Commit d16f9df0090c6fa75380510df3e899f37a28ae86

Authored by Charles Otto
1 parent 63f5d257

Changes to training subsampling

Remove variables related to subsampling training data from Transform, move them
to a new transform called DownsampleTraining, and perform subsampling in that
transform, rather than as part of IndependentTransform (which depended on the
subsampling variables in Transform).

The prior approach was incompatible with using explicit class variable names,
rather than assuming a fixed label variable. This is because the actual
downsampling was carried out as part of IndependentTransform, which is not a
visible part of the algorithm grammar. Removing this logic from
Independent/Transform is generally desirable if only to avoid cluttering
Transform with member variables that are only used some of the time, and used
in a (fairly) restrictive way.

The current approach of makring DownsampleTraining as an indepent transform
is still limited since the downsample logic (still) cannot be applied without
the split used in Independent, and also DownsampleTraining can only be
implemented as a wrapper for another transform (because the argument to train
is const, and performing the donwsample logic in project would also impact the
testing case (since we have no way to exclude a transform used in training from
use in testing, for a fixed algorithm)).
openbr/openbr_plugin.cpp
... ... @@ -1083,9 +1083,6 @@ Transform::Transform(bool _independent, bool _trainable)
1083 1083 {
1084 1084 independent = _independent;
1085 1085 trainable = _trainable;
1086   - classes = std::numeric_limits<int>::max();
1087   - instances = std::numeric_limits<int>::max();
1088   - fraction = 1;
1089 1086 }
1090 1087  
1091 1088 Transform *Transform::make(QString str, QObject *parent)
... ... @@ -1141,9 +1138,6 @@ Transform *Transform::make(QString str, QObject *parent)
1141 1138 Transform *Transform::clone() const
1142 1139 {
1143 1140 Transform *clone = Factory<Transform>::make(file.flat());
1144   - clone->classes = classes;
1145   - clone->instances = instances;
1146   - clone->fraction = fraction;
1147 1141 return clone;
1148 1142 }
1149 1143  
... ...
openbr/openbr_plugin.h
... ... @@ -1068,12 +1068,6 @@ class BR_EXPORT Transform : public Object
1068 1068 Q_OBJECT
1069 1069  
1070 1070 public:
1071   - Q_PROPERTY(int classes READ get_classes WRITE set_classes RESET reset_classes STORED false)
1072   - Q_PROPERTY(int instances READ get_instances WRITE set_instances RESET reset_instances STORED false)
1073   - Q_PROPERTY(float fraction READ get_fraction WRITE set_fraction RESET reset_fraction STORED false)
1074   - BR_PROPERTY(int, classes, std::numeric_limits<int>::max())
1075   - BR_PROPERTY(int, instances, std::numeric_limits<int>::max())
1076   - BR_PROPERTY(float, fraction, 1)
1077 1071 bool independent, trainable;
1078 1072  
1079 1073 virtual ~Transform() {}
... ...
openbr/plugins/algorithms.cpp
... ... @@ -48,7 +48,7 @@ class AlgorithmsInitializer : public Initializer
48 48 // Video
49 49 Globals->abbreviations.insert("DisplayVideo", "Stream([Show(false)+Discard])");
50 50 Globals->abbreviations.insert("PerFrameDetection", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+ASEFEyes+RestoreMat(original)+Draw(inPlace=true),Show(false)+Discard])");
51   - Globals->abbreviations.insert("AgeGenderDemo", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+Expand+<FaceClassificationRegistration>+<FaceClassificationExtraction>+(<AgeRegressor>+Rename(Subject,Age)+Discard)/(<GenderClassifier>+Rename(Subject,Gender)+Discard)+RestoreMat(original)+Draw(inPlace=true)+DrawPropertiesPoint([Age,Gender],Affine_0,inPlace=true)+SaveMat(original)+Discard+Contract,RestoreMat(original)+FPSCalc+Show(false,[AvgFPS,Age,Gender])+Discard])");
  51 + Globals->abbreviations.insert("AgeGenderDemo", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+Expand+<FaceClassificationRegistration>+<FaceClassificationExtraction>+<AgeRegressor>/<GenderClassifier>+Discard+RestoreMat(original)+Draw(inPlace=true)+DrawPropertiesPoint([Age,Gender],Affine_0,inPlace=true)+SaveMat(original)+Discard+Contract,RestoreMat(original)+FPSCalc+Show(false,[AvgFPS,Age,Gender])+Discard])");
52 52  
53 53 // Generic Image Processing
54 54 Globals->abbreviations.insert("SIFT", "Open+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)");
... ... @@ -74,14 +74,14 @@ class AlgorithmsInitializer : public Initializer
74 74 Globals->abbreviations.insert("FaceDetection", "(Open+Cvt(Gray)+Cascade(FrontalFace))");
75 75 Globals->abbreviations.insert("DenseLBP", "(Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59))");
76 76 Globals->abbreviations.insert("DenseSIFT", "(Grid(10,10)+SIFTDescriptor(12)+ByRow)");
77   - Globals->abbreviations.insert("FaceRecognitionRegistration", "(ASEFEyes+Affine(88,88,0.25,0.35)+FTE(DFFS,instances=1))");
78   - Globals->abbreviations.insert("FaceRecognitionExtraction", "(Mask+DenseSIFT/DenseLBP+PCA(0.95,instances=1)+Normalize(L2)+Cat)");
79   - Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+LDA(0.98,instances=-2)+Cat+PCA(768,instances=1))");
  77 + Globals->abbreviations.insert("FaceRecognitionRegistration", "(ASEFEyes+Affine(88,88,0.25,0.35)+DownsampleTraining(FTE(DFFS),instances=1))");
  78 + Globals->abbreviations.insert("FaceRecognitionExtraction", "(Mask+DenseSIFT/DenseLBP+DownsampleTraining(PCA(0.95),instances=1)+Normalize(L2)+Cat)");
  79 + Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+DownsampleTraining(LDA(0.98),instances=-2)+Cat+DownsampleTraining(PCA(768),instances=1))");
80 80 Globals->abbreviations.insert("FaceRecognitionQuantization", "(Normalize(L1)+Quantize)");
81 81 Globals->abbreviations.insert("FaceClassificationRegistration", "(ASEFEyes+Affine(56,72,0.33,0.45)+FTE(DFFS))");
82   - Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+PCA(0.95,instances=-1)+Cat)");
83   - Globals->abbreviations.insert("AgeRegressor", "Center(Range,instances=-1)+SVM(RBF,EPS_SVR,inputVariable=Age,instances=100)");
84   - Globals->abbreviations.insert("GenderClassifier", "Center(Range,instances=-1)+SVM(RBF,C_SVC,inputVariable=Gender,instances=4000)");
  82 + Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+DownsampleTraining(PCA(0.95),instances=-1, inputVariable=Gender)+Cat)");
  83 + Globals->abbreviations.insert("AgeRegressor", "DownsampleTraining(Center(Range),instances=-1)+DownsampleTraining(SVM(RBF,EPS_SVR,inputVariable=Age),instances=100, inputVariable=Age)");
  84 + Globals->abbreviations.insert("GenderClassifier", "DownsampleTraining(Center(Range),instances=-1, inputVariable=Gender)+DownsampleTraining(SVM(RBF,C_SVC,inputVariable=Gender),instances=4000, inputVariable=Gender)");
85 85 Globals->abbreviations.insert("UCharL1", "Unit(ByteL1)");
86 86 }
87 87 };
... ...
openbr/plugins/independent.cpp
... ... @@ -9,16 +9,16 @@ using namespace cv;
9 9 namespace br
10 10 {
11 11  
12   -static TemplateList Downsample(const TemplateList &templates, const Transform *transform, const QString & inputVariable)
  12 +static TemplateList Downsample(const TemplateList &templates, int classes, int instances, float fraction, const QString & inputVariable)
13 13 {
14 14 // Return early when no downsampling is required
15   - if ((transform->classes == std::numeric_limits<int>::max()) &&
16   - (transform->instances == std::numeric_limits<int>::max()) &&
17   - (transform->fraction >= 1))
  15 + if ((classes == std::numeric_limits<int>::max()) &&
  16 + (instances == std::numeric_limits<int>::max()) &&
  17 + (fraction >= 1))
18 18 return templates;
19 19  
20   - const bool atLeast = transform->instances < 0;
21   - const int instances = abs(transform->instances);
  20 + const bool atLeast = instances < 0;
  21 + instances = abs(instances);
22 22  
23 23 QList<QString> allLabels = templates.get<QString>(inputVariable);
24 24 QList<QString> uniqueLabels = allLabels.toSet().toList();
... ... @@ -26,20 +26,20 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t
26 26  
27 27 QMap<QString,int> counts = templates.countValues<QString>(inputVariable, instances != std::numeric_limits<int>::max());
28 28  
29   - if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max()))
  29 + if ((instances != std::numeric_limits<int>::max()) && (classes != std::numeric_limits<int>::max()))
30 30 foreach (const QString & label, counts.keys())
31 31 if (counts[label] < instances)
32 32 counts.remove(label);
33 33  
34 34 uniqueLabels = counts.keys();
35   - if ((transform->classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < transform->classes))
36   - qWarning("Downsample requested %d classes but only %d are available.", transform->classes, uniqueLabels.size());
  35 + if ((classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < classes))
  36 + qWarning("Downsample requested %d classes but only %d are available.", classes, uniqueLabels.size());
37 37  
38 38 Common::seedRNG();
39 39 QList<QString> selectedLabels = uniqueLabels;
40   - if (transform->classes < uniqueLabels.size()) {
  40 + if (classes < uniqueLabels.size()) {
41 41 std::random_shuffle(selectedLabels.begin(), selectedLabels.end());
42   - selectedLabels = selectedLabels.mid(0, transform->classes);
  42 + selectedLabels = selectedLabels.mid(0, classes);
43 43 }
44 44  
45 45 TemplateList downsample;
... ... @@ -56,14 +56,45 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t
56 56 downsample.append(templates.value(indices[j]));
57 57 }
58 58  
59   - if (transform->fraction < 1) {
  59 + if (fraction < 1) {
60 60 std::random_shuffle(downsample.begin(), downsample.end());
61   - downsample = downsample.mid(0, downsample.size()*transform->fraction);
  61 + downsample = downsample.mid(0, downsample.size()*fraction);
62 62 }
63 63  
64 64 return downsample;
65 65 }
66 66  
  67 +class DownsampleTrainingTransform : public Transform
  68 +{
  69 + Q_OBJECT
  70 + Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform STORED true)
  71 + Q_PROPERTY(int classes READ get_classes WRITE set_classes RESET reset_classes STORED false)
  72 + Q_PROPERTY(int instances READ get_instances WRITE set_instances RESET reset_instances STORED false)
  73 + Q_PROPERTY(float fraction READ get_fraction WRITE set_fraction RESET reset_fraction STORED false)
  74 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  75 + BR_PROPERTY(br::Transform*, transform, NULL)
  76 + BR_PROPERTY(int, classes, std::numeric_limits<int>::max())
  77 + BR_PROPERTY(int, instances, std::numeric_limits<int>::max())
  78 + BR_PROPERTY(float, fraction, 1)
  79 + BR_PROPERTY(QString, inputVariable, "Label")
  80 +
  81 + void project(const Template & src, Template & dst) const
  82 + {
  83 + transform->project(src,dst);
  84 + }
  85 +
  86 +
  87 + void train(const TemplateList &data)
  88 + {
  89 + if (!transform || !transform->trainable)
  90 + return;
  91 +
  92 + TemplateList downsampled = Downsample(data, classes, instances, fraction, inputVariable);
  93 + transform->train(downsampled);
  94 + }
  95 +};
  96 +BR_REGISTER(Transform, DownsampleTrainingTransform)
  97 +
67 98 /*!
68 99 * \ingroup transforms
69 100 * \brief Clones the transform so that it can be applied independently.
... ... @@ -74,9 +105,7 @@ class IndependentTransform : public MetaTransform
74 105 {
75 106 Q_OBJECT
76 107 Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform STORED false)
77   - Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
78 108 BR_PROPERTY(br::Transform*, transform, NULL)
79   - BR_PROPERTY(QString, inputVariable, "Label")
80 109  
81 110 QList<Transform*> transforms;
82 111  
... ... @@ -124,13 +153,10 @@ class IndependentTransform : public MetaTransform
124 153 while (transforms.size() < templatesList.size())
125 154 transforms.append(transform->clone());
126 155  
127   - for (int i=0; i<templatesList.size(); i++)
128   - templatesList[i] = Downsample(templatesList[i], transforms[i], inputVariable);
129   -
130 156 QFutureSynchronizer<void> futures;
131 157 for (int i=0; i<templatesList.size(); i++)
132   - futures.addFuture(QtConcurrent::run(_train, transforms[i], &templatesList[i]));
133   - futures.waitForFinished();
  158 + futures.addFuture(QtConcurrent::run(_train, transforms[i], &templatesList[i]));
  159 + futures.waitForFinished();
134 160 }
135 161  
136 162 void project(const Template &src, Template &dst) const
... ...
openbr/plugins/openbr_internal.h
... ... @@ -225,10 +225,6 @@ public:
225 225 }
226 226  
227 227 output->file = this->file;
228   - output->classes = classes;
229   - output->instances = instances;
230   - output->fraction = fraction;
231   -
232 228 output->init();
233 229  
234 230 return output;
... ...