diff --git a/openbr/plugins/classification/cascade.cpp b/openbr/plugins/classification/cascade.cpp index 166b027..5c532b6 100644 --- a/openbr/plugins/classification/cascade.cpp +++ b/openbr/plugins/classification/cascade.cpp @@ -1,6 +1,7 @@ #include #include +#include using namespace cv; @@ -25,6 +26,8 @@ struct ImageHandler stepFactor = 0.5F; round = 0; + indices = Common::RandSample(posImages.size(),posImages.size(),true); + return true; } @@ -85,7 +88,7 @@ struct ImageHandler if (posIdx >= posImages.size()) return false; - posImages[posIdx++].copyTo(_img); + posImages[indices[posIdx++]].copyTo(_img); return true; } @@ -100,6 +103,8 @@ struct ImageHandler float stepFactor; size_t round; Size winSize; + + QList indices; }; /*! @@ -139,10 +144,15 @@ class CascadeClassifier : public Classifier for (int i = 0; i < images.size(); i++) labels[i] == 1 ? posImages.append(images[i]) : negImages.append(images[i]); + stages.reserve(numStages); + for (int i = 0; i < numStages; i++) { + Classifier *next_stage = Classifier::make(stageDescription, NULL); + stages.append(next_stage); + } + ImageHandler imgHandler; - imgHandler.create(posImages, negImages, Size(24, 24)); + imgHandler.create(posImages, negImages, windowSize()); - stages.reserve(numStages); for (int i = 0; i < numStages; i++) { qDebug() << "===== TRAINING" << i << "stage ====="; qDebug() << "train(trainingImages, trainingLabels); - stages.append(next_stage); + stages[i]->train(trainingImages, trainingLabels); qDebug() << "END>"; } @@ -193,7 +201,7 @@ class CascadeClassifier : public Classifier return stages.first()->preprocess(image); } - Size windowSize(int *dx, int *dy) const + Size windowSize(int *dx = NULL, int *dy = NULL) const { return stages.first()->windowSize(dx, dy); }