Commit cb60809425f4390709e750f0d715c479d51a266d

Authored by Scott Klum
2 parents d93507d9 c1648735

Merge pull request #296 from biometrics/forest

Forest
Showing 1 changed file with 149 additions and 0 deletions
openbr/plugins/tree.cpp 0 → 100644
  1 +#include <opencv2/ml/ml.hpp>
  2 +
  3 +#include "openbr_internal.h"
  4 +#include "openbr/core/opencvutils.h"
  5 +#include <QString>
  6 +#include <QTemporaryFile>
  7 +
  8 +using namespace std;
  9 +using namespace cv;
  10 +
  11 +namespace br
  12 +{
  13 +
  14 +static void storeForest(const CvRTrees &forest, QDataStream &stream)
  15 +{
  16 + // Create local file
  17 + QTemporaryFile tempFile;
  18 + tempFile.open();
  19 + tempFile.close();
  20 +
  21 + // Save MLP to local file
  22 + forest.save(qPrintable(tempFile.fileName()));
  23 +
  24 + // Copy local file contents to stream
  25 + tempFile.open();
  26 + QByteArray data = tempFile.readAll();
  27 + tempFile.close();
  28 + stream << data;
  29 +}
  30 +
  31 +static void loadForest(CvRTrees &forest, QDataStream &stream)
  32 +{
  33 + // Copy local file contents from stream
  34 + QByteArray data;
  35 + stream >> data;
  36 +
  37 + // Create local file
  38 + QTemporaryFile tempFile(QDir::tempPath()+"/forest");
  39 + tempFile.open();
  40 + tempFile.write(data);
  41 + tempFile.close();
  42 +
  43 + // Load MLP from local file
  44 + forest.load(qPrintable(tempFile.fileName()));
  45 +}
  46 +
  47 +/*!
  48 + * \ingroup transforms
  49 + * \brief Wraps OpenCV's random trees framework
  50 + * \author Scott Klum \cite sklum
  51 + * \brief http://docs.opencv.org/modules/ml/doc/random_trees.html
  52 + */
  53 +class ForestTransform : public MetaTransform
  54 +{
  55 + Q_OBJECT
  56 + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED true)
  57 + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED true)
  58 + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true)
  59 + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true)
  60 + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true)
  61 + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED true)
  62 + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED true)
  63 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  64 + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
  65 + BR_PROPERTY(bool, classification, true)
  66 + BR_PROPERTY(float, splitPercentage, .01)
  67 + BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max())
  68 + BR_PROPERTY(int, maxTrees, 10)
  69 + BR_PROPERTY(float, forestAccuracy, .1)
  70 + BR_PROPERTY(bool, returnConfidence, true)
  71 + BR_PROPERTY(bool, overwriteMat, true)
  72 + BR_PROPERTY(QString, inputVariable, "Label")
  73 + BR_PROPERTY(QString, outputVariable, "")
  74 +
  75 + CvRTrees forest;
  76 +
  77 + void train(const TemplateList &data)
  78 + {
  79 + Mat samples = OpenCVUtils::toMat(data.data());
  80 + Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable));
  81 +
  82 + Mat types = Mat(samples.cols + 1, 1, CV_8U);
  83 + types.setTo(Scalar(CV_VAR_NUMERICAL));
  84 +
  85 + if (classification) {
  86 + types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL;
  87 + } else {
  88 + types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL;
  89 + }
  90 +
  91 + int minSamplesForSplit = data.size()*splitPercentage;
  92 + forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(),
  93 + CvRTParams(maxDepth,
  94 + minSamplesForSplit,
  95 + 0,
  96 + false,
  97 + 2,
  98 + 0, // priors
  99 + false,
  100 + 0,
  101 + maxTrees,
  102 + forestAccuracy,
  103 + CV_TERMCRIT_EPS));
  104 +
  105 + qDebug() << "Number of trees:" << forest.get_tree_count();
  106 + }
  107 +
  108 + void project(const Template &src, Template &dst) const
  109 + {
  110 + dst = src;
  111 +
  112 + float response;
  113 + if (classification && returnConfidence) {
  114 + // Fuzzy class label
  115 + response = forest.predict_prob(src.m().reshape(1,1));
  116 + } else {
  117 + response = forest.predict(src.m().reshape(1,1));
  118 + }
  119 +
  120 + if (overwriteMat) {
  121 + dst.m() = Mat(1, 1, CV_32F);
  122 + dst.m().at<float>(0, 0) = response;
  123 + } else {
  124 + dst.file.set(outputVariable, response);
  125 + }
  126 + }
  127 +
  128 + void load(QDataStream &stream)
  129 + {
  130 + loadForest(forest,stream);
  131 + }
  132 +
  133 + void store(QDataStream &stream) const
  134 + {
  135 + storeForest(forest,stream);
  136 + }
  137 +
  138 + void init()
  139 + {
  140 + if (outputVariable.isEmpty())
  141 + outputVariable = inputVariable;
  142 + }
  143 +};
  144 +
  145 +BR_REGISTER(Transform, ForestTransform)
  146 +
  147 +} // namespace br
  148 +
  149 +#include "tree.moc"