Commit 373069f680c4ce3b7cf41497db2a1cb16e954bf8

Authored by Scott Klum
1 parent decda7bc

Refactored random forests and forest induction

Showing 1 changed file with 96 additions and 60 deletions
openbr/plugins/tree.cpp
... ... @@ -35,7 +35,7 @@ static void loadModel(CvStatModel &model, QDataStream &stream)
35 35 stream >> data;
36 36  
37 37 // Create local file
38   - QTemporaryFile tempFile(QDir::tempPath()+"/model");
  38 + QTemporaryFile tempFile(QDir::tempPath()+"/"+QString::number(rand()));
39 39 tempFile.open();
40 40 tempFile.write(data);
41 41 tempFile.close();
... ... @@ -53,24 +53,6 @@ static void loadModel(CvStatModel &model, QDataStream &stream)
53 53 class ForestTransform : public Transform
54 54 {
55 55 Q_OBJECT
56   - Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED false)
57   - Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED false)
58   - Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false)
59   - Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED false)
60   - Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED false)
61   - Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false)
62   - Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false)
63   - Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
64   - Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
65   - BR_PROPERTY(bool, classification, true)
66   - BR_PROPERTY(float, splitPercentage, .01)
67   - BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max())
68   - BR_PROPERTY(int, maxTrees, 10)
69   - BR_PROPERTY(float, forestAccuracy, .1)
70   - BR_PROPERTY(bool, returnConfidence, true)
71   - BR_PROPERTY(bool, overwriteMat, true)
72   - BR_PROPERTY(QString, inputVariable, "Label")
73   - BR_PROPERTY(QString, outputVariable, "")
74 56  
75 57 void train(const TemplateList &data)
76 58 {
... ... @@ -114,6 +96,27 @@ class ForestTransform : public Transform
114 96 }
115 97  
116 98 protected:
  99 + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED false)
  100 + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED false)
  101 + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false)
  102 + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED false)
  103 + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED false)
  104 + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false)
  105 + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false)
  106 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  107 + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
  108 + Q_PROPERTY(bool weight READ get_weight WRITE set_weight RESET reset_weight STORED false)
  109 + BR_PROPERTY(bool, classification, true)
  110 + BR_PROPERTY(float, splitPercentage, .01)
  111 + BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max())
  112 + BR_PROPERTY(int, maxTrees, 10)
  113 + BR_PROPERTY(float, forestAccuracy, .1)
  114 + BR_PROPERTY(bool, returnConfidence, true)
  115 + BR_PROPERTY(bool, overwriteMat, true)
  116 + BR_PROPERTY(QString, inputVariable, "Label")
  117 + BR_PROPERTY(QString, outputVariable, "")
  118 + BR_PROPERTY(bool, weight, false)
  119 +
117 120 CvRTrees forest;
118 121  
119 122 void trainForest(const TemplateList &data)
... ... @@ -130,6 +133,15 @@ protected:
130 133 types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL;
131 134 }
132 135  
  136 + bool usePrior = classification && weight;
  137 + float priors[2];
  138 + if (usePrior) {
  139 + int nonZero = countNonZero(labels);
  140 + priors[0] = 1;
  141 + priors[1] = (float)(samples.rows-nonZero)/nonZero;
  142 + qDebug() << priors[0] << priors[1] << (samples.rows-nonZero)/nonZero;
  143 + }
  144 +
133 145 int minSamplesForSplit = data.size()*splitPercentage;
134 146 forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(),
135 147 CvRTParams(maxDepth,
... ... @@ -137,14 +149,37 @@ protected:
137 149 0,
138 150 false,
139 151 2,
140   - 0,
  152 + usePrior ? priors : 0, //priors
141 153 false,
142 154 0,
143 155 maxTrees,
144 156 forestAccuracy,
145   - CV_TERMCRIT_ITER | CV_TERMCRIT_EPS));
  157 + CV_TERMCRIT_ITER));
  158 +
  159 + if (Globals->verbose) {
  160 + qDebug() << "Number of trees:" << forest.get_tree_count();
  161 +
  162 + if (classification) {
  163 + QTime timer;
  164 + timer.start();
  165 + int correctClassification = 0;
  166 + float regressionError = 0;
  167 + for (int i=0; i<samples.rows; i++) {
  168 + float prediction = forest.predict_prob(samples.row(i));
  169 + int label = forest.predict(samples.row(i));
  170 + if (label == labels.at<float>(i,0)) {
  171 + correctClassification++;
  172 + }
  173 + regressionError += fabs(prediction-labels.at<float>(i,0));
  174 + }
146 175  
147   - qDebug() << "Number of trees:" << forest.get_tree_count();
  176 + qDebug("Time to classify %d samples: %d ms\n \
  177 + Classification Accuracy: %f\n \
  178 + MAE: %f\n \
  179 + Sample dimensionality: %d",
  180 + samples.rows,timer.elapsed(),(float)correctClassification/samples.rows,regressionError/samples.rows,samples.cols);
  181 + }
  182 + }
148 183 }
149 184 };
150 185  
... ... @@ -159,14 +194,14 @@ BR_REGISTER(Transform, ForestTransform)
159 194 class ForestInductionTransform : public ForestTransform
160 195 {
161 196 Q_OBJECT
  197 + Q_PROPERTY(bool useRegressionValue READ get_useRegressionValue WRITE set_useRegressionValue RESET reset_useRegressionValue STORED false)
  198 + BR_PROPERTY(bool, useRegressionValue, false)
162 199  
163 200 int totalSize;
164 201 QList< QList<const CvDTreeNode*> > nodes;
165 202  
166   - void train(const TemplateList &data)
  203 + void fillNodes()
167 204 {
168   - trainForest(data);
169   -
170 205 for (int i=0; i<forest.get_tree_count(); i++) {
171 206 nodes.append(QList<const CvDTreeNode*>());
172 207 const CvDTreeNode* node = forest.get_tree(i)->get_root();
... ... @@ -198,17 +233,31 @@ class ForestInductionTransform : public ForestTransform
198 233 }
199 234 }
200 235  
  236 + void train(const TemplateList &data)
  237 + {
  238 + trainForest(data);
  239 + if (!useRegressionValue) fillNodes();
  240 + }
  241 +
201 242 void project(const Template &src, Template &dst) const
202 243 {
203 244 dst = src;
204 245  
205   - Mat responses = Mat::zeros(totalSize,1,CV_32F);
  246 + Mat responses;
206 247  
207   - int offset = 0;
208   - for (int i=0; i<nodes.size(); i++) {
209   - int index = nodes[i].indexOf(forest.get_tree(i)->predict(src.m().reshape(1,1)));
210   - responses.at<float>(offset+index,0) = 1;
211   - offset += nodes[i].size();
  248 + if (useRegressionValue) {
  249 + responses = Mat::zeros(forest.get_tree_count(),1,CV_32F);
  250 + for (int i=0; i<forest.get_tree_count(); i++) {
  251 + responses.at<float>(i,0) = forest.get_tree(i)->predict(src.m().reshape(1,1))->value;
  252 + }
  253 + } else {
  254 + responses = Mat::zeros(totalSize,1,CV_32F);
  255 + int offset = 0;
  256 + for (int i=0; i<nodes.size(); i++) {
  257 + int index = nodes[i].indexOf(forest.get_tree(i)->predict(src.m().reshape(1,1)));
  258 + responses.at<float>(offset+index,0) = 1;
  259 + offset += nodes[i].size();
  260 + }
212 261 }
213 262  
214 263 dst.m() = responses;
... ... @@ -217,35 +266,7 @@ class ForestInductionTransform : public ForestTransform
217 266 void load(QDataStream &stream)
218 267 {
219 268 loadModel(forest,stream);
220   - for (int i=0; i<forest.get_tree_count(); i++) {
221   - nodes.append(QList<const CvDTreeNode*>());
222   - const CvDTreeNode* node = forest.get_tree(i)->get_root();
223   -
224   - // traverse the tree and save all the nodes in depth-first order
225   - for(;;)
226   - {
227   - CvDTreeNode* parent;
228   - for(;;)
229   - {
230   - if( !node->left )
231   - break;
232   - node = node->left;
233   - }
234   -
235   - nodes.last().append(node);
236   -
237   - for( parent = node->parent; parent && parent->right == node;
238   - node = parent, parent = parent->parent )
239   - ;
240   -
241   - if( !parent )
242   - break;
243   -
244   - node = parent->right;
245   - }
246   -
247   - totalSize += nodes.last().size();
248   - }
  269 + if (!useRegressionValue) fillNodes();
249 270 }
250 271  
251 272 void store(QDataStream &stream) const
... ... @@ -309,6 +330,10 @@ private:
309 330 Mat samples = OpenCVUtils::toMat(data.data());
310 331 Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable));
311 332  
  333 + for (int i=0; i<labels.rows; i++) {
  334 + if (labels.at<float>(i,0) != 1) labels.at<float>(i,0) = 0;
  335 + }
  336 +
312 337 Mat types = Mat(samples.cols + 1, 1, CV_8U);
313 338 types.setTo(Scalar(CV_VAR_NUMERICAL));
314 339 types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL;
... ... @@ -323,6 +348,17 @@ private:
323 348  
324 349 boost.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(),
325 350 params);
  351 +
  352 + QTime timer;
  353 + timer.start();
  354 + int correct = 0;
  355 + for (int i=0; i<samples.rows; i++) {
  356 + float prediction = boost.predict(samples.row(i));
  357 + if (prediction == labels.at<float>(i,0))
  358 + correct++;
  359 + }
  360 +
  361 + qDebug("Time to classify %d samples: %d ms\nAccuracy: %f\nSample dimensionality: %d",samples.rows,timer.elapsed(),(float)correct/samples.rows,samples.cols);
326 362 }
327 363  
328 364 void project(const Template &src, Template &dst) const
... ...