Commit cb60809425f4390709e750f0d715c479d51a266d
Merge pull request #296 from biometrics/forest
Forest
Showing
1 changed file
with
149 additions
and
0 deletions
openbr/plugins/tree.cpp
0 → 100644
| 1 | +#include <opencv2/ml/ml.hpp> | ||
| 2 | + | ||
| 3 | +#include "openbr_internal.h" | ||
| 4 | +#include "openbr/core/opencvutils.h" | ||
| 5 | +#include <QString> | ||
| 6 | +#include <QTemporaryFile> | ||
| 7 | + | ||
| 8 | +using namespace std; | ||
| 9 | +using namespace cv; | ||
| 10 | + | ||
| 11 | +namespace br | ||
| 12 | +{ | ||
| 13 | + | ||
| 14 | +static void storeForest(const CvRTrees &forest, QDataStream &stream) | ||
| 15 | +{ | ||
| 16 | + // Create local file | ||
| 17 | + QTemporaryFile tempFile; | ||
| 18 | + tempFile.open(); | ||
| 19 | + tempFile.close(); | ||
| 20 | + | ||
| 21 | + // Save MLP to local file | ||
| 22 | + forest.save(qPrintable(tempFile.fileName())); | ||
| 23 | + | ||
| 24 | + // Copy local file contents to stream | ||
| 25 | + tempFile.open(); | ||
| 26 | + QByteArray data = tempFile.readAll(); | ||
| 27 | + tempFile.close(); | ||
| 28 | + stream << data; | ||
| 29 | +} | ||
| 30 | + | ||
| 31 | +static void loadForest(CvRTrees &forest, QDataStream &stream) | ||
| 32 | +{ | ||
| 33 | + // Copy local file contents from stream | ||
| 34 | + QByteArray data; | ||
| 35 | + stream >> data; | ||
| 36 | + | ||
| 37 | + // Create local file | ||
| 38 | + QTemporaryFile tempFile(QDir::tempPath()+"/forest"); | ||
| 39 | + tempFile.open(); | ||
| 40 | + tempFile.write(data); | ||
| 41 | + tempFile.close(); | ||
| 42 | + | ||
| 43 | + // Load MLP from local file | ||
| 44 | + forest.load(qPrintable(tempFile.fileName())); | ||
| 45 | +} | ||
| 46 | + | ||
| 47 | +/*! | ||
| 48 | + * \ingroup transforms | ||
| 49 | + * \brief Wraps OpenCV's random trees framework | ||
| 50 | + * \author Scott Klum \cite sklum | ||
| 51 | + * \brief http://docs.opencv.org/modules/ml/doc/random_trees.html | ||
| 52 | + */ | ||
| 53 | +class ForestTransform : public MetaTransform | ||
| 54 | +{ | ||
| 55 | + Q_OBJECT | ||
| 56 | + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED true) | ||
| 57 | + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED true) | ||
| 58 | + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true) | ||
| 59 | + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true) | ||
| 60 | + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true) | ||
| 61 | + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED true) | ||
| 62 | + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED true) | ||
| 63 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | ||
| 64 | + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | ||
| 65 | + BR_PROPERTY(bool, classification, true) | ||
| 66 | + BR_PROPERTY(float, splitPercentage, .01) | ||
| 67 | + BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max()) | ||
| 68 | + BR_PROPERTY(int, maxTrees, 10) | ||
| 69 | + BR_PROPERTY(float, forestAccuracy, .1) | ||
| 70 | + BR_PROPERTY(bool, returnConfidence, true) | ||
| 71 | + BR_PROPERTY(bool, overwriteMat, true) | ||
| 72 | + BR_PROPERTY(QString, inputVariable, "Label") | ||
| 73 | + BR_PROPERTY(QString, outputVariable, "") | ||
| 74 | + | ||
| 75 | + CvRTrees forest; | ||
| 76 | + | ||
| 77 | + void train(const TemplateList &data) | ||
| 78 | + { | ||
| 79 | + Mat samples = OpenCVUtils::toMat(data.data()); | ||
| 80 | + Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); | ||
| 81 | + | ||
| 82 | + Mat types = Mat(samples.cols + 1, 1, CV_8U); | ||
| 83 | + types.setTo(Scalar(CV_VAR_NUMERICAL)); | ||
| 84 | + | ||
| 85 | + if (classification) { | ||
| 86 | + types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL; | ||
| 87 | + } else { | ||
| 88 | + types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL; | ||
| 89 | + } | ||
| 90 | + | ||
| 91 | + int minSamplesForSplit = data.size()*splitPercentage; | ||
| 92 | + forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), | ||
| 93 | + CvRTParams(maxDepth, | ||
| 94 | + minSamplesForSplit, | ||
| 95 | + 0, | ||
| 96 | + false, | ||
| 97 | + 2, | ||
| 98 | + 0, // priors | ||
| 99 | + false, | ||
| 100 | + 0, | ||
| 101 | + maxTrees, | ||
| 102 | + forestAccuracy, | ||
| 103 | + CV_TERMCRIT_EPS)); | ||
| 104 | + | ||
| 105 | + qDebug() << "Number of trees:" << forest.get_tree_count(); | ||
| 106 | + } | ||
| 107 | + | ||
| 108 | + void project(const Template &src, Template &dst) const | ||
| 109 | + { | ||
| 110 | + dst = src; | ||
| 111 | + | ||
| 112 | + float response; | ||
| 113 | + if (classification && returnConfidence) { | ||
| 114 | + // Fuzzy class label | ||
| 115 | + response = forest.predict_prob(src.m().reshape(1,1)); | ||
| 116 | + } else { | ||
| 117 | + response = forest.predict(src.m().reshape(1,1)); | ||
| 118 | + } | ||
| 119 | + | ||
| 120 | + if (overwriteMat) { | ||
| 121 | + dst.m() = Mat(1, 1, CV_32F); | ||
| 122 | + dst.m().at<float>(0, 0) = response; | ||
| 123 | + } else { | ||
| 124 | + dst.file.set(outputVariable, response); | ||
| 125 | + } | ||
| 126 | + } | ||
| 127 | + | ||
| 128 | + void load(QDataStream &stream) | ||
| 129 | + { | ||
| 130 | + loadForest(forest,stream); | ||
| 131 | + } | ||
| 132 | + | ||
| 133 | + void store(QDataStream &stream) const | ||
| 134 | + { | ||
| 135 | + storeForest(forest,stream); | ||
| 136 | + } | ||
| 137 | + | ||
| 138 | + void init() | ||
| 139 | + { | ||
| 140 | + if (outputVariable.isEmpty()) | ||
| 141 | + outputVariable = inputVariable; | ||
| 142 | + } | ||
| 143 | +}; | ||
| 144 | + | ||
| 145 | +BR_REGISTER(Transform, ForestTransform) | ||
| 146 | + | ||
| 147 | +} // namespace br | ||
| 148 | + | ||
| 149 | +#include "tree.moc" |