From 3de17d865c779d6346524723ebcc5653f2f745fc Mon Sep 17 00:00:00 2001 From: Scott Klum Date: Fri, 30 Oct 2015 15:35:37 +0000 Subject: [PATCH] Minor cascade optimizations --- openbr/plugins/classification/cascade.cpp | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/openbr/plugins/classification/cascade.cpp b/openbr/plugins/classification/cascade.cpp index 196e400..20787b3 100644 --- a/openbr/plugins/classification/cascade.cpp +++ b/openbr/plugins/classification/cascade.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -118,7 +119,7 @@ class CascadeClassifier : public Classifier TemplateList posImages, negImages; TemplateList posSamples, negSamples; - QList indices; + QList negIndices, posIndices; int negIndex, posIndex, samplingRound; QMutex samplingMutex, miningMutex; @@ -132,7 +133,7 @@ class CascadeClassifier : public Classifier { if (posIndex >= posImages.size()) return false; - img = posImages[indices[posIndex++]]; + img = posImages[posIndices[posIndex++]]; return true; } @@ -144,7 +145,7 @@ class CascadeClassifier : public Classifier // Grab negative from list int count = negImages.size(); for (int i = 0; i < count; i++) { - negative = negImages[negIndex++]; + negative = negImages[negIndices[negIndex++]]; samplingRound += negIndex / count; samplingRound = samplingRound % (size.width * size.height); @@ -203,18 +204,17 @@ class CascadeClassifier : public Classifier << "\nTotal positive images:" << posImages.size() << "\nTotal negative images:" << negImages.size(); - indices = Common::RandSample(posImages.size(), posImages.size(), true); + posIndices = Common::RandSample(posImages.size(), posImages.size(), true); + negIndices = Common::RandSample(negImages.size(), negImages.size(), true); stages.reserve(numStages); for (int i = 0; i < numStages; i++) { - Classifier *next_stage = Classifier::make(stageDescription, NULL); - stages.append(next_stage); - } - - for (int i = 0; i < stages.size(); i++) { qDebug() << "===== TRAINING" << i << "stage ====="; qDebug() << "train(posSamples + negSamples); + stages.last()->train(posSamples + negSamples); qDebug() << "END>"; } -- libgit2 0.21.4