diff --git a/openbr/openbr_plugin.h b/openbr/openbr_plugin.h index 88b1a89..22b960b 100644 --- a/openbr/openbr_plugin.h +++ b/openbr/openbr_plugin.h @@ -1423,11 +1423,10 @@ public: static Classifier *make(QString str, QObject *parent); /*!< \brief Make a classifier from a string. */ virtual void train(const QList &images, const QList &labels) = 0; - // By convention, classify should return a value normalized such that the threshold is 0. Negative values - // can be interpreted as a negative classification and positive values as a positive classification. - virtual float classify(const cv::Mat &image) const = 0; + virtual float classify(const cv::Mat &image, float &confidence) const = 0; // Slots for representation + virtual cv::Mat preprocess(const cv::Mat &image) const = 0; virtual cv::Size windowSize() const = 0; // OpenCV compatibility diff --git a/openbr/plugins/classification/boostedforest.cpp b/openbr/plugins/classification/boostedforest.cpp index 5dc4dd3..e022777 100644 --- a/openbr/plugins/classification/boostedforest.cpp +++ b/openbr/plugins/classification/boostedforest.cpp @@ -129,12 +129,9 @@ class BoostedForestClassifier : public Classifier } } - float classify(const Mat &_image) const + float classify(const Mat &image, float &confidence) const { - Mat image; - representation->preprocess(_image, image); - - float sum = 0; + confidence = 0; for (int i = 0; i < classifiers.size(); i++) { Node *node = classifiers[i]; @@ -147,11 +144,10 @@ class BoostedForestClassifier : public Classifier node = val <= node->threshold ? node->left : node->right; } } - qDebug("value: %f", node->value); - sum += node->value; + confidence += node->value; } - return sum < threshold - THRESHOLD_EPS ? -std::abs(sum) : std::abs(sum); + return confidence < threshold - THRESHOLD_EPS ? 0.0f : 1.0f; } int numFeatures() const @@ -164,6 +160,13 @@ class BoostedForestClassifier : public Classifier return representation->maxCatCount(); } + Mat preprocess(const Mat &image) const + { + Mat dst; + representation->preprocess(image, dst); + return dst; + } + Size windowSize() const { return representation->preWindowSize(); diff --git a/openbr/plugins/classification/cascade.cpp b/openbr/plugins/classification/cascade.cpp index d663519..89cb6fa 100644 --- a/openbr/plugins/classification/cascade.cpp +++ b/openbr/plugins/classification/cascade.cpp @@ -152,18 +152,23 @@ class CascadeClassifier : public Classifier } } - float classify(const Mat &image) const + float classify(const Mat &image, float &confidence) const { - if (stages.empty()) - return 1.0f; + if (stages.empty()) { + confidence = 0.0f; + return 1.0f; + } - float val = 0.0f; for (int i = 0; i < stages.size(); i++) { - val = stages[i]->classify(image); - if (val < 0.0f) - return stages.size() - i < 4 ? i * val : 0.0f; + float result = stages[i]->classify(image, confidence); + if (result == 0.0f) { + //confidence *= i; + return i; + } } - return stages.size() * val; + + //confidence *= stages.size(); + return stages.size(); } int numFeatures() const @@ -176,7 +181,12 @@ class CascadeClassifier : public Classifier return stages.first()->maxCatCount(); } - cv::Size windowSize() const + Mat preprocess(const Mat &image) const + { + return stages.first()->preprocess(image); + } + + Size windowSize() const { return stages.first()->windowSize(); } @@ -207,12 +217,14 @@ private: { imgHandler.restart(); + float confidence = 0.0f; // not used; + while (images.size() < numPos) { Mat pos(imgHandler.winSize, CV_8UC1); if (!imgHandler.getPos(pos)) qFatal("Cannot get another positive sample!"); - if (classify(pos) > 0.0f) { + if (classify(pos, confidence) > 0.0f) { printf("POS current samples: %d\r", images.size()); images.append(pos); labels.append(1.0f); @@ -228,7 +240,7 @@ private: if (!imgHandler.getNeg(neg)) qFatal("Cannot get another negative sample!"); - if (classify(neg) > 0.0f) { + if (classify(neg, confidence) > 0.0f) { printf("NEG current samples: %d\r", images.size() - posCount); images.append(neg); labels.append(0.0f); diff --git a/openbr/plugins/imgproc/slidingwindow.cpp b/openbr/plugins/imgproc/slidingwindow.cpp index ab2f96e..b15b390 100644 --- a/openbr/plugins/imgproc/slidingwindow.cpp +++ b/openbr/plugins/imgproc/slidingwindow.cpp @@ -14,15 +14,10 @@ * limitations under the License. * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ -#include - #include +#include #include #include -#include - -#include -#include using namespace cv; @@ -31,22 +26,23 @@ namespace br /*! * \ingroup transforms - * \brief Applies a classifier to a sliding window. - * \author Jordan Cheney \cite JordanCheney + * \brief Sliding Window Framework + * \author Jordan Cheney */ - -class SlidingWindowTransform : public Transform +class SlidingWindowTransform : public UntrainableMetaTransform { Q_OBJECT Q_PROPERTY(br::Classifier *classifier READ get_classifier WRITE set_classifier RESET reset_classifier STORED false) + Q_PROPERTY(int minSize READ get_minSize WRITE set_minSize RESET reset_minSize STORED false) Q_PROPERTY(int maxSize READ get_maxSize WRITE set_maxSize RESET reset_maxSize STORED false) Q_PROPERTY(float scaleFactor READ get_scaleFactor WRITE set_scaleFactor RESET reset_scaleFactor STORED false) Q_PROPERTY(int minNeighbors READ get_minNeighbors WRITE set_minNeighbors RESET reset_minNeighbors STORED false) Q_PROPERTY(float eps READ get_eps WRITE set_eps RESET reset_eps STORED false) - Q_PROPERTY(QString cascadeDir READ get_cascadeDir WRITE set_cascadeDir RESET reset_cascadeDir STORED false) + Q_PROPERTY(QString model READ get_model WRITE set_model RESET reset_model STORED false) + BR_PROPERTY(br::Classifier *, classifier, NULL) BR_PROPERTY(int, minSize, 20) BR_PROPERTY(int, maxSize, -1) @@ -54,7 +50,13 @@ class SlidingWindowTransform : public Transform BR_PROPERTY(int, minNeighbors, 5) BR_PROPERTY(float, eps, 0.2) - BR_PROPERTY(QString, cascadeDir, "") + BR_PROPERTY(QString, model, "") + + void init() + { + QDataStream stream; + load(stream); + } void train(const TemplateList &data) { @@ -70,6 +72,9 @@ class SlidingWindowTransform : public Transform void project(const TemplateList &src, TemplateList &dst) const { + Size minObjectSize(minSize, minSize); + Size maxObjectSize; + foreach (const Template &t, src) { const bool enrollAll = t.file.getBool("enrollAll"); @@ -80,26 +85,23 @@ class SlidingWindowTransform : public Transform continue; } - for (int i = 0; i < t.size(); i++) { - Mat image; - OpenCVUtils::cvtUChar(t[i], image); - + for (int i=0; i rects; std::vector rejectLevels; std::vector levelWeights; - Size minObjectSize(minSize, minSize); - Size maxObjectSize(maxSize, maxSize); - if (maxObjectSize.height <= 0 || maxObjectSize.width <= 0) - maxObjectSize = image.size(); + if (maxObjectSize.height == 0 || maxObjectSize.width == 0) + maxObjectSize = m.size(); - Mat imageBuffer(image.rows + 1, image.cols + 1, CV_8U); + Mat imageBuffer(m.rows + 1, m.cols + 1, CV_8U); for (double factor = 1; ; factor *= scaleFactor) { - Size originalWindowSize = classifier->windowSize(); + Size originalWindowSize(24, 24); Size windowSize(cvRound(originalWindowSize.width*factor), cvRound(originalWindowSize.height*factor) ); - Size scaledImageSize(cvRound(image.cols/factor ), cvRound(image.rows/factor)); + Size scaledImageSize(cvRound(m.cols/factor ), cvRound(m.rows/factor)); Size processingRectSize(scaledImageSize.width - originalWindowSize.width, scaledImageSize.height - originalWindowSize.height); if (processingRectSize.width <= 0 || processingRectSize.height <= 0) @@ -110,22 +112,26 @@ class SlidingWindowTransform : public Transform continue; Mat scaledImage(scaledImageSize, CV_8U, imageBuffer.data); - resize(image, scaledImage, scaledImageSize, 0, 0, CV_INTER_LINEAR); + resize(m, scaledImage, scaledImageSize, 0, 0, CV_INTER_LINEAR); + + Mat repImage = classifier->preprocess(scaledImage); + + int step = factor > 2. ? 1 : 2; + for (int y = 0; y < processingRectSize.height; y += step) { + for (int x = 0; x < processingRectSize.width; x += step) { + Mat window = repImage(Rect(Point(x, y), Size(25,25))).clone(); - int yStep = factor > 2. ? 1 : 2; - for (int y = 0; y < processingRectSize.height; y += yStep) { - for (int x = 0; x < processingRectSize.width; x += yStep) { - Mat window = scaledImage(Rect(Point(x, y), classifier->windowSize())).clone(); + float gypWeight; + int result = classifier->classify(window, gypWeight); - float result = classifier->classify(window); - qDebug("result: %f", result); - if (result > 0) { + if (12 - result < 4) { rects.push_back(Rect(cvRound(x*factor), cvRound(y*factor), windowSize.width, windowSize.height)); - rejectLevels.push_back(1); - levelWeights.push_back(result); + rejectLevels.push_back(result); + levelWeights.push_back(gypWeight); } + if (result == 0) - x = yStep; + x += step; } } } @@ -133,57 +139,33 @@ class SlidingWindowTransform : public Transform groupRectangles(rects, rejectLevels, levelWeights, minNeighbors, eps); if (!enrollAll && rects.empty()) - rects.push_back(Rect(0, 0, image.cols, image.rows)); + rects.push_back(Rect(0, 0, m.cols, m.rows)); - for (size_t j = 0; j < rects.size(); j++) { - Template u(t.file, image); + for (size_t j=0; j j) u.file.set("Confidence", rejectLevels[j]*levelWeights[j]); else u.file.set("Confidence", 1); const QRectF rect = OpenCVUtils::fromRect(rects[j]); u.file.appendRect(rect); - u.file.set("Face", rect); + u.file.set(model, rect); dst.append(u); } } } - } + } void load(QDataStream &stream) { - (void) stream; + (void)stream; - QString filename = Globals->sdkPath + "/share/openbr/models/openbrcascades/" + cascadeDir + "/cascade.xml"; + QString filename = Globals->sdkPath + "/share/openbr/models/openbrcascades/" + model + "/cascade.xml"; FileStorage fs(filename.toStdString(), FileStorage::READ); if (!fs.isOpened()) return; classifier->read(fs.getFirstTopLevelNode()); - - return; - } - - void store(QDataStream &stream) const - { - (void) stream; - - QString path = Globals->sdkPath + "/share/openbr/models/openbrcascades/" + cascadeDir; - QtUtils::touchDir(QDir(path)); - - QString filename = path + "/cascade.xml"; - FileStorage fs(filename.toStdString(), FileStorage::WRITE); - - if (!fs.isOpened()) { - qWarning("Unable to open file: %s", qPrintable(filename)); - return; - } - - fs << FileStorage::getDefaultObjectName(filename.toStdString()) << "{"; - - classifier->write(fs); - - fs << "}"; } };