Commit 3de17d865c779d6346524723ebcc5653f2f745fc

Authored by Scott Klum
1 parent 3ff62749

Minor cascade optimizations

openbr/plugins/classification/cascade.cpp
1 1 #include <opencv2/imgproc/imgproc.hpp>
  2 +#include <opencv2/highgui/highgui.hpp>
2 3  
3 4 #include <openbr/plugins/openbr_internal.h>
4 5 #include <openbr/core/common.h>
... ... @@ -118,7 +119,7 @@ class CascadeClassifier : public Classifier
118 119 TemplateList posImages, negImages;
119 120 TemplateList posSamples, negSamples;
120 121  
121   - QList<int> indices;
  122 + QList<int> negIndices, posIndices;
122 123 int negIndex, posIndex, samplingRound;
123 124  
124 125 QMutex samplingMutex, miningMutex;
... ... @@ -132,7 +133,7 @@ class CascadeClassifier : public Classifier
132 133 {
133 134 if (posIndex >= posImages.size())
134 135 return false;
135   - img = posImages[indices[posIndex++]];
  136 + img = posImages[posIndices[posIndex++]];
136 137 return true;
137 138 }
138 139  
... ... @@ -144,7 +145,7 @@ class CascadeClassifier : public Classifier
144 145 // Grab negative from list
145 146 int count = negImages.size();
146 147 for (int i = 0; i < count; i++) {
147   - negative = negImages[negIndex++];
  148 + negative = negImages[negIndices[negIndex++]];
148 149  
149 150 samplingRound += negIndex / count;
150 151 samplingRound = samplingRound % (size.width * size.height);
... ... @@ -203,18 +204,17 @@ class CascadeClassifier : public Classifier
203 204 << "\nTotal positive images:" << posImages.size()
204 205 << "\nTotal negative images:" << negImages.size();
205 206  
206   - indices = Common::RandSample(posImages.size(), posImages.size(), true);
  207 + posIndices = Common::RandSample(posImages.size(), posImages.size(), true);
  208 + negIndices = Common::RandSample(negImages.size(), negImages.size(), true);
207 209  
208 210 stages.reserve(numStages);
209 211 for (int i = 0; i < numStages; i++) {
210   - Classifier *next_stage = Classifier::make(stageDescription, NULL);
211   - stages.append(next_stage);
212   - }
213   -
214   - for (int i = 0; i < stages.size(); i++) {
215 212 qDebug() << "===== TRAINING" << i << "stage =====";
216 213 qDebug() << "<BEGIN";
217 214  
  215 + Classifier *next_stage = Classifier::make(stageDescription, NULL);
  216 + stages.append(next_stage);
  217 +
218 218 float currFAR = getSamples();
219 219  
220 220 if (currFAR < maxFAR && !requireAllStages) {
... ... @@ -222,7 +222,7 @@ class CascadeClassifier : public Classifier
222 222 return;
223 223 }
224 224  
225   - stages[i]->train(posSamples + negSamples);
  225 + stages.last()->train(posSamples + negSamples);
226 226  
227 227 qDebug() << "END>";
228 228 }
... ...