From 9b430a4f829a7e5aecd163738865cc37e6cbeb59 Mon Sep 17 00:00:00 2001 From: Scott Klum Date: Tue, 3 Feb 2015 13:47:13 -0500 Subject: [PATCH] Added forest induction --- openbr/plugins/tree.cpp | 103 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------------------------ 1 file changed, 67 insertions(+), 36 deletions(-) diff --git a/openbr/plugins/tree.cpp b/openbr/plugins/tree.cpp index 47875e1..6768ec7 100644 --- a/openbr/plugins/tree.cpp +++ b/openbr/plugins/tree.cpp @@ -72,11 +72,51 @@ class ForestTransform : public Transform BR_PROPERTY(QString, inputVariable, "Label") BR_PROPERTY(QString, outputVariable, "") + void train(const TemplateList &data) + { + trainForest(data); + } + + void project(const Template &src, Template &dst) const + { + dst = src; + + float response; + if (classification && returnConfidence) { + // Fuzzy class label + response = forest.predict_prob(src.m().reshape(1,1)); + } else { + response = forest.predict(src.m().reshape(1,1)); + } + + if (overwriteMat) { + dst.m() = Mat(1, 1, CV_32F); + dst.m().at(0, 0) = response; + } else { + dst.file.set(outputVariable, response); + } + } + + void load(QDataStream &stream) + { + loadModel(forest,stream); + } + + void store(QDataStream &stream) const + { + storeModel(forest,stream); + } + + void init() + { + if (outputVariable.isEmpty()) + outputVariable = inputVariable; + } + +protected: CvRTrees forest; - int totalSize; - QList< QList > nodes; - void train(const TemplateList &data) + void trainForest(const TemplateList &data) { Mat samples = OpenCVUtils::toMat(data.data()); Mat labels = OpenCVUtils::toMat(File::get(data, inputVariable)); @@ -102,9 +142,30 @@ class ForestTransform : public Transform 0, maxTrees, forestAccuracy, - CV_TERMCRIT_ITER)); + CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)); qDebug() << "Number of trees:" << forest.get_tree_count(); + } +}; + +BR_REGISTER(Transform, ForestTransform) + +/*! + * \ingroup transforms + * \brief Wraps OpenCV's random trees framework to induce features + * \author Scott Klum \cite sklum + * \brief https://lirias.kuleuven.be/bitstream/123456789/316661/1/icdm11-camready.pdf + */ +class ForestInductionTransform : public ForestTransform +{ + Q_OBJECT + + int totalSize; + QList< QList > nodes; + + void train(const TemplateList &data) + { + trainForest(data); for (int i=0; i()); @@ -141,20 +202,6 @@ class ForestTransform : public Transform { dst = src; - /* - float response; - if (classification && returnConfidence) { - // Fuzzy class label - response = forest.predict_prob(src.m().reshape(1,1)); - } else { - response = forest.predict(src.m().reshape(1,1)); - }*/ - - // QTime timer; - // timer.start(); - - //qDebug() << forest.get_tree(0)->get_var_count(); - Mat responses = Mat::zeros(totalSize,1,CV_32F); int offset = 0; @@ -164,15 +211,7 @@ class ForestTransform : public Transform offset += nodes[i].size(); } - if (overwriteMat) { - dst.m() = responses; - //dst.m() = Mat(1, 1, CV_32F); - //dst.m().at(0, 0) = response; - } else { - //dst.file.set(outputVariable, response); - } - - //qDebug() << timer.elapsed(); + dst.m() = responses; } void load(QDataStream &stream) @@ -213,17 +252,9 @@ class ForestTransform : public Transform { storeModel(forest,stream); } - - void init() - { - totalSize = 0; - - if (outputVariable.isEmpty()) - outputVariable = inputVariable; - } }; -BR_REGISTER(Transform, ForestTransform) +BR_REGISTER(Transform, ForestInductionTransform) /*! * \ingroup transforms -- libgit2 0.21.4