Commit 3de17d865c779d6346524723ebcc5653f2f745fc
1 parent
3ff62749
Minor cascade optimizations
Showing
1 changed file
with
10 additions
and
10 deletions
openbr/plugins/classification/cascade.cpp
| 1 | 1 | #include <opencv2/imgproc/imgproc.hpp> |
| 2 | +#include <opencv2/highgui/highgui.hpp> | |
| 2 | 3 | |
| 3 | 4 | #include <openbr/plugins/openbr_internal.h> |
| 4 | 5 | #include <openbr/core/common.h> |
| ... | ... | @@ -118,7 +119,7 @@ class CascadeClassifier : public Classifier |
| 118 | 119 | TemplateList posImages, negImages; |
| 119 | 120 | TemplateList posSamples, negSamples; |
| 120 | 121 | |
| 121 | - QList<int> indices; | |
| 122 | + QList<int> negIndices, posIndices; | |
| 122 | 123 | int negIndex, posIndex, samplingRound; |
| 123 | 124 | |
| 124 | 125 | QMutex samplingMutex, miningMutex; |
| ... | ... | @@ -132,7 +133,7 @@ class CascadeClassifier : public Classifier |
| 132 | 133 | { |
| 133 | 134 | if (posIndex >= posImages.size()) |
| 134 | 135 | return false; |
| 135 | - img = posImages[indices[posIndex++]]; | |
| 136 | + img = posImages[posIndices[posIndex++]]; | |
| 136 | 137 | return true; |
| 137 | 138 | } |
| 138 | 139 | |
| ... | ... | @@ -144,7 +145,7 @@ class CascadeClassifier : public Classifier |
| 144 | 145 | // Grab negative from list |
| 145 | 146 | int count = negImages.size(); |
| 146 | 147 | for (int i = 0; i < count; i++) { |
| 147 | - negative = negImages[negIndex++]; | |
| 148 | + negative = negImages[negIndices[negIndex++]]; | |
| 148 | 149 | |
| 149 | 150 | samplingRound += negIndex / count; |
| 150 | 151 | samplingRound = samplingRound % (size.width * size.height); |
| ... | ... | @@ -203,18 +204,17 @@ class CascadeClassifier : public Classifier |
| 203 | 204 | << "\nTotal positive images:" << posImages.size() |
| 204 | 205 | << "\nTotal negative images:" << negImages.size(); |
| 205 | 206 | |
| 206 | - indices = Common::RandSample(posImages.size(), posImages.size(), true); | |
| 207 | + posIndices = Common::RandSample(posImages.size(), posImages.size(), true); | |
| 208 | + negIndices = Common::RandSample(negImages.size(), negImages.size(), true); | |
| 207 | 209 | |
| 208 | 210 | stages.reserve(numStages); |
| 209 | 211 | for (int i = 0; i < numStages; i++) { |
| 210 | - Classifier *next_stage = Classifier::make(stageDescription, NULL); | |
| 211 | - stages.append(next_stage); | |
| 212 | - } | |
| 213 | - | |
| 214 | - for (int i = 0; i < stages.size(); i++) { | |
| 215 | 212 | qDebug() << "===== TRAINING" << i << "stage ====="; |
| 216 | 213 | qDebug() << "<BEGIN"; |
| 217 | 214 | |
| 215 | + Classifier *next_stage = Classifier::make(stageDescription, NULL); | |
| 216 | + stages.append(next_stage); | |
| 217 | + | |
| 218 | 218 | float currFAR = getSamples(); |
| 219 | 219 | |
| 220 | 220 | if (currFAR < maxFAR && !requireAllStages) { |
| ... | ... | @@ -222,7 +222,7 @@ class CascadeClassifier : public Classifier |
| 222 | 222 | return; |
| 223 | 223 | } |
| 224 | 224 | |
| 225 | - stages[i]->train(posSamples + negSamples); | |
| 225 | + stages.last()->train(posSamples + negSamples); | |
| 226 | 226 | |
| 227 | 227 | qDebug() << "END>"; |
| 228 | 228 | } | ... | ... |