diff --git a/openbr/plugins/tree.cpp b/openbr/plugins/tree.cpp index 9cba347..47875e1 100644 --- a/openbr/plugins/tree.cpp +++ b/openbr/plugins/tree.cpp @@ -73,6 +73,8 @@ class ForestTransform : public Transform BR_PROPERTY(QString, outputVariable, "") CvRTrees forest; + int totalSize; + QList< QList > nodes; void train(const TemplateList &data) { @@ -95,39 +97,116 @@ class ForestTransform : public Transform 0, false, 2, - 0, // priors + 0, false, 0, maxTrees, forestAccuracy, - CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)); + CV_TERMCRIT_ITER)); qDebug() << "Number of trees:" << forest.get_tree_count(); + + for (int i=0; i()); + const CvDTreeNode* node = forest.get_tree(i)->get_root(); + + // traverse the tree and save all the nodes in depth-first order + for(;;) + { + CvDTreeNode* parent; + for(;;) + { + if( !node->left ) + break; + node = node->left; + } + + nodes.last().append(node); + + for( parent = node->parent; parent && parent->right == node; + node = parent, parent = parent->parent ) + ; + + if( !parent ) + break; + + node = parent->right; + } + + totalSize += nodes.last().size(); + } } void project(const Template &src, Template &dst) const { dst = src; + /* float response; if (classification && returnConfidence) { // Fuzzy class label response = forest.predict_prob(src.m().reshape(1,1)); } else { response = forest.predict(src.m().reshape(1,1)); + }*/ + + // QTime timer; + // timer.start(); + + //qDebug() << forest.get_tree(0)->get_var_count(); + + Mat responses = Mat::zeros(totalSize,1,CV_32F); + + int offset = 0; + for (int i=0; ipredict(src.m().reshape(1,1))); + responses.at(offset+index,0) = 1; + offset += nodes[i].size(); } if (overwriteMat) { - dst.m() = Mat(1, 1, CV_32F); - dst.m().at(0, 0) = response; + dst.m() = responses; + //dst.m() = Mat(1, 1, CV_32F); + //dst.m().at(0, 0) = response; } else { - dst.file.set(outputVariable, response); + //dst.file.set(outputVariable, response); } + + //qDebug() << timer.elapsed(); } void load(QDataStream &stream) { loadModel(forest,stream); + for (int i=0; i()); + const CvDTreeNode* node = forest.get_tree(i)->get_root(); + + // traverse the tree and save all the nodes in depth-first order + for(;;) + { + CvDTreeNode* parent; + for(;;) + { + if( !node->left ) + break; + node = node->left; + } + + nodes.last().append(node); + + for( parent = node->parent; parent && parent->right == node; + node = parent, parent = parent->parent ) + ; + + if( !parent ) + break; + + node = parent->right; + } + + totalSize += nodes.last().size(); + } } void store(QDataStream &stream) const @@ -137,6 +216,8 @@ class ForestTransform : public Transform void init() { + totalSize = 0; + if (outputVariable.isEmpty()) outputVariable = inputVariable; }