Commit d16f9df0090c6fa75380510df3e899f37a28ae86
1 parent
63f5d257
Changes to training subsampling
Remove variables related to subsampling training data from Transform, move them to a new transform called DownsampleTraining, and perform subsampling in that transform, rather than as part of IndependentTransform (which depended on the subsampling variables in Transform). The prior approach was incompatible with using explicit class variable names, rather than assuming a fixed label variable. This is because the actual downsampling was carried out as part of IndependentTransform, which is not a visible part of the algorithm grammar. Removing this logic from Independent/Transform is generally desirable if only to avoid cluttering Transform with member variables that are only used some of the time, and used in a (fairly) restrictive way. The current approach of makring DownsampleTraining as an indepent transform is still limited since the downsample logic (still) cannot be applied without the split used in Independent, and also DownsampleTraining can only be implemented as a wrapper for another transform (because the argument to train is const, and performing the donwsample logic in project would also impact the testing case (since we have no way to exclude a transform used in training from use in testing, for a fixed algorithm)).
Showing
5 changed files
with
53 additions
and
43 deletions
openbr/openbr_plugin.cpp
| ... | ... | @@ -1083,9 +1083,6 @@ Transform::Transform(bool _independent, bool _trainable) |
| 1083 | 1083 | { |
| 1084 | 1084 | independent = _independent; |
| 1085 | 1085 | trainable = _trainable; |
| 1086 | - classes = std::numeric_limits<int>::max(); | |
| 1087 | - instances = std::numeric_limits<int>::max(); | |
| 1088 | - fraction = 1; | |
| 1089 | 1086 | } |
| 1090 | 1087 | |
| 1091 | 1088 | Transform *Transform::make(QString str, QObject *parent) |
| ... | ... | @@ -1141,9 +1138,6 @@ Transform *Transform::make(QString str, QObject *parent) |
| 1141 | 1138 | Transform *Transform::clone() const |
| 1142 | 1139 | { |
| 1143 | 1140 | Transform *clone = Factory<Transform>::make(file.flat()); |
| 1144 | - clone->classes = classes; | |
| 1145 | - clone->instances = instances; | |
| 1146 | - clone->fraction = fraction; | |
| 1147 | 1141 | return clone; |
| 1148 | 1142 | } |
| 1149 | 1143 | ... | ... |
openbr/openbr_plugin.h
| ... | ... | @@ -1068,12 +1068,6 @@ class BR_EXPORT Transform : public Object |
| 1068 | 1068 | Q_OBJECT |
| 1069 | 1069 | |
| 1070 | 1070 | public: |
| 1071 | - Q_PROPERTY(int classes READ get_classes WRITE set_classes RESET reset_classes STORED false) | |
| 1072 | - Q_PROPERTY(int instances READ get_instances WRITE set_instances RESET reset_instances STORED false) | |
| 1073 | - Q_PROPERTY(float fraction READ get_fraction WRITE set_fraction RESET reset_fraction STORED false) | |
| 1074 | - BR_PROPERTY(int, classes, std::numeric_limits<int>::max()) | |
| 1075 | - BR_PROPERTY(int, instances, std::numeric_limits<int>::max()) | |
| 1076 | - BR_PROPERTY(float, fraction, 1) | |
| 1077 | 1071 | bool independent, trainable; |
| 1078 | 1072 | |
| 1079 | 1073 | virtual ~Transform() {} | ... | ... |
openbr/plugins/algorithms.cpp
| ... | ... | @@ -48,7 +48,7 @@ class AlgorithmsInitializer : public Initializer |
| 48 | 48 | // Video |
| 49 | 49 | Globals->abbreviations.insert("DisplayVideo", "Stream([Show(false)+Discard])"); |
| 50 | 50 | Globals->abbreviations.insert("PerFrameDetection", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+ASEFEyes+RestoreMat(original)+Draw(inPlace=true),Show(false)+Discard])"); |
| 51 | - Globals->abbreviations.insert("AgeGenderDemo", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+Expand+<FaceClassificationRegistration>+<FaceClassificationExtraction>+(<AgeRegressor>+Rename(Subject,Age)+Discard)/(<GenderClassifier>+Rename(Subject,Gender)+Discard)+RestoreMat(original)+Draw(inPlace=true)+DrawPropertiesPoint([Age,Gender],Affine_0,inPlace=true)+SaveMat(original)+Discard+Contract,RestoreMat(original)+FPSCalc+Show(false,[AvgFPS,Age,Gender])+Discard])"); | |
| 51 | + Globals->abbreviations.insert("AgeGenderDemo", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+Expand+<FaceClassificationRegistration>+<FaceClassificationExtraction>+<AgeRegressor>/<GenderClassifier>+Discard+RestoreMat(original)+Draw(inPlace=true)+DrawPropertiesPoint([Age,Gender],Affine_0,inPlace=true)+SaveMat(original)+Discard+Contract,RestoreMat(original)+FPSCalc+Show(false,[AvgFPS,Age,Gender])+Discard])"); | |
| 52 | 52 | |
| 53 | 53 | // Generic Image Processing |
| 54 | 54 | Globals->abbreviations.insert("SIFT", "Open+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); |
| ... | ... | @@ -74,14 +74,14 @@ class AlgorithmsInitializer : public Initializer |
| 74 | 74 | Globals->abbreviations.insert("FaceDetection", "(Open+Cvt(Gray)+Cascade(FrontalFace))"); |
| 75 | 75 | Globals->abbreviations.insert("DenseLBP", "(Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59))"); |
| 76 | 76 | Globals->abbreviations.insert("DenseSIFT", "(Grid(10,10)+SIFTDescriptor(12)+ByRow)"); |
| 77 | - Globals->abbreviations.insert("FaceRecognitionRegistration", "(ASEFEyes+Affine(88,88,0.25,0.35)+FTE(DFFS,instances=1))"); | |
| 78 | - Globals->abbreviations.insert("FaceRecognitionExtraction", "(Mask+DenseSIFT/DenseLBP+PCA(0.95,instances=1)+Normalize(L2)+Cat)"); | |
| 79 | - Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+LDA(0.98,instances=-2)+Cat+PCA(768,instances=1))"); | |
| 77 | + Globals->abbreviations.insert("FaceRecognitionRegistration", "(ASEFEyes+Affine(88,88,0.25,0.35)+DownsampleTraining(FTE(DFFS),instances=1))"); | |
| 78 | + Globals->abbreviations.insert("FaceRecognitionExtraction", "(Mask+DenseSIFT/DenseLBP+DownsampleTraining(PCA(0.95),instances=1)+Normalize(L2)+Cat)"); | |
| 79 | + Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+DownsampleTraining(LDA(0.98),instances=-2)+Cat+DownsampleTraining(PCA(768),instances=1))"); | |
| 80 | 80 | Globals->abbreviations.insert("FaceRecognitionQuantization", "(Normalize(L1)+Quantize)"); |
| 81 | 81 | Globals->abbreviations.insert("FaceClassificationRegistration", "(ASEFEyes+Affine(56,72,0.33,0.45)+FTE(DFFS))"); |
| 82 | - Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+PCA(0.95,instances=-1)+Cat)"); | |
| 83 | - Globals->abbreviations.insert("AgeRegressor", "Center(Range,instances=-1)+SVM(RBF,EPS_SVR,inputVariable=Age,instances=100)"); | |
| 84 | - Globals->abbreviations.insert("GenderClassifier", "Center(Range,instances=-1)+SVM(RBF,C_SVC,inputVariable=Gender,instances=4000)"); | |
| 82 | + Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+DownsampleTraining(PCA(0.95),instances=-1, inputVariable=Gender)+Cat)"); | |
| 83 | + Globals->abbreviations.insert("AgeRegressor", "DownsampleTraining(Center(Range),instances=-1)+DownsampleTraining(SVM(RBF,EPS_SVR,inputVariable=Age),instances=100, inputVariable=Age)"); | |
| 84 | + Globals->abbreviations.insert("GenderClassifier", "DownsampleTraining(Center(Range),instances=-1, inputVariable=Gender)+DownsampleTraining(SVM(RBF,C_SVC,inputVariable=Gender),instances=4000, inputVariable=Gender)"); | |
| 85 | 85 | Globals->abbreviations.insert("UCharL1", "Unit(ByteL1)"); |
| 86 | 86 | } |
| 87 | 87 | }; | ... | ... |
openbr/plugins/independent.cpp
| ... | ... | @@ -9,16 +9,16 @@ using namespace cv; |
| 9 | 9 | namespace br |
| 10 | 10 | { |
| 11 | 11 | |
| 12 | -static TemplateList Downsample(const TemplateList &templates, const Transform *transform, const QString & inputVariable) | |
| 12 | +static TemplateList Downsample(const TemplateList &templates, int classes, int instances, float fraction, const QString & inputVariable) | |
| 13 | 13 | { |
| 14 | 14 | // Return early when no downsampling is required |
| 15 | - if ((transform->classes == std::numeric_limits<int>::max()) && | |
| 16 | - (transform->instances == std::numeric_limits<int>::max()) && | |
| 17 | - (transform->fraction >= 1)) | |
| 15 | + if ((classes == std::numeric_limits<int>::max()) && | |
| 16 | + (instances == std::numeric_limits<int>::max()) && | |
| 17 | + (fraction >= 1)) | |
| 18 | 18 | return templates; |
| 19 | 19 | |
| 20 | - const bool atLeast = transform->instances < 0; | |
| 21 | - const int instances = abs(transform->instances); | |
| 20 | + const bool atLeast = instances < 0; | |
| 21 | + instances = abs(instances); | |
| 22 | 22 | |
| 23 | 23 | QList<QString> allLabels = templates.get<QString>(inputVariable); |
| 24 | 24 | QList<QString> uniqueLabels = allLabels.toSet().toList(); |
| ... | ... | @@ -26,20 +26,20 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t |
| 26 | 26 | |
| 27 | 27 | QMap<QString,int> counts = templates.countValues<QString>(inputVariable, instances != std::numeric_limits<int>::max()); |
| 28 | 28 | |
| 29 | - if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max())) | |
| 29 | + if ((instances != std::numeric_limits<int>::max()) && (classes != std::numeric_limits<int>::max())) | |
| 30 | 30 | foreach (const QString & label, counts.keys()) |
| 31 | 31 | if (counts[label] < instances) |
| 32 | 32 | counts.remove(label); |
| 33 | 33 | |
| 34 | 34 | uniqueLabels = counts.keys(); |
| 35 | - if ((transform->classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < transform->classes)) | |
| 36 | - qWarning("Downsample requested %d classes but only %d are available.", transform->classes, uniqueLabels.size()); | |
| 35 | + if ((classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < classes)) | |
| 36 | + qWarning("Downsample requested %d classes but only %d are available.", classes, uniqueLabels.size()); | |
| 37 | 37 | |
| 38 | 38 | Common::seedRNG(); |
| 39 | 39 | QList<QString> selectedLabels = uniqueLabels; |
| 40 | - if (transform->classes < uniqueLabels.size()) { | |
| 40 | + if (classes < uniqueLabels.size()) { | |
| 41 | 41 | std::random_shuffle(selectedLabels.begin(), selectedLabels.end()); |
| 42 | - selectedLabels = selectedLabels.mid(0, transform->classes); | |
| 42 | + selectedLabels = selectedLabels.mid(0, classes); | |
| 43 | 43 | } |
| 44 | 44 | |
| 45 | 45 | TemplateList downsample; |
| ... | ... | @@ -56,14 +56,45 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t |
| 56 | 56 | downsample.append(templates.value(indices[j])); |
| 57 | 57 | } |
| 58 | 58 | |
| 59 | - if (transform->fraction < 1) { | |
| 59 | + if (fraction < 1) { | |
| 60 | 60 | std::random_shuffle(downsample.begin(), downsample.end()); |
| 61 | - downsample = downsample.mid(0, downsample.size()*transform->fraction); | |
| 61 | + downsample = downsample.mid(0, downsample.size()*fraction); | |
| 62 | 62 | } |
| 63 | 63 | |
| 64 | 64 | return downsample; |
| 65 | 65 | } |
| 66 | 66 | |
| 67 | +class DownsampleTrainingTransform : public Transform | |
| 68 | +{ | |
| 69 | + Q_OBJECT | |
| 70 | + Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform STORED true) | |
| 71 | + Q_PROPERTY(int classes READ get_classes WRITE set_classes RESET reset_classes STORED false) | |
| 72 | + Q_PROPERTY(int instances READ get_instances WRITE set_instances RESET reset_instances STORED false) | |
| 73 | + Q_PROPERTY(float fraction READ get_fraction WRITE set_fraction RESET reset_fraction STORED false) | |
| 74 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | |
| 75 | + BR_PROPERTY(br::Transform*, transform, NULL) | |
| 76 | + BR_PROPERTY(int, classes, std::numeric_limits<int>::max()) | |
| 77 | + BR_PROPERTY(int, instances, std::numeric_limits<int>::max()) | |
| 78 | + BR_PROPERTY(float, fraction, 1) | |
| 79 | + BR_PROPERTY(QString, inputVariable, "Label") | |
| 80 | + | |
| 81 | + void project(const Template & src, Template & dst) const | |
| 82 | + { | |
| 83 | + transform->project(src,dst); | |
| 84 | + } | |
| 85 | + | |
| 86 | + | |
| 87 | + void train(const TemplateList &data) | |
| 88 | + { | |
| 89 | + if (!transform || !transform->trainable) | |
| 90 | + return; | |
| 91 | + | |
| 92 | + TemplateList downsampled = Downsample(data, classes, instances, fraction, inputVariable); | |
| 93 | + transform->train(downsampled); | |
| 94 | + } | |
| 95 | +}; | |
| 96 | +BR_REGISTER(Transform, DownsampleTrainingTransform) | |
| 97 | + | |
| 67 | 98 | /*! |
| 68 | 99 | * \ingroup transforms |
| 69 | 100 | * \brief Clones the transform so that it can be applied independently. |
| ... | ... | @@ -74,9 +105,7 @@ class IndependentTransform : public MetaTransform |
| 74 | 105 | { |
| 75 | 106 | Q_OBJECT |
| 76 | 107 | Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform STORED false) |
| 77 | - Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | |
| 78 | 108 | BR_PROPERTY(br::Transform*, transform, NULL) |
| 79 | - BR_PROPERTY(QString, inputVariable, "Label") | |
| 80 | 109 | |
| 81 | 110 | QList<Transform*> transforms; |
| 82 | 111 | |
| ... | ... | @@ -124,13 +153,10 @@ class IndependentTransform : public MetaTransform |
| 124 | 153 | while (transforms.size() < templatesList.size()) |
| 125 | 154 | transforms.append(transform->clone()); |
| 126 | 155 | |
| 127 | - for (int i=0; i<templatesList.size(); i++) | |
| 128 | - templatesList[i] = Downsample(templatesList[i], transforms[i], inputVariable); | |
| 129 | - | |
| 130 | 156 | QFutureSynchronizer<void> futures; |
| 131 | 157 | for (int i=0; i<templatesList.size(); i++) |
| 132 | - futures.addFuture(QtConcurrent::run(_train, transforms[i], &templatesList[i])); | |
| 133 | - futures.waitForFinished(); | |
| 158 | + futures.addFuture(QtConcurrent::run(_train, transforms[i], &templatesList[i])); | |
| 159 | + futures.waitForFinished(); | |
| 134 | 160 | } |
| 135 | 161 | |
| 136 | 162 | void project(const Template &src, Template &dst) const | ... | ... |
openbr/plugins/openbr_internal.h