From 373069f680c4ce3b7cf41497db2a1cb16e954bf8 Mon Sep 17 00:00:00 2001 From: Scott Klum Date: Mon, 9 Feb 2015 13:07:27 -0500 Subject: [PATCH] Refactored random forests and forest induction --- openbr/plugins/tree.cpp | 156 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------------------------------------------------ 1 file changed, 96 insertions(+), 60 deletions(-) diff --git a/openbr/plugins/tree.cpp b/openbr/plugins/tree.cpp index 6768ec7..3e073ac 100644 --- a/openbr/plugins/tree.cpp +++ b/openbr/plugins/tree.cpp @@ -35,7 +35,7 @@ static void loadModel(CvStatModel &model, QDataStream &stream) stream >> data; // Create local file - QTemporaryFile tempFile(QDir::tempPath()+"/model"); + QTemporaryFile tempFile(QDir::tempPath()+"/"+QString::number(rand())); tempFile.open(); tempFile.write(data); tempFile.close(); @@ -53,24 +53,6 @@ static void loadModel(CvStatModel &model, QDataStream &stream) class ForestTransform : public Transform { Q_OBJECT - Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED false) - Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED false) - Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false) - Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED false) - Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED false) - Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false) - Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) - Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) - Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) - BR_PROPERTY(bool, classification, true) - BR_PROPERTY(float, splitPercentage, .01) - BR_PROPERTY(int, maxDepth, std::numeric_limits::max()) - BR_PROPERTY(int, maxTrees, 10) - BR_PROPERTY(float, forestAccuracy, .1) - BR_PROPERTY(bool, returnConfidence, true) - BR_PROPERTY(bool, overwriteMat, true) - BR_PROPERTY(QString, inputVariable, "Label") - BR_PROPERTY(QString, outputVariable, "") void train(const TemplateList &data) { @@ -114,6 +96,27 @@ class ForestTransform : public Transform } protected: + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED false) + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED false) + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false) + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED false) + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED false) + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false) + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) + Q_PROPERTY(bool weight READ get_weight WRITE set_weight RESET reset_weight STORED false) + BR_PROPERTY(bool, classification, true) + BR_PROPERTY(float, splitPercentage, .01) + BR_PROPERTY(int, maxDepth, std::numeric_limits::max()) + BR_PROPERTY(int, maxTrees, 10) + BR_PROPERTY(float, forestAccuracy, .1) + BR_PROPERTY(bool, returnConfidence, true) + BR_PROPERTY(bool, overwriteMat, true) + BR_PROPERTY(QString, inputVariable, "Label") + BR_PROPERTY(QString, outputVariable, "") + BR_PROPERTY(bool, weight, false) + CvRTrees forest; void trainForest(const TemplateList &data) @@ -130,6 +133,15 @@ protected: types.at(samples.cols, 0) = CV_VAR_NUMERICAL; } + bool usePrior = classification && weight; + float priors[2]; + if (usePrior) { + int nonZero = countNonZero(labels); + priors[0] = 1; + priors[1] = (float)(samples.rows-nonZero)/nonZero; + qDebug() << priors[0] << priors[1] << (samples.rows-nonZero)/nonZero; + } + int minSamplesForSplit = data.size()*splitPercentage; forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), CvRTParams(maxDepth, @@ -137,14 +149,37 @@ protected: 0, false, 2, - 0, + usePrior ? priors : 0, //priors false, 0, maxTrees, forestAccuracy, - CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)); + CV_TERMCRIT_ITER)); + + if (Globals->verbose) { + qDebug() << "Number of trees:" << forest.get_tree_count(); + + if (classification) { + QTime timer; + timer.start(); + int correctClassification = 0; + float regressionError = 0; + for (int i=0; i(i,0)) { + correctClassification++; + } + regressionError += fabs(prediction-labels.at(i,0)); + } - qDebug() << "Number of trees:" << forest.get_tree_count(); + qDebug("Time to classify %d samples: %d ms\n \ + Classification Accuracy: %f\n \ + MAE: %f\n \ + Sample dimensionality: %d", + samples.rows,timer.elapsed(),(float)correctClassification/samples.rows,regressionError/samples.rows,samples.cols); + } + } } }; @@ -159,14 +194,14 @@ BR_REGISTER(Transform, ForestTransform) class ForestInductionTransform : public ForestTransform { Q_OBJECT + Q_PROPERTY(bool useRegressionValue READ get_useRegressionValue WRITE set_useRegressionValue RESET reset_useRegressionValue STORED false) + BR_PROPERTY(bool, useRegressionValue, false) int totalSize; QList< QList > nodes; - void train(const TemplateList &data) + void fillNodes() { - trainForest(data); - for (int i=0; i()); const CvDTreeNode* node = forest.get_tree(i)->get_root(); @@ -198,17 +233,31 @@ class ForestInductionTransform : public ForestTransform } } + void train(const TemplateList &data) + { + trainForest(data); + if (!useRegressionValue) fillNodes(); + } + void project(const Template &src, Template &dst) const { dst = src; - Mat responses = Mat::zeros(totalSize,1,CV_32F); + Mat responses; - int offset = 0; - for (int i=0; ipredict(src.m().reshape(1,1))); - responses.at(offset+index,0) = 1; - offset += nodes[i].size(); + if (useRegressionValue) { + responses = Mat::zeros(forest.get_tree_count(),1,CV_32F); + for (int i=0; i(i,0) = forest.get_tree(i)->predict(src.m().reshape(1,1))->value; + } + } else { + responses = Mat::zeros(totalSize,1,CV_32F); + int offset = 0; + for (int i=0; ipredict(src.m().reshape(1,1))); + responses.at(offset+index,0) = 1; + offset += nodes[i].size(); + } } dst.m() = responses; @@ -217,35 +266,7 @@ class ForestInductionTransform : public ForestTransform void load(QDataStream &stream) { loadModel(forest,stream); - for (int i=0; i()); - const CvDTreeNode* node = forest.get_tree(i)->get_root(); - - // traverse the tree and save all the nodes in depth-first order - for(;;) - { - CvDTreeNode* parent; - for(;;) - { - if( !node->left ) - break; - node = node->left; - } - - nodes.last().append(node); - - for( parent = node->parent; parent && parent->right == node; - node = parent, parent = parent->parent ) - ; - - if( !parent ) - break; - - node = parent->right; - } - - totalSize += nodes.last().size(); - } + if (!useRegressionValue) fillNodes(); } void store(QDataStream &stream) const @@ -309,6 +330,10 @@ private: Mat samples = OpenCVUtils::toMat(data.data()); Mat labels = OpenCVUtils::toMat(File::get(data, inputVariable)); + for (int i=0; i(i,0) != 1) labels.at(i,0) = 0; + } + Mat types = Mat(samples.cols + 1, 1, CV_8U); types.setTo(Scalar(CV_VAR_NUMERICAL)); types.at(samples.cols, 0) = CV_VAR_CATEGORICAL; @@ -323,6 +348,17 @@ private: boost.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), params); + + QTime timer; + timer.start(); + int correct = 0; + for (int i=0; i(i,0)) + correct++; + } + + qDebug("Time to classify %d samples: %d ms\nAccuracy: %f\nSample dimensionality: %d",samples.rows,timer.elapsed(),(float)correct/samples.rows,samples.cols); } void project(const Template &src, Template &dst) const -- libgit2 0.21.4