Commit baa5c5eda65c9af74b2ffec15cbe3bb6f6f668e4

Authored by Jordan Cheney
1 parent 39f5a23e

Working frontend cascade with some (very) rough edges

openbr/openbr_plugin.h
... ... @@ -1423,11 +1423,10 @@ public:
1423 1423 static Classifier *make(QString str, QObject *parent); /*!< \brief Make a classifier from a string. */
1424 1424  
1425 1425 virtual void train(const QList<cv::Mat> &images, const QList<float> &labels) = 0;
1426   - // By convention, classify should return a value normalized such that the threshold is 0. Negative values
1427   - // can be interpreted as a negative classification and positive values as a positive classification.
1428   - virtual float classify(const cv::Mat &image) const = 0;
  1426 + virtual float classify(const cv::Mat &image, float &confidence) const = 0;
1429 1427  
1430 1428 // Slots for representation
  1429 + virtual cv::Mat preprocess(const cv::Mat &image) const = 0;
1431 1430 virtual cv::Size windowSize() const = 0;
1432 1431  
1433 1432 // OpenCV compatibility
... ...
openbr/plugins/classification/boostedforest.cpp
... ... @@ -129,12 +129,9 @@ class BoostedForestClassifier : public Classifier
129 129 }
130 130 }
131 131  
132   - float classify(const Mat &_image) const
  132 + float classify(const Mat &image, float &confidence) const
133 133 {
134   - Mat image;
135   - representation->preprocess(_image, image);
136   -
137   - float sum = 0;
  134 + confidence = 0;
138 135 for (int i = 0; i < classifiers.size(); i++) {
139 136 Node *node = classifiers[i];
140 137  
... ... @@ -147,11 +144,10 @@ class BoostedForestClassifier : public Classifier
147 144 node = val <= node->threshold ? node->left : node->right;
148 145 }
149 146 }
150   - qDebug("value: %f", node->value);
151   - sum += node->value;
  147 + confidence += node->value;
152 148 }
153 149  
154   - return sum < threshold - THRESHOLD_EPS ? -std::abs(sum) : std::abs(sum);
  150 + return confidence < threshold - THRESHOLD_EPS ? 0.0f : 1.0f;
155 151 }
156 152  
157 153 int numFeatures() const
... ... @@ -164,6 +160,13 @@ class BoostedForestClassifier : public Classifier
164 160 return representation->maxCatCount();
165 161 }
166 162  
  163 + Mat preprocess(const Mat &image) const
  164 + {
  165 + Mat dst;
  166 + representation->preprocess(image, dst);
  167 + return dst;
  168 + }
  169 +
167 170 Size windowSize() const
168 171 {
169 172 return representation->preWindowSize();
... ...
openbr/plugins/classification/cascade.cpp
... ... @@ -152,18 +152,23 @@ class CascadeClassifier : public Classifier
152 152 }
153 153 }
154 154  
155   - float classify(const Mat &image) const
  155 + float classify(const Mat &image, float &confidence) const
156 156 {
157   - if (stages.empty())
158   - return 1.0f;
  157 + if (stages.empty()) {
  158 + confidence = 0.0f;
  159 + return 1.0f;
  160 + }
159 161  
160   - float val = 0.0f;
161 162 for (int i = 0; i < stages.size(); i++) {
162   - val = stages[i]->classify(image);
163   - if (val < 0.0f)
164   - return stages.size() - i < 4 ? i * val : 0.0f;
  163 + float result = stages[i]->classify(image, confidence);
  164 + if (result == 0.0f) {
  165 + //confidence *= i;
  166 + return i;
  167 + }
165 168 }
166   - return stages.size() * val;
  169 +
  170 + //confidence *= stages.size();
  171 + return stages.size();
167 172 }
168 173  
169 174 int numFeatures() const
... ... @@ -176,7 +181,12 @@ class CascadeClassifier : public Classifier
176 181 return stages.first()->maxCatCount();
177 182 }
178 183  
179   - cv::Size windowSize() const
  184 + Mat preprocess(const Mat &image) const
  185 + {
  186 + return stages.first()->preprocess(image);
  187 + }
  188 +
  189 + Size windowSize() const
180 190 {
181 191 return stages.first()->windowSize();
182 192 }
... ... @@ -207,12 +217,14 @@ private:
207 217 {
208 218 imgHandler.restart();
209 219  
  220 + float confidence = 0.0f; // not used;
  221 +
210 222 while (images.size() < numPos) {
211 223 Mat pos(imgHandler.winSize, CV_8UC1);
212 224 if (!imgHandler.getPos(pos))
213 225 qFatal("Cannot get another positive sample!");
214 226  
215   - if (classify(pos) > 0.0f) {
  227 + if (classify(pos, confidence) > 0.0f) {
216 228 printf("POS current samples: %d\r", images.size());
217 229 images.append(pos);
218 230 labels.append(1.0f);
... ... @@ -228,7 +240,7 @@ private:
228 240 if (!imgHandler.getNeg(neg))
229 241 qFatal("Cannot get another negative sample!");
230 242  
231   - if (classify(neg) > 0.0f) {
  243 + if (classify(neg, confidence) > 0.0f) {
232 244 printf("NEG current samples: %d\r", images.size() - posCount);
233 245 images.append(neg);
234 246 labels.append(0.0f);
... ...
openbr/plugins/imgproc/slidingwindow.cpp
... ... @@ -14,15 +14,10 @@
14 14 * limitations under the License. *
15 15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
16 16  
17   -#include <fstream>
18   -
19 17 #include <openbr/plugins/openbr_internal.h>
  18 +#include <openbr/core/cascade.h>
20 19 #include <openbr/core/opencvutils.h>
21 20 #include <openbr/core/qtutils.h>
22   -#include <openbr/core/cascade.h>
23   -
24   -#include <opencv2/highgui/highgui.hpp>
25   -#include <opencv2/imgproc/imgproc.hpp>
26 21  
27 22 using namespace cv;
28 23  
... ... @@ -31,22 +26,23 @@ namespace br
31 26  
32 27 /*!
33 28 * \ingroup transforms
34   - * \brief Applies a classifier to a sliding window.
35   - * \author Jordan Cheney \cite JordanCheney
  29 + * \brief Sliding Window Framework
  30 + * \author Jordan Cheney
36 31 */
37   -
38   -class SlidingWindowTransform : public Transform
  32 +class SlidingWindowTransform : public UntrainableMetaTransform
39 33 {
40 34 Q_OBJECT
41 35  
42 36 Q_PROPERTY(br::Classifier *classifier READ get_classifier WRITE set_classifier RESET reset_classifier STORED false)
  37 +
43 38 Q_PROPERTY(int minSize READ get_minSize WRITE set_minSize RESET reset_minSize STORED false)
44 39 Q_PROPERTY(int maxSize READ get_maxSize WRITE set_maxSize RESET reset_maxSize STORED false)
45 40 Q_PROPERTY(float scaleFactor READ get_scaleFactor WRITE set_scaleFactor RESET reset_scaleFactor STORED false)
46 41 Q_PROPERTY(int minNeighbors READ get_minNeighbors WRITE set_minNeighbors RESET reset_minNeighbors STORED false)
47 42 Q_PROPERTY(float eps READ get_eps WRITE set_eps RESET reset_eps STORED false)
48 43  
49   - Q_PROPERTY(QString cascadeDir READ get_cascadeDir WRITE set_cascadeDir RESET reset_cascadeDir STORED false)
  44 + Q_PROPERTY(QString model READ get_model WRITE set_model RESET reset_model STORED false)
  45 +
50 46 BR_PROPERTY(br::Classifier *, classifier, NULL)
51 47 BR_PROPERTY(int, minSize, 20)
52 48 BR_PROPERTY(int, maxSize, -1)
... ... @@ -54,7 +50,13 @@ class SlidingWindowTransform : public Transform
54 50 BR_PROPERTY(int, minNeighbors, 5)
55 51 BR_PROPERTY(float, eps, 0.2)
56 52  
57   - BR_PROPERTY(QString, cascadeDir, "")
  53 + BR_PROPERTY(QString, model, "")
  54 +
  55 + void init()
  56 + {
  57 + QDataStream stream;
  58 + load(stream);
  59 + }
58 60  
59 61 void train(const TemplateList &data)
60 62 {
... ... @@ -70,6 +72,9 @@ class SlidingWindowTransform : public Transform
70 72  
71 73 void project(const TemplateList &src, TemplateList &dst) const
72 74 {
  75 + Size minObjectSize(minSize, minSize);
  76 + Size maxObjectSize;
  77 +
73 78 foreach (const Template &t, src) {
74 79 const bool enrollAll = t.file.getBool("enrollAll");
75 80  
... ... @@ -80,26 +85,23 @@ class SlidingWindowTransform : public Transform
80 85 continue;
81 86 }
82 87  
83   - for (int i = 0; i < t.size(); i++) {
84   - Mat image;
85   - OpenCVUtils::cvtUChar(t[i], image);
86   -
  88 + for (int i=0; i<t.size(); i++) {
  89 + Mat m;
  90 + OpenCVUtils::cvtUChar(t[i], m);
87 91 std::vector<Rect> rects;
88 92 std::vector<int> rejectLevels;
89 93 std::vector<double> levelWeights;
90 94  
91   - Size minObjectSize(minSize, minSize);
92   - Size maxObjectSize(maxSize, maxSize);
93   - if (maxObjectSize.height <= 0 || maxObjectSize.width <= 0)
94   - maxObjectSize = image.size();
  95 + if (maxObjectSize.height == 0 || maxObjectSize.width == 0)
  96 + maxObjectSize = m.size();
95 97  
96   - Mat imageBuffer(image.rows + 1, image.cols + 1, CV_8U);
  98 + Mat imageBuffer(m.rows + 1, m.cols + 1, CV_8U);
97 99  
98 100 for (double factor = 1; ; factor *= scaleFactor) {
99   - Size originalWindowSize = classifier->windowSize();
  101 + Size originalWindowSize(24, 24);
100 102  
101 103 Size windowSize(cvRound(originalWindowSize.width*factor), cvRound(originalWindowSize.height*factor) );
102   - Size scaledImageSize(cvRound(image.cols/factor ), cvRound(image.rows/factor));
  104 + Size scaledImageSize(cvRound(m.cols/factor ), cvRound(m.rows/factor));
103 105 Size processingRectSize(scaledImageSize.width - originalWindowSize.width, scaledImageSize.height - originalWindowSize.height);
104 106  
105 107 if (processingRectSize.width <= 0 || processingRectSize.height <= 0)
... ... @@ -110,22 +112,26 @@ class SlidingWindowTransform : public Transform
110 112 continue;
111 113  
112 114 Mat scaledImage(scaledImageSize, CV_8U, imageBuffer.data);
113   - resize(image, scaledImage, scaledImageSize, 0, 0, CV_INTER_LINEAR);
  115 + resize(m, scaledImage, scaledImageSize, 0, 0, CV_INTER_LINEAR);
  116 +
  117 + Mat repImage = classifier->preprocess(scaledImage);
  118 +
  119 + int step = factor > 2. ? 1 : 2;
  120 + for (int y = 0; y < processingRectSize.height; y += step) {
  121 + for (int x = 0; x < processingRectSize.width; x += step) {
  122 + Mat window = repImage(Rect(Point(x, y), Size(25,25))).clone();
114 123  
115   - int yStep = factor > 2. ? 1 : 2;
116   - for (int y = 0; y < processingRectSize.height; y += yStep) {
117   - for (int x = 0; x < processingRectSize.width; x += yStep) {
118   - Mat window = scaledImage(Rect(Point(x, y), classifier->windowSize())).clone();
  124 + float gypWeight;
  125 + int result = classifier->classify(window, gypWeight);
119 126  
120   - float result = classifier->classify(window);
121   - qDebug("result: %f", result);
122   - if (result > 0) {
  127 + if (12 - result < 4) {
123 128 rects.push_back(Rect(cvRound(x*factor), cvRound(y*factor), windowSize.width, windowSize.height));
124   - rejectLevels.push_back(1);
125   - levelWeights.push_back(result);
  129 + rejectLevels.push_back(result);
  130 + levelWeights.push_back(gypWeight);
126 131 }
  132 +
127 133 if (result == 0)
128   - x = yStep;
  134 + x += step;
129 135 }
130 136 }
131 137 }
... ... @@ -133,57 +139,33 @@ class SlidingWindowTransform : public Transform
133 139 groupRectangles(rects, rejectLevels, levelWeights, minNeighbors, eps);
134 140  
135 141 if (!enrollAll && rects.empty())
136   - rects.push_back(Rect(0, 0, image.cols, image.rows));
  142 + rects.push_back(Rect(0, 0, m.cols, m.rows));
137 143  
138   - for (size_t j = 0; j < rects.size(); j++) {
139   - Template u(t.file, image);
  144 + for (size_t j=0; j<rects.size(); j++) {
  145 + Template u(t.file, m);
140 146 if (rejectLevels.size() > j)
141 147 u.file.set("Confidence", rejectLevels[j]*levelWeights[j]);
142 148 else
143 149 u.file.set("Confidence", 1);
144 150 const QRectF rect = OpenCVUtils::fromRect(rects[j]);
145 151 u.file.appendRect(rect);
146   - u.file.set("Face", rect);
  152 + u.file.set(model, rect);
147 153 dst.append(u);
148 154 }
149 155 }
150 156 }
151   - }
  157 + }
152 158  
153 159 void load(QDataStream &stream)
154 160 {
155   - (void) stream;
  161 + (void)stream;
156 162  
157   - QString filename = Globals->sdkPath + "/share/openbr/models/openbrcascades/" + cascadeDir + "/cascade.xml";
  163 + QString filename = Globals->sdkPath + "/share/openbr/models/openbrcascades/" + model + "/cascade.xml";
158 164 FileStorage fs(filename.toStdString(), FileStorage::READ);
159 165 if (!fs.isOpened())
160 166 return;
161 167  
162 168 classifier->read(fs.getFirstTopLevelNode());
163   -
164   - return;
165   - }
166   -
167   - void store(QDataStream &stream) const
168   - {
169   - (void) stream;
170   -
171   - QString path = Globals->sdkPath + "/share/openbr/models/openbrcascades/" + cascadeDir;
172   - QtUtils::touchDir(QDir(path));
173   -
174   - QString filename = path + "/cascade.xml";
175   - FileStorage fs(filename.toStdString(), FileStorage::WRITE);
176   -
177   - if (!fs.isOpened()) {
178   - qWarning("Unable to open file: %s", qPrintable(filename));
179   - return;
180   - }
181   -
182   - fs << FileStorage::getDefaultObjectName(filename.toStdString()) << "{";
183   -
184   - classifier->write(fs);
185   -
186   - fs << "}";
187 169 }
188 170 };
189 171  
... ...