Commit d16f9df0090c6fa75380510df3e899f37a28ae86

Authored by Charles Otto
1 parent 63f5d257

Changes to training subsampling

Remove variables related to subsampling training data from Transform, move them
to a new transform called DownsampleTraining, and perform subsampling in that
transform, rather than as part of IndependentTransform (which depended on the
subsampling variables in Transform).

The prior approach was incompatible with using explicit class variable names,
rather than assuming a fixed label variable. This is because the actual
downsampling was carried out as part of IndependentTransform, which is not a
visible part of the algorithm grammar. Removing this logic from
Independent/Transform is generally desirable if only to avoid cluttering
Transform with member variables that are only used some of the time, and used
in a (fairly) restrictive way.

The current approach of makring DownsampleTraining as an indepent transform
is still limited since the downsample logic (still) cannot be applied without
the split used in Independent, and also DownsampleTraining can only be
implemented as a wrapper for another transform (because the argument to train
is const, and performing the donwsample logic in project would also impact the
testing case (since we have no way to exclude a transform used in training from
use in testing, for a fixed algorithm)).
openbr/openbr_plugin.cpp
@@ -1083,9 +1083,6 @@ Transform::Transform(bool _independent, bool _trainable) @@ -1083,9 +1083,6 @@ Transform::Transform(bool _independent, bool _trainable)
1083 { 1083 {
1084 independent = _independent; 1084 independent = _independent;
1085 trainable = _trainable; 1085 trainable = _trainable;
1086 - classes = std::numeric_limits<int>::max();  
1087 - instances = std::numeric_limits<int>::max();  
1088 - fraction = 1;  
1089 } 1086 }
1090 1087
1091 Transform *Transform::make(QString str, QObject *parent) 1088 Transform *Transform::make(QString str, QObject *parent)
@@ -1141,9 +1138,6 @@ Transform *Transform::make(QString str, QObject *parent) @@ -1141,9 +1138,6 @@ Transform *Transform::make(QString str, QObject *parent)
1141 Transform *Transform::clone() const 1138 Transform *Transform::clone() const
1142 { 1139 {
1143 Transform *clone = Factory<Transform>::make(file.flat()); 1140 Transform *clone = Factory<Transform>::make(file.flat());
1144 - clone->classes = classes;  
1145 - clone->instances = instances;  
1146 - clone->fraction = fraction;  
1147 return clone; 1141 return clone;
1148 } 1142 }
1149 1143
openbr/openbr_plugin.h
@@ -1068,12 +1068,6 @@ class BR_EXPORT Transform : public Object @@ -1068,12 +1068,6 @@ class BR_EXPORT Transform : public Object
1068 Q_OBJECT 1068 Q_OBJECT
1069 1069
1070 public: 1070 public:
1071 - Q_PROPERTY(int classes READ get_classes WRITE set_classes RESET reset_classes STORED false)  
1072 - Q_PROPERTY(int instances READ get_instances WRITE set_instances RESET reset_instances STORED false)  
1073 - Q_PROPERTY(float fraction READ get_fraction WRITE set_fraction RESET reset_fraction STORED false)  
1074 - BR_PROPERTY(int, classes, std::numeric_limits<int>::max())  
1075 - BR_PROPERTY(int, instances, std::numeric_limits<int>::max())  
1076 - BR_PROPERTY(float, fraction, 1)  
1077 bool independent, trainable; 1071 bool independent, trainable;
1078 1072
1079 virtual ~Transform() {} 1073 virtual ~Transform() {}
openbr/plugins/algorithms.cpp
@@ -48,7 +48,7 @@ class AlgorithmsInitializer : public Initializer @@ -48,7 +48,7 @@ class AlgorithmsInitializer : public Initializer
48 // Video 48 // Video
49 Globals->abbreviations.insert("DisplayVideo", "Stream([Show(false)+Discard])"); 49 Globals->abbreviations.insert("DisplayVideo", "Stream([Show(false)+Discard])");
50 Globals->abbreviations.insert("PerFrameDetection", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+ASEFEyes+RestoreMat(original)+Draw(inPlace=true),Show(false)+Discard])"); 50 Globals->abbreviations.insert("PerFrameDetection", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+ASEFEyes+RestoreMat(original)+Draw(inPlace=true),Show(false)+Discard])");
51 - Globals->abbreviations.insert("AgeGenderDemo", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+Expand+<FaceClassificationRegistration>+<FaceClassificationExtraction>+(<AgeRegressor>+Rename(Subject,Age)+Discard)/(<GenderClassifier>+Rename(Subject,Gender)+Discard)+RestoreMat(original)+Draw(inPlace=true)+DrawPropertiesPoint([Age,Gender],Affine_0,inPlace=true)+SaveMat(original)+Discard+Contract,RestoreMat(original)+FPSCalc+Show(false,[AvgFPS,Age,Gender])+Discard])"); 51 + Globals->abbreviations.insert("AgeGenderDemo", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+Expand+<FaceClassificationRegistration>+<FaceClassificationExtraction>+<AgeRegressor>/<GenderClassifier>+Discard+RestoreMat(original)+Draw(inPlace=true)+DrawPropertiesPoint([Age,Gender],Affine_0,inPlace=true)+SaveMat(original)+Discard+Contract,RestoreMat(original)+FPSCalc+Show(false,[AvgFPS,Age,Gender])+Discard])");
52 52
53 // Generic Image Processing 53 // Generic Image Processing
54 Globals->abbreviations.insert("SIFT", "Open+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); 54 Globals->abbreviations.insert("SIFT", "Open+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)");
@@ -74,14 +74,14 @@ class AlgorithmsInitializer : public Initializer @@ -74,14 +74,14 @@ class AlgorithmsInitializer : public Initializer
74 Globals->abbreviations.insert("FaceDetection", "(Open+Cvt(Gray)+Cascade(FrontalFace))"); 74 Globals->abbreviations.insert("FaceDetection", "(Open+Cvt(Gray)+Cascade(FrontalFace))");
75 Globals->abbreviations.insert("DenseLBP", "(Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59))"); 75 Globals->abbreviations.insert("DenseLBP", "(Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59))");
76 Globals->abbreviations.insert("DenseSIFT", "(Grid(10,10)+SIFTDescriptor(12)+ByRow)"); 76 Globals->abbreviations.insert("DenseSIFT", "(Grid(10,10)+SIFTDescriptor(12)+ByRow)");
77 - Globals->abbreviations.insert("FaceRecognitionRegistration", "(ASEFEyes+Affine(88,88,0.25,0.35)+FTE(DFFS,instances=1))");  
78 - Globals->abbreviations.insert("FaceRecognitionExtraction", "(Mask+DenseSIFT/DenseLBP+PCA(0.95,instances=1)+Normalize(L2)+Cat)");  
79 - Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+LDA(0.98,instances=-2)+Cat+PCA(768,instances=1))"); 77 + Globals->abbreviations.insert("FaceRecognitionRegistration", "(ASEFEyes+Affine(88,88,0.25,0.35)+DownsampleTraining(FTE(DFFS),instances=1))");
  78 + Globals->abbreviations.insert("FaceRecognitionExtraction", "(Mask+DenseSIFT/DenseLBP+DownsampleTraining(PCA(0.95),instances=1)+Normalize(L2)+Cat)");
  79 + Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+DownsampleTraining(LDA(0.98),instances=-2)+Cat+DownsampleTraining(PCA(768),instances=1))");
80 Globals->abbreviations.insert("FaceRecognitionQuantization", "(Normalize(L1)+Quantize)"); 80 Globals->abbreviations.insert("FaceRecognitionQuantization", "(Normalize(L1)+Quantize)");
81 Globals->abbreviations.insert("FaceClassificationRegistration", "(ASEFEyes+Affine(56,72,0.33,0.45)+FTE(DFFS))"); 81 Globals->abbreviations.insert("FaceClassificationRegistration", "(ASEFEyes+Affine(56,72,0.33,0.45)+FTE(DFFS))");
82 - Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+PCA(0.95,instances=-1)+Cat)");  
83 - Globals->abbreviations.insert("AgeRegressor", "Center(Range,instances=-1)+SVM(RBF,EPS_SVR,inputVariable=Age,instances=100)");  
84 - Globals->abbreviations.insert("GenderClassifier", "Center(Range,instances=-1)+SVM(RBF,C_SVC,inputVariable=Gender,instances=4000)"); 82 + Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+DownsampleTraining(PCA(0.95),instances=-1, inputVariable=Gender)+Cat)");
  83 + Globals->abbreviations.insert("AgeRegressor", "DownsampleTraining(Center(Range),instances=-1)+DownsampleTraining(SVM(RBF,EPS_SVR,inputVariable=Age),instances=100, inputVariable=Age)");
  84 + Globals->abbreviations.insert("GenderClassifier", "DownsampleTraining(Center(Range),instances=-1, inputVariable=Gender)+DownsampleTraining(SVM(RBF,C_SVC,inputVariable=Gender),instances=4000, inputVariable=Gender)");
85 Globals->abbreviations.insert("UCharL1", "Unit(ByteL1)"); 85 Globals->abbreviations.insert("UCharL1", "Unit(ByteL1)");
86 } 86 }
87 }; 87 };
openbr/plugins/independent.cpp
@@ -9,16 +9,16 @@ using namespace cv; @@ -9,16 +9,16 @@ using namespace cv;
9 namespace br 9 namespace br
10 { 10 {
11 11
12 -static TemplateList Downsample(const TemplateList &templates, const Transform *transform, const QString & inputVariable) 12 +static TemplateList Downsample(const TemplateList &templates, int classes, int instances, float fraction, const QString & inputVariable)
13 { 13 {
14 // Return early when no downsampling is required 14 // Return early when no downsampling is required
15 - if ((transform->classes == std::numeric_limits<int>::max()) &&  
16 - (transform->instances == std::numeric_limits<int>::max()) &&  
17 - (transform->fraction >= 1)) 15 + if ((classes == std::numeric_limits<int>::max()) &&
  16 + (instances == std::numeric_limits<int>::max()) &&
  17 + (fraction >= 1))
18 return templates; 18 return templates;
19 19
20 - const bool atLeast = transform->instances < 0;  
21 - const int instances = abs(transform->instances); 20 + const bool atLeast = instances < 0;
  21 + instances = abs(instances);
22 22
23 QList<QString> allLabels = templates.get<QString>(inputVariable); 23 QList<QString> allLabels = templates.get<QString>(inputVariable);
24 QList<QString> uniqueLabels = allLabels.toSet().toList(); 24 QList<QString> uniqueLabels = allLabels.toSet().toList();
@@ -26,20 +26,20 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t @@ -26,20 +26,20 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t
26 26
27 QMap<QString,int> counts = templates.countValues<QString>(inputVariable, instances != std::numeric_limits<int>::max()); 27 QMap<QString,int> counts = templates.countValues<QString>(inputVariable, instances != std::numeric_limits<int>::max());
28 28
29 - if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max())) 29 + if ((instances != std::numeric_limits<int>::max()) && (classes != std::numeric_limits<int>::max()))
30 foreach (const QString & label, counts.keys()) 30 foreach (const QString & label, counts.keys())
31 if (counts[label] < instances) 31 if (counts[label] < instances)
32 counts.remove(label); 32 counts.remove(label);
33 33
34 uniqueLabels = counts.keys(); 34 uniqueLabels = counts.keys();
35 - if ((transform->classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < transform->classes))  
36 - qWarning("Downsample requested %d classes but only %d are available.", transform->classes, uniqueLabels.size()); 35 + if ((classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < classes))
  36 + qWarning("Downsample requested %d classes but only %d are available.", classes, uniqueLabels.size());
37 37
38 Common::seedRNG(); 38 Common::seedRNG();
39 QList<QString> selectedLabels = uniqueLabels; 39 QList<QString> selectedLabels = uniqueLabels;
40 - if (transform->classes < uniqueLabels.size()) { 40 + if (classes < uniqueLabels.size()) {
41 std::random_shuffle(selectedLabels.begin(), selectedLabels.end()); 41 std::random_shuffle(selectedLabels.begin(), selectedLabels.end());
42 - selectedLabels = selectedLabels.mid(0, transform->classes); 42 + selectedLabels = selectedLabels.mid(0, classes);
43 } 43 }
44 44
45 TemplateList downsample; 45 TemplateList downsample;
@@ -56,14 +56,45 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t @@ -56,14 +56,45 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t
56 downsample.append(templates.value(indices[j])); 56 downsample.append(templates.value(indices[j]));
57 } 57 }
58 58
59 - if (transform->fraction < 1) { 59 + if (fraction < 1) {
60 std::random_shuffle(downsample.begin(), downsample.end()); 60 std::random_shuffle(downsample.begin(), downsample.end());
61 - downsample = downsample.mid(0, downsample.size()*transform->fraction); 61 + downsample = downsample.mid(0, downsample.size()*fraction);
62 } 62 }
63 63
64 return downsample; 64 return downsample;
65 } 65 }
66 66
  67 +class DownsampleTrainingTransform : public Transform
  68 +{
  69 + Q_OBJECT
  70 + Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform STORED true)
  71 + Q_PROPERTY(int classes READ get_classes WRITE set_classes RESET reset_classes STORED false)
  72 + Q_PROPERTY(int instances READ get_instances WRITE set_instances RESET reset_instances STORED false)
  73 + Q_PROPERTY(float fraction READ get_fraction WRITE set_fraction RESET reset_fraction STORED false)
  74 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  75 + BR_PROPERTY(br::Transform*, transform, NULL)
  76 + BR_PROPERTY(int, classes, std::numeric_limits<int>::max())
  77 + BR_PROPERTY(int, instances, std::numeric_limits<int>::max())
  78 + BR_PROPERTY(float, fraction, 1)
  79 + BR_PROPERTY(QString, inputVariable, "Label")
  80 +
  81 + void project(const Template & src, Template & dst) const
  82 + {
  83 + transform->project(src,dst);
  84 + }
  85 +
  86 +
  87 + void train(const TemplateList &data)
  88 + {
  89 + if (!transform || !transform->trainable)
  90 + return;
  91 +
  92 + TemplateList downsampled = Downsample(data, classes, instances, fraction, inputVariable);
  93 + transform->train(downsampled);
  94 + }
  95 +};
  96 +BR_REGISTER(Transform, DownsampleTrainingTransform)
  97 +
67 /*! 98 /*!
68 * \ingroup transforms 99 * \ingroup transforms
69 * \brief Clones the transform so that it can be applied independently. 100 * \brief Clones the transform so that it can be applied independently.
@@ -74,9 +105,7 @@ class IndependentTransform : public MetaTransform @@ -74,9 +105,7 @@ class IndependentTransform : public MetaTransform
74 { 105 {
75 Q_OBJECT 106 Q_OBJECT
76 Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform STORED false) 107 Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform STORED false)
77 - Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)  
78 BR_PROPERTY(br::Transform*, transform, NULL) 108 BR_PROPERTY(br::Transform*, transform, NULL)
79 - BR_PROPERTY(QString, inputVariable, "Label")  
80 109
81 QList<Transform*> transforms; 110 QList<Transform*> transforms;
82 111
@@ -124,13 +153,10 @@ class IndependentTransform : public MetaTransform @@ -124,13 +153,10 @@ class IndependentTransform : public MetaTransform
124 while (transforms.size() < templatesList.size()) 153 while (transforms.size() < templatesList.size())
125 transforms.append(transform->clone()); 154 transforms.append(transform->clone());
126 155
127 - for (int i=0; i<templatesList.size(); i++)  
128 - templatesList[i] = Downsample(templatesList[i], transforms[i], inputVariable);  
129 -  
130 QFutureSynchronizer<void> futures; 156 QFutureSynchronizer<void> futures;
131 for (int i=0; i<templatesList.size(); i++) 157 for (int i=0; i<templatesList.size(); i++)
132 - futures.addFuture(QtConcurrent::run(_train, transforms[i], &templatesList[i]));  
133 - futures.waitForFinished(); 158 + futures.addFuture(QtConcurrent::run(_train, transforms[i], &templatesList[i]));
  159 + futures.waitForFinished();
134 } 160 }
135 161
136 void project(const Template &src, Template &dst) const 162 void project(const Template &src, Template &dst) const
openbr/plugins/openbr_internal.h
@@ -225,10 +225,6 @@ public: @@ -225,10 +225,6 @@ public:
225 } 225 }
226 226
227 output->file = this->file; 227 output->file = this->file;
228 - output->classes = classes;  
229 - output->instances = instances;  
230 - output->fraction = fraction;  
231 -  
232 output->init(); 228 output->init();
233 229
234 return output; 230 return output;