From e1439a2cc79f6a875536201350a1592bf5802f77 Mon Sep 17 00:00:00 2001 From: Scott Klum Date: Fri, 26 Dec 2014 10:22:14 -0500 Subject: [PATCH] Added tree.cpp with ForestTransform --- openbr/plugins/tree.cpp | 124 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+), 0 deletions(-) create mode 100644 openbr/plugins/tree.cpp diff --git a/openbr/plugins/tree.cpp b/openbr/plugins/tree.cpp new file mode 100644 index 0000000..fed8960 --- /dev/null +++ b/openbr/plugins/tree.cpp @@ -0,0 +1,124 @@ +#include + +#include "openbr_internal.h" +#include "openbr/core/opencvutils.h" +#include +#include + +using namespace std; +using namespace cv; + +namespace br +{ + +static void storeForest(const CvRTrees &forest, QDataStream &stream) +{ + // Create local file + QTemporaryFile tempFile; + tempFile.open(); + tempFile.close(); + + // Save MLP to local file + forest.save(qPrintable(tempFile.fileName())); + + // Copy local file contents to stream + tempFile.open(); + QByteArray data = tempFile.readAll(); + tempFile.close(); + stream << data; +} + +static void loadForest(CvRTrees &forest, QDataStream &stream) +{ + // Copy local file contents from stream + QByteArray data; + stream >> data; + + // Create local file + QTemporaryFile tempFile(QDir::tempPath()+"/forest"); + tempFile.open(); + tempFile.write(data); + tempFile.close(); + + // Load MLP from local file + forest.load(qPrintable(tempFile.fileName())); +} + +/*! + * \ingroup transforms + * \brief Wraps OpenCV's random trees framework + * \author Scott Klum \cite sklum + * \brief http://docs.opencv.org/modules/ml/doc/random_trees.html + */ +class ForestTransform : public MetaTransform +{ + Q_OBJECT + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED true) + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED true) + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true) + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true) + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true) + BR_PROPERTY(bool, classification, true) + BR_PROPERTY(float, splitPercentage, .01) + BR_PROPERTY(int, maxDepth, std::numeric_limits::max()) + BR_PROPERTY(int, maxTrees, 10) + BR_PROPERTY(float, forestAccuracy, .1) + + CvRTrees forest; + + void train(const TemplateList &data) + { + Mat samples = OpenCVUtils::toMat(data.data()); + Mat labels = OpenCVUtils::toMat(File::get(data, "Label")); + + Mat types = Mat(samples.cols + 1, 1, CV_8U ); + types.setTo(Scalar(CV_VAR_NUMERICAL)); + + if (classification) { + types.at(samples.cols, 0) = CV_VAR_CATEGORICAL; + } else { + types.at(samples.cols, 0) = CV_VAR_NUMERICAL; + } + + int minSamplesForSplit = data.size()*splitPercentage; + forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), + CvRTParams(maxDepth, + minSamplesForSplit, + 0, + false, + 2, + 0, // priors + false, + 0, + maxTrees, + forestAccuracy, + CV_TERMCRIT_EPS)); + + qDebug() << "Number of trees:" << forest.get_tree_count(); + } + + void project(const Template &src, Template &dst) const + { + dst = src; + + float response = forest.predict_prob(src.m().reshape(1,1)); + dst.m() = Mat(1, 1, CV_32F); + dst.m().at(0, 0) = response; + } + + void load(QDataStream &stream) + { + loadForest(forest,stream); + } + + void store(QDataStream &stream) const + { + storeForest(forest,stream); + } +}; + +BR_REGISTER(Transform, ForestTransform) + +} // namespace br + +#include "tree.moc" -- libgit2 0.21.4