Commit 373069f680c4ce3b7cf41497db2a1cb16e954bf8

Authored by Scott Klum
1 parent decda7bc

Refactored random forests and forest induction

Showing 1 changed file with 96 additions and 60 deletions
openbr/plugins/tree.cpp
@@ -35,7 +35,7 @@ static void loadModel(CvStatModel &model, QDataStream &stream) @@ -35,7 +35,7 @@ static void loadModel(CvStatModel &model, QDataStream &stream)
35 stream >> data; 35 stream >> data;
36 36
37 // Create local file 37 // Create local file
38 - QTemporaryFile tempFile(QDir::tempPath()+"/model"); 38 + QTemporaryFile tempFile(QDir::tempPath()+"/"+QString::number(rand()));
39 tempFile.open(); 39 tempFile.open();
40 tempFile.write(data); 40 tempFile.write(data);
41 tempFile.close(); 41 tempFile.close();
@@ -53,24 +53,6 @@ static void loadModel(CvStatModel &model, QDataStream &stream) @@ -53,24 +53,6 @@ static void loadModel(CvStatModel &model, QDataStream &stream)
53 class ForestTransform : public Transform 53 class ForestTransform : public Transform
54 { 54 {
55 Q_OBJECT 55 Q_OBJECT
56 - Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED false)  
57 - Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED false)  
58 - Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false)  
59 - Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED false)  
60 - Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED false)  
61 - Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false)  
62 - Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false)  
63 - Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)  
64 - Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)  
65 - BR_PROPERTY(bool, classification, true)  
66 - BR_PROPERTY(float, splitPercentage, .01)  
67 - BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max())  
68 - BR_PROPERTY(int, maxTrees, 10)  
69 - BR_PROPERTY(float, forestAccuracy, .1)  
70 - BR_PROPERTY(bool, returnConfidence, true)  
71 - BR_PROPERTY(bool, overwriteMat, true)  
72 - BR_PROPERTY(QString, inputVariable, "Label")  
73 - BR_PROPERTY(QString, outputVariable, "")  
74 56
75 void train(const TemplateList &data) 57 void train(const TemplateList &data)
76 { 58 {
@@ -114,6 +96,27 @@ class ForestTransform : public Transform @@ -114,6 +96,27 @@ class ForestTransform : public Transform
114 } 96 }
115 97
116 protected: 98 protected:
  99 + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED false)
  100 + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED false)
  101 + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false)
  102 + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED false)
  103 + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED false)
  104 + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false)
  105 + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false)
  106 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  107 + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
  108 + Q_PROPERTY(bool weight READ get_weight WRITE set_weight RESET reset_weight STORED false)
  109 + BR_PROPERTY(bool, classification, true)
  110 + BR_PROPERTY(float, splitPercentage, .01)
  111 + BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max())
  112 + BR_PROPERTY(int, maxTrees, 10)
  113 + BR_PROPERTY(float, forestAccuracy, .1)
  114 + BR_PROPERTY(bool, returnConfidence, true)
  115 + BR_PROPERTY(bool, overwriteMat, true)
  116 + BR_PROPERTY(QString, inputVariable, "Label")
  117 + BR_PROPERTY(QString, outputVariable, "")
  118 + BR_PROPERTY(bool, weight, false)
  119 +
117 CvRTrees forest; 120 CvRTrees forest;
118 121
119 void trainForest(const TemplateList &data) 122 void trainForest(const TemplateList &data)
@@ -130,6 +133,15 @@ protected: @@ -130,6 +133,15 @@ protected:
130 types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL; 133 types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL;
131 } 134 }
132 135
  136 + bool usePrior = classification && weight;
  137 + float priors[2];
  138 + if (usePrior) {
  139 + int nonZero = countNonZero(labels);
  140 + priors[0] = 1;
  141 + priors[1] = (float)(samples.rows-nonZero)/nonZero;
  142 + qDebug() << priors[0] << priors[1] << (samples.rows-nonZero)/nonZero;
  143 + }
  144 +
133 int minSamplesForSplit = data.size()*splitPercentage; 145 int minSamplesForSplit = data.size()*splitPercentage;
134 forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), 146 forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(),
135 CvRTParams(maxDepth, 147 CvRTParams(maxDepth,
@@ -137,14 +149,37 @@ protected: @@ -137,14 +149,37 @@ protected:
137 0, 149 0,
138 false, 150 false,
139 2, 151 2,
140 - 0, 152 + usePrior ? priors : 0, //priors
141 false, 153 false,
142 0, 154 0,
143 maxTrees, 155 maxTrees,
144 forestAccuracy, 156 forestAccuracy,
145 - CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)); 157 + CV_TERMCRIT_ITER));
  158 +
  159 + if (Globals->verbose) {
  160 + qDebug() << "Number of trees:" << forest.get_tree_count();
  161 +
  162 + if (classification) {
  163 + QTime timer;
  164 + timer.start();
  165 + int correctClassification = 0;
  166 + float regressionError = 0;
  167 + for (int i=0; i<samples.rows; i++) {
  168 + float prediction = forest.predict_prob(samples.row(i));
  169 + int label = forest.predict(samples.row(i));
  170 + if (label == labels.at<float>(i,0)) {
  171 + correctClassification++;
  172 + }
  173 + regressionError += fabs(prediction-labels.at<float>(i,0));
  174 + }
146 175
147 - qDebug() << "Number of trees:" << forest.get_tree_count(); 176 + qDebug("Time to classify %d samples: %d ms\n \
  177 + Classification Accuracy: %f\n \
  178 + MAE: %f\n \
  179 + Sample dimensionality: %d",
  180 + samples.rows,timer.elapsed(),(float)correctClassification/samples.rows,regressionError/samples.rows,samples.cols);
  181 + }
  182 + }
148 } 183 }
149 }; 184 };
150 185
@@ -159,14 +194,14 @@ BR_REGISTER(Transform, ForestTransform) @@ -159,14 +194,14 @@ BR_REGISTER(Transform, ForestTransform)
159 class ForestInductionTransform : public ForestTransform 194 class ForestInductionTransform : public ForestTransform
160 { 195 {
161 Q_OBJECT 196 Q_OBJECT
  197 + Q_PROPERTY(bool useRegressionValue READ get_useRegressionValue WRITE set_useRegressionValue RESET reset_useRegressionValue STORED false)
  198 + BR_PROPERTY(bool, useRegressionValue, false)
162 199
163 int totalSize; 200 int totalSize;
164 QList< QList<const CvDTreeNode*> > nodes; 201 QList< QList<const CvDTreeNode*> > nodes;
165 202
166 - void train(const TemplateList &data) 203 + void fillNodes()
167 { 204 {
168 - trainForest(data);  
169 -  
170 for (int i=0; i<forest.get_tree_count(); i++) { 205 for (int i=0; i<forest.get_tree_count(); i++) {
171 nodes.append(QList<const CvDTreeNode*>()); 206 nodes.append(QList<const CvDTreeNode*>());
172 const CvDTreeNode* node = forest.get_tree(i)->get_root(); 207 const CvDTreeNode* node = forest.get_tree(i)->get_root();
@@ -198,17 +233,31 @@ class ForestInductionTransform : public ForestTransform @@ -198,17 +233,31 @@ class ForestInductionTransform : public ForestTransform
198 } 233 }
199 } 234 }
200 235
  236 + void train(const TemplateList &data)
  237 + {
  238 + trainForest(data);
  239 + if (!useRegressionValue) fillNodes();
  240 + }
  241 +
201 void project(const Template &src, Template &dst) const 242 void project(const Template &src, Template &dst) const
202 { 243 {
203 dst = src; 244 dst = src;
204 245
205 - Mat responses = Mat::zeros(totalSize,1,CV_32F); 246 + Mat responses;
206 247
207 - int offset = 0;  
208 - for (int i=0; i<nodes.size(); i++) {  
209 - int index = nodes[i].indexOf(forest.get_tree(i)->predict(src.m().reshape(1,1)));  
210 - responses.at<float>(offset+index,0) = 1;  
211 - offset += nodes[i].size(); 248 + if (useRegressionValue) {
  249 + responses = Mat::zeros(forest.get_tree_count(),1,CV_32F);
  250 + for (int i=0; i<forest.get_tree_count(); i++) {
  251 + responses.at<float>(i,0) = forest.get_tree(i)->predict(src.m().reshape(1,1))->value;
  252 + }
  253 + } else {
  254 + responses = Mat::zeros(totalSize,1,CV_32F);
  255 + int offset = 0;
  256 + for (int i=0; i<nodes.size(); i++) {
  257 + int index = nodes[i].indexOf(forest.get_tree(i)->predict(src.m().reshape(1,1)));
  258 + responses.at<float>(offset+index,0) = 1;
  259 + offset += nodes[i].size();
  260 + }
212 } 261 }
213 262
214 dst.m() = responses; 263 dst.m() = responses;
@@ -217,35 +266,7 @@ class ForestInductionTransform : public ForestTransform @@ -217,35 +266,7 @@ class ForestInductionTransform : public ForestTransform
217 void load(QDataStream &stream) 266 void load(QDataStream &stream)
218 { 267 {
219 loadModel(forest,stream); 268 loadModel(forest,stream);
220 - for (int i=0; i<forest.get_tree_count(); i++) {  
221 - nodes.append(QList<const CvDTreeNode*>());  
222 - const CvDTreeNode* node = forest.get_tree(i)->get_root();  
223 -  
224 - // traverse the tree and save all the nodes in depth-first order  
225 - for(;;)  
226 - {  
227 - CvDTreeNode* parent;  
228 - for(;;)  
229 - {  
230 - if( !node->left )  
231 - break;  
232 - node = node->left;  
233 - }  
234 -  
235 - nodes.last().append(node);  
236 -  
237 - for( parent = node->parent; parent && parent->right == node;  
238 - node = parent, parent = parent->parent )  
239 - ;  
240 -  
241 - if( !parent )  
242 - break;  
243 -  
244 - node = parent->right;  
245 - }  
246 -  
247 - totalSize += nodes.last().size();  
248 - } 269 + if (!useRegressionValue) fillNodes();
249 } 270 }
250 271
251 void store(QDataStream &stream) const 272 void store(QDataStream &stream) const
@@ -309,6 +330,10 @@ private: @@ -309,6 +330,10 @@ private:
309 Mat samples = OpenCVUtils::toMat(data.data()); 330 Mat samples = OpenCVUtils::toMat(data.data());
310 Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); 331 Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable));
311 332
  333 + for (int i=0; i<labels.rows; i++) {
  334 + if (labels.at<float>(i,0) != 1) labels.at<float>(i,0) = 0;
  335 + }
  336 +
312 Mat types = Mat(samples.cols + 1, 1, CV_8U); 337 Mat types = Mat(samples.cols + 1, 1, CV_8U);
313 types.setTo(Scalar(CV_VAR_NUMERICAL)); 338 types.setTo(Scalar(CV_VAR_NUMERICAL));
314 types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL; 339 types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL;
@@ -323,6 +348,17 @@ private: @@ -323,6 +348,17 @@ private:
323 348
324 boost.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), 349 boost.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(),
325 params); 350 params);
  351 +
  352 + QTime timer;
  353 + timer.start();
  354 + int correct = 0;
  355 + for (int i=0; i<samples.rows; i++) {
  356 + float prediction = boost.predict(samples.row(i));
  357 + if (prediction == labels.at<float>(i,0))
  358 + correct++;
  359 + }
  360 +
  361 + qDebug("Time to classify %d samples: %d ms\nAccuracy: %f\nSample dimensionality: %d",samples.rows,timer.elapsed(),(float)correct/samples.rows,samples.cols);
326 } 362 }
327 363
328 void project(const Template &src, Template &dst) const 364 void project(const Template &src, Template &dst) const