From d16f9df0090c6fa75380510df3e899f37a28ae86 Mon Sep 17 00:00:00 2001 From: Charles Otto Date: Sat, 27 Jul 2013 17:38:17 -0400 Subject: [PATCH] Changes to training subsampling --- openbr/openbr_plugin.cpp | 6 ------ openbr/openbr_plugin.h | 6 ------ openbr/plugins/algorithms.cpp | 14 +++++++------- openbr/plugins/independent.cpp | 66 ++++++++++++++++++++++++++++++++++++++++++++++-------------------- openbr/plugins/openbr_internal.h | 4 ---- 5 files changed, 53 insertions(+), 43 deletions(-) diff --git a/openbr/openbr_plugin.cpp b/openbr/openbr_plugin.cpp index 1a6494f..d681843 100644 --- a/openbr/openbr_plugin.cpp +++ b/openbr/openbr_plugin.cpp @@ -1083,9 +1083,6 @@ Transform::Transform(bool _independent, bool _trainable) { independent = _independent; trainable = _trainable; - classes = std::numeric_limits::max(); - instances = std::numeric_limits::max(); - fraction = 1; } Transform *Transform::make(QString str, QObject *parent) @@ -1141,9 +1138,6 @@ Transform *Transform::make(QString str, QObject *parent) Transform *Transform::clone() const { Transform *clone = Factory::make(file.flat()); - clone->classes = classes; - clone->instances = instances; - clone->fraction = fraction; return clone; } diff --git a/openbr/openbr_plugin.h b/openbr/openbr_plugin.h index 7b27901..6820818 100644 --- a/openbr/openbr_plugin.h +++ b/openbr/openbr_plugin.h @@ -1068,12 +1068,6 @@ class BR_EXPORT Transform : public Object Q_OBJECT public: - Q_PROPERTY(int classes READ get_classes WRITE set_classes RESET reset_classes STORED false) - Q_PROPERTY(int instances READ get_instances WRITE set_instances RESET reset_instances STORED false) - Q_PROPERTY(float fraction READ get_fraction WRITE set_fraction RESET reset_fraction STORED false) - BR_PROPERTY(int, classes, std::numeric_limits::max()) - BR_PROPERTY(int, instances, std::numeric_limits::max()) - BR_PROPERTY(float, fraction, 1) bool independent, trainable; virtual ~Transform() {} diff --git a/openbr/plugins/algorithms.cpp b/openbr/plugins/algorithms.cpp index f41b96e..1a3d8a8 100644 --- a/openbr/plugins/algorithms.cpp +++ b/openbr/plugins/algorithms.cpp @@ -48,7 +48,7 @@ class AlgorithmsInitializer : public Initializer // Video Globals->abbreviations.insert("DisplayVideo", "Stream([Show(false)+Discard])"); Globals->abbreviations.insert("PerFrameDetection", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+ASEFEyes+RestoreMat(original)+Draw(inPlace=true),Show(false)+Discard])"); - Globals->abbreviations.insert("AgeGenderDemo", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+Expand+++(+Rename(Subject,Age)+Discard)/(+Rename(Subject,Gender)+Discard)+RestoreMat(original)+Draw(inPlace=true)+DrawPropertiesPoint([Age,Gender],Affine_0,inPlace=true)+SaveMat(original)+Discard+Contract,RestoreMat(original)+FPSCalc+Show(false,[AvgFPS,Age,Gender])+Discard])"); + Globals->abbreviations.insert("AgeGenderDemo", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+Expand+++/+Discard+RestoreMat(original)+Draw(inPlace=true)+DrawPropertiesPoint([Age,Gender],Affine_0,inPlace=true)+SaveMat(original)+Discard+Contract,RestoreMat(original)+FPSCalc+Show(false,[AvgFPS,Age,Gender])+Discard])"); // Generic Image Processing Globals->abbreviations.insert("SIFT", "Open+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); @@ -74,14 +74,14 @@ class AlgorithmsInitializer : public Initializer Globals->abbreviations.insert("FaceDetection", "(Open+Cvt(Gray)+Cascade(FrontalFace))"); Globals->abbreviations.insert("DenseLBP", "(Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59))"); Globals->abbreviations.insert("DenseSIFT", "(Grid(10,10)+SIFTDescriptor(12)+ByRow)"); - Globals->abbreviations.insert("FaceRecognitionRegistration", "(ASEFEyes+Affine(88,88,0.25,0.35)+FTE(DFFS,instances=1))"); - Globals->abbreviations.insert("FaceRecognitionExtraction", "(Mask+DenseSIFT/DenseLBP+PCA(0.95,instances=1)+Normalize(L2)+Cat)"); - Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+LDA(0.98,instances=-2)+Cat+PCA(768,instances=1))"); + Globals->abbreviations.insert("FaceRecognitionRegistration", "(ASEFEyes+Affine(88,88,0.25,0.35)+DownsampleTraining(FTE(DFFS),instances=1))"); + Globals->abbreviations.insert("FaceRecognitionExtraction", "(Mask+DenseSIFT/DenseLBP+DownsampleTraining(PCA(0.95),instances=1)+Normalize(L2)+Cat)"); + Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+DownsampleTraining(LDA(0.98),instances=-2)+Cat+DownsampleTraining(PCA(768),instances=1))"); Globals->abbreviations.insert("FaceRecognitionQuantization", "(Normalize(L1)+Quantize)"); Globals->abbreviations.insert("FaceClassificationRegistration", "(ASEFEyes+Affine(56,72,0.33,0.45)+FTE(DFFS))"); - Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+PCA(0.95,instances=-1)+Cat)"); - Globals->abbreviations.insert("AgeRegressor", "Center(Range,instances=-1)+SVM(RBF,EPS_SVR,inputVariable=Age,instances=100)"); - Globals->abbreviations.insert("GenderClassifier", "Center(Range,instances=-1)+SVM(RBF,C_SVC,inputVariable=Gender,instances=4000)"); + Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+DownsampleTraining(PCA(0.95),instances=-1, inputVariable=Gender)+Cat)"); + Globals->abbreviations.insert("AgeRegressor", "DownsampleTraining(Center(Range),instances=-1)+DownsampleTraining(SVM(RBF,EPS_SVR,inputVariable=Age),instances=100, inputVariable=Age)"); + Globals->abbreviations.insert("GenderClassifier", "DownsampleTraining(Center(Range),instances=-1, inputVariable=Gender)+DownsampleTraining(SVM(RBF,C_SVC,inputVariable=Gender),instances=4000, inputVariable=Gender)"); Globals->abbreviations.insert("UCharL1", "Unit(ByteL1)"); } }; diff --git a/openbr/plugins/independent.cpp b/openbr/plugins/independent.cpp index 00770bf..fd650dc 100644 --- a/openbr/plugins/independent.cpp +++ b/openbr/plugins/independent.cpp @@ -9,16 +9,16 @@ using namespace cv; namespace br { -static TemplateList Downsample(const TemplateList &templates, const Transform *transform, const QString & inputVariable) +static TemplateList Downsample(const TemplateList &templates, int classes, int instances, float fraction, const QString & inputVariable) { // Return early when no downsampling is required - if ((transform->classes == std::numeric_limits::max()) && - (transform->instances == std::numeric_limits::max()) && - (transform->fraction >= 1)) + if ((classes == std::numeric_limits::max()) && + (instances == std::numeric_limits::max()) && + (fraction >= 1)) return templates; - const bool atLeast = transform->instances < 0; - const int instances = abs(transform->instances); + const bool atLeast = instances < 0; + instances = abs(instances); QList allLabels = templates.get(inputVariable); QList uniqueLabels = allLabels.toSet().toList(); @@ -26,20 +26,20 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t QMap counts = templates.countValues(inputVariable, instances != std::numeric_limits::max()); - if ((instances != std::numeric_limits::max()) && (transform->classes != std::numeric_limits::max())) + if ((instances != std::numeric_limits::max()) && (classes != std::numeric_limits::max())) foreach (const QString & label, counts.keys()) if (counts[label] < instances) counts.remove(label); uniqueLabels = counts.keys(); - if ((transform->classes != std::numeric_limits::max()) && (uniqueLabels.size() < transform->classes)) - qWarning("Downsample requested %d classes but only %d are available.", transform->classes, uniqueLabels.size()); + if ((classes != std::numeric_limits::max()) && (uniqueLabels.size() < classes)) + qWarning("Downsample requested %d classes but only %d are available.", classes, uniqueLabels.size()); Common::seedRNG(); QList selectedLabels = uniqueLabels; - if (transform->classes < uniqueLabels.size()) { + if (classes < uniqueLabels.size()) { std::random_shuffle(selectedLabels.begin(), selectedLabels.end()); - selectedLabels = selectedLabels.mid(0, transform->classes); + selectedLabels = selectedLabels.mid(0, classes); } TemplateList downsample; @@ -56,14 +56,45 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t downsample.append(templates.value(indices[j])); } - if (transform->fraction < 1) { + if (fraction < 1) { std::random_shuffle(downsample.begin(), downsample.end()); - downsample = downsample.mid(0, downsample.size()*transform->fraction); + downsample = downsample.mid(0, downsample.size()*fraction); } return downsample; } +class DownsampleTrainingTransform : public Transform +{ + Q_OBJECT + Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform STORED true) + Q_PROPERTY(int classes READ get_classes WRITE set_classes RESET reset_classes STORED false) + Q_PROPERTY(int instances READ get_instances WRITE set_instances RESET reset_instances STORED false) + Q_PROPERTY(float fraction READ get_fraction WRITE set_fraction RESET reset_fraction STORED false) + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) + BR_PROPERTY(br::Transform*, transform, NULL) + BR_PROPERTY(int, classes, std::numeric_limits::max()) + BR_PROPERTY(int, instances, std::numeric_limits::max()) + BR_PROPERTY(float, fraction, 1) + BR_PROPERTY(QString, inputVariable, "Label") + + void project(const Template & src, Template & dst) const + { + transform->project(src,dst); + } + + + void train(const TemplateList &data) + { + if (!transform || !transform->trainable) + return; + + TemplateList downsampled = Downsample(data, classes, instances, fraction, inputVariable); + transform->train(downsampled); + } +}; +BR_REGISTER(Transform, DownsampleTrainingTransform) + /*! * \ingroup transforms * \brief Clones the transform so that it can be applied independently. @@ -74,9 +105,7 @@ class IndependentTransform : public MetaTransform { Q_OBJECT Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform STORED false) - Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) BR_PROPERTY(br::Transform*, transform, NULL) - BR_PROPERTY(QString, inputVariable, "Label") QList transforms; @@ -124,13 +153,10 @@ class IndependentTransform : public MetaTransform while (transforms.size() < templatesList.size()) transforms.append(transform->clone()); - for (int i=0; i futures; for (int i=0; ifile = this->file; - output->classes = classes; - output->instances = instances; - output->fraction = fraction; - output->init(); return output; -- libgit2 0.21.4