Commit e1439a2cc79f6a875536201350a1592bf5802f77

Authored by Scott Klum
1 parent d93507d9

Added tree.cpp with ForestTransform

Showing 1 changed file with 124 additions and 0 deletions
openbr/plugins/tree.cpp 0 → 100644
  1 +#include <opencv2/ml/ml.hpp>
  2 +
  3 +#include "openbr_internal.h"
  4 +#include "openbr/core/opencvutils.h"
  5 +#include <QString>
  6 +#include <QTemporaryFile>
  7 +
  8 +using namespace std;
  9 +using namespace cv;
  10 +
  11 +namespace br
  12 +{
  13 +
  14 +static void storeForest(const CvRTrees &forest, QDataStream &stream)
  15 +{
  16 + // Create local file
  17 + QTemporaryFile tempFile;
  18 + tempFile.open();
  19 + tempFile.close();
  20 +
  21 + // Save MLP to local file
  22 + forest.save(qPrintable(tempFile.fileName()));
  23 +
  24 + // Copy local file contents to stream
  25 + tempFile.open();
  26 + QByteArray data = tempFile.readAll();
  27 + tempFile.close();
  28 + stream << data;
  29 +}
  30 +
  31 +static void loadForest(CvRTrees &forest, QDataStream &stream)
  32 +{
  33 + // Copy local file contents from stream
  34 + QByteArray data;
  35 + stream >> data;
  36 +
  37 + // Create local file
  38 + QTemporaryFile tempFile(QDir::tempPath()+"/forest");
  39 + tempFile.open();
  40 + tempFile.write(data);
  41 + tempFile.close();
  42 +
  43 + // Load MLP from local file
  44 + forest.load(qPrintable(tempFile.fileName()));
  45 +}
  46 +
  47 +/*!
  48 + * \ingroup transforms
  49 + * \brief Wraps OpenCV's random trees framework
  50 + * \author Scott Klum \cite sklum
  51 + * \brief http://docs.opencv.org/modules/ml/doc/random_trees.html
  52 + */
  53 +class ForestTransform : public MetaTransform
  54 +{
  55 + Q_OBJECT
  56 + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED true)
  57 + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED true)
  58 + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true)
  59 + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true)
  60 + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true)
  61 + BR_PROPERTY(bool, classification, true)
  62 + BR_PROPERTY(float, splitPercentage, .01)
  63 + BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max())
  64 + BR_PROPERTY(int, maxTrees, 10)
  65 + BR_PROPERTY(float, forestAccuracy, .1)
  66 +
  67 + CvRTrees forest;
  68 +
  69 + void train(const TemplateList &data)
  70 + {
  71 + Mat samples = OpenCVUtils::toMat(data.data());
  72 + Mat labels = OpenCVUtils::toMat(File::get<float>(data, "Label"));
  73 +
  74 + Mat types = Mat(samples.cols + 1, 1, CV_8U );
  75 + types.setTo(Scalar(CV_VAR_NUMERICAL));
  76 +
  77 + if (classification) {
  78 + types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL;
  79 + } else {
  80 + types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL;
  81 + }
  82 +
  83 + int minSamplesForSplit = data.size()*splitPercentage;
  84 + forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(),
  85 + CvRTParams(maxDepth,
  86 + minSamplesForSplit,
  87 + 0,
  88 + false,
  89 + 2,
  90 + 0, // priors
  91 + false,
  92 + 0,
  93 + maxTrees,
  94 + forestAccuracy,
  95 + CV_TERMCRIT_EPS));
  96 +
  97 + qDebug() << "Number of trees:" << forest.get_tree_count();
  98 + }
  99 +
  100 + void project(const Template &src, Template &dst) const
  101 + {
  102 + dst = src;
  103 +
  104 + float response = forest.predict_prob(src.m().reshape(1,1));
  105 + dst.m() = Mat(1, 1, CV_32F);
  106 + dst.m().at<float>(0, 0) = response;
  107 + }
  108 +
  109 + void load(QDataStream &stream)
  110 + {
  111 + loadForest(forest,stream);
  112 + }
  113 +
  114 + void store(QDataStream &stream) const
  115 + {
  116 + storeForest(forest,stream);
  117 + }
  118 +};
  119 +
  120 +BR_REGISTER(Transform, ForestTransform)
  121 +
  122 +} // namespace br
  123 +
  124 +#include "tree.moc"