tree.cpp
3.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include <opencv2/ml/ml.hpp>
#include "openbr_internal.h"
#include "openbr/core/opencvutils.h"
#include <QString>
#include <QTemporaryFile>
using namespace std;
using namespace cv;
namespace br
{
static void storeForest(const CvRTrees &forest, QDataStream &stream)
{
// Create local file
QTemporaryFile tempFile;
tempFile.open();
tempFile.close();
// Save MLP to local file
forest.save(qPrintable(tempFile.fileName()));
// Copy local file contents to stream
tempFile.open();
QByteArray data = tempFile.readAll();
tempFile.close();
stream << data;
}
static void loadForest(CvRTrees &forest, QDataStream &stream)
{
// Copy local file contents from stream
QByteArray data;
stream >> data;
// Create local file
QTemporaryFile tempFile(QDir::tempPath()+"/forest");
tempFile.open();
tempFile.write(data);
tempFile.close();
// Load MLP from local file
forest.load(qPrintable(tempFile.fileName()));
}
/*!
* \ingroup transforms
* \brief Wraps OpenCV's random trees framework
* \author Scott Klum \cite sklum
* \brief http://docs.opencv.org/modules/ml/doc/random_trees.html
*/
class ForestTransform : public MetaTransform
{
Q_OBJECT
Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED true)
Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED true)
Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true)
Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true)
Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true)
BR_PROPERTY(bool, classification, true)
BR_PROPERTY(float, splitPercentage, .01)
BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max())
BR_PROPERTY(int, maxTrees, 10)
BR_PROPERTY(float, forestAccuracy, .1)
CvRTrees forest;
void train(const TemplateList &data)
{
Mat samples = OpenCVUtils::toMat(data.data());
Mat labels = OpenCVUtils::toMat(File::get<float>(data, "Label"));
Mat types = Mat(samples.cols + 1, 1, CV_8U );
types.setTo(Scalar(CV_VAR_NUMERICAL));
if (classification) {
types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL;
} else {
types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL;
}
int minSamplesForSplit = data.size()*splitPercentage;
forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(),
CvRTParams(maxDepth,
minSamplesForSplit,
0,
false,
2,
0, // priors
false,
0,
maxTrees,
forestAccuracy,
CV_TERMCRIT_EPS));
qDebug() << "Number of trees:" << forest.get_tree_count();
}
void project(const Template &src, Template &dst) const
{
dst = src;
float response = forest.predict_prob(src.m().reshape(1,1));
dst.m() = Mat(1, 1, CV_32F);
dst.m().at<float>(0, 0) = response;
}
void load(QDataStream &stream)
{
loadForest(forest,stream);
}
void store(QDataStream &stream) const
{
storeForest(forest,stream);
}
};
BR_REGISTER(Transform, ForestTransform)
} // namespace br
#include "tree.moc"