Commit cb60809425f4390709e750f0d715c479d51a266d
Merge pull request #296 from biometrics/forest
Forest
Showing
1 changed file
with
149 additions
and
0 deletions
openbr/plugins/tree.cpp
0 → 100644
| 1 | +#include <opencv2/ml/ml.hpp> | |
| 2 | + | |
| 3 | +#include "openbr_internal.h" | |
| 4 | +#include "openbr/core/opencvutils.h" | |
| 5 | +#include <QString> | |
| 6 | +#include <QTemporaryFile> | |
| 7 | + | |
| 8 | +using namespace std; | |
| 9 | +using namespace cv; | |
| 10 | + | |
| 11 | +namespace br | |
| 12 | +{ | |
| 13 | + | |
| 14 | +static void storeForest(const CvRTrees &forest, QDataStream &stream) | |
| 15 | +{ | |
| 16 | + // Create local file | |
| 17 | + QTemporaryFile tempFile; | |
| 18 | + tempFile.open(); | |
| 19 | + tempFile.close(); | |
| 20 | + | |
| 21 | + // Save MLP to local file | |
| 22 | + forest.save(qPrintable(tempFile.fileName())); | |
| 23 | + | |
| 24 | + // Copy local file contents to stream | |
| 25 | + tempFile.open(); | |
| 26 | + QByteArray data = tempFile.readAll(); | |
| 27 | + tempFile.close(); | |
| 28 | + stream << data; | |
| 29 | +} | |
| 30 | + | |
| 31 | +static void loadForest(CvRTrees &forest, QDataStream &stream) | |
| 32 | +{ | |
| 33 | + // Copy local file contents from stream | |
| 34 | + QByteArray data; | |
| 35 | + stream >> data; | |
| 36 | + | |
| 37 | + // Create local file | |
| 38 | + QTemporaryFile tempFile(QDir::tempPath()+"/forest"); | |
| 39 | + tempFile.open(); | |
| 40 | + tempFile.write(data); | |
| 41 | + tempFile.close(); | |
| 42 | + | |
| 43 | + // Load MLP from local file | |
| 44 | + forest.load(qPrintable(tempFile.fileName())); | |
| 45 | +} | |
| 46 | + | |
| 47 | +/*! | |
| 48 | + * \ingroup transforms | |
| 49 | + * \brief Wraps OpenCV's random trees framework | |
| 50 | + * \author Scott Klum \cite sklum | |
| 51 | + * \brief http://docs.opencv.org/modules/ml/doc/random_trees.html | |
| 52 | + */ | |
| 53 | +class ForestTransform : public MetaTransform | |
| 54 | +{ | |
| 55 | + Q_OBJECT | |
| 56 | + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED true) | |
| 57 | + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED true) | |
| 58 | + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true) | |
| 59 | + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true) | |
| 60 | + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true) | |
| 61 | + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED true) | |
| 62 | + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED true) | |
| 63 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | |
| 64 | + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | |
| 65 | + BR_PROPERTY(bool, classification, true) | |
| 66 | + BR_PROPERTY(float, splitPercentage, .01) | |
| 67 | + BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max()) | |
| 68 | + BR_PROPERTY(int, maxTrees, 10) | |
| 69 | + BR_PROPERTY(float, forestAccuracy, .1) | |
| 70 | + BR_PROPERTY(bool, returnConfidence, true) | |
| 71 | + BR_PROPERTY(bool, overwriteMat, true) | |
| 72 | + BR_PROPERTY(QString, inputVariable, "Label") | |
| 73 | + BR_PROPERTY(QString, outputVariable, "") | |
| 74 | + | |
| 75 | + CvRTrees forest; | |
| 76 | + | |
| 77 | + void train(const TemplateList &data) | |
| 78 | + { | |
| 79 | + Mat samples = OpenCVUtils::toMat(data.data()); | |
| 80 | + Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); | |
| 81 | + | |
| 82 | + Mat types = Mat(samples.cols + 1, 1, CV_8U); | |
| 83 | + types.setTo(Scalar(CV_VAR_NUMERICAL)); | |
| 84 | + | |
| 85 | + if (classification) { | |
| 86 | + types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL; | |
| 87 | + } else { | |
| 88 | + types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL; | |
| 89 | + } | |
| 90 | + | |
| 91 | + int minSamplesForSplit = data.size()*splitPercentage; | |
| 92 | + forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), | |
| 93 | + CvRTParams(maxDepth, | |
| 94 | + minSamplesForSplit, | |
| 95 | + 0, | |
| 96 | + false, | |
| 97 | + 2, | |
| 98 | + 0, // priors | |
| 99 | + false, | |
| 100 | + 0, | |
| 101 | + maxTrees, | |
| 102 | + forestAccuracy, | |
| 103 | + CV_TERMCRIT_EPS)); | |
| 104 | + | |
| 105 | + qDebug() << "Number of trees:" << forest.get_tree_count(); | |
| 106 | + } | |
| 107 | + | |
| 108 | + void project(const Template &src, Template &dst) const | |
| 109 | + { | |
| 110 | + dst = src; | |
| 111 | + | |
| 112 | + float response; | |
| 113 | + if (classification && returnConfidence) { | |
| 114 | + // Fuzzy class label | |
| 115 | + response = forest.predict_prob(src.m().reshape(1,1)); | |
| 116 | + } else { | |
| 117 | + response = forest.predict(src.m().reshape(1,1)); | |
| 118 | + } | |
| 119 | + | |
| 120 | + if (overwriteMat) { | |
| 121 | + dst.m() = Mat(1, 1, CV_32F); | |
| 122 | + dst.m().at<float>(0, 0) = response; | |
| 123 | + } else { | |
| 124 | + dst.file.set(outputVariable, response); | |
| 125 | + } | |
| 126 | + } | |
| 127 | + | |
| 128 | + void load(QDataStream &stream) | |
| 129 | + { | |
| 130 | + loadForest(forest,stream); | |
| 131 | + } | |
| 132 | + | |
| 133 | + void store(QDataStream &stream) const | |
| 134 | + { | |
| 135 | + storeForest(forest,stream); | |
| 136 | + } | |
| 137 | + | |
| 138 | + void init() | |
| 139 | + { | |
| 140 | + if (outputVariable.isEmpty()) | |
| 141 | + outputVariable = inputVariable; | |
| 142 | + } | |
| 143 | +}; | |
| 144 | + | |
| 145 | +BR_REGISTER(Transform, ForestTransform) | |
| 146 | + | |
| 147 | +} // namespace br | |
| 148 | + | |
| 149 | +#include "tree.moc" | ... | ... |