tree.cpp
8.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
#include <opencv2/ml/ml.hpp>
#include "openbr_internal.h"
#include "openbr/core/opencvutils.h"
#include <QString>
#include <QTemporaryFile>
using namespace std;
using namespace cv;
namespace br
{
static void storeModel(const CvStatModel &model, QDataStream &stream)
{
// Create local file
QTemporaryFile tempFile;
tempFile.open();
tempFile.close();
// Save MLP to local file
model.save(qPrintable(tempFile.fileName()));
// Copy local file contents to stream
tempFile.open();
QByteArray data = tempFile.readAll();
tempFile.close();
stream << data;
}
static void loadModel(CvStatModel &model, QDataStream &stream)
{
// Copy local file contents from stream
QByteArray data;
stream >> data;
// Create local file
QTemporaryFile tempFile(QDir::tempPath()+"/model");
tempFile.open();
tempFile.write(data);
tempFile.close();
// Load MLP from local file
model.load(qPrintable(tempFile.fileName()));
}
/*!
* \ingroup transforms
* \brief Wraps OpenCV's random trees framework
* \author Scott Klum \cite sklum
* \brief http://docs.opencv.org/modules/ml/doc/random_trees.html
*/
class ForestTransform : public Transform
{
Q_OBJECT
Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED false)
Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED false)
Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false)
Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED false)
Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED false)
Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false)
Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false)
Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
BR_PROPERTY(bool, classification, true)
BR_PROPERTY(float, splitPercentage, .01)
BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max())
BR_PROPERTY(int, maxTrees, 10)
BR_PROPERTY(float, forestAccuracy, .1)
BR_PROPERTY(bool, returnConfidence, true)
BR_PROPERTY(bool, overwriteMat, true)
BR_PROPERTY(QString, inputVariable, "Label")
BR_PROPERTY(QString, outputVariable, "")
CvRTrees forest;
void train(const TemplateList &data)
{
Mat samples = OpenCVUtils::toMat(data.data());
Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable));
Mat types = Mat(samples.cols + 1, 1, CV_8U);
types.setTo(Scalar(CV_VAR_NUMERICAL));
if (classification) {
types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL;
} else {
types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL;
}
int minSamplesForSplit = data.size()*splitPercentage;
forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(),
CvRTParams(maxDepth,
minSamplesForSplit,
0,
false,
2,
0, // priors
false,
0,
maxTrees,
forestAccuracy,
CV_TERMCRIT_ITER | CV_TERMCRIT_EPS));
qDebug() << "Number of trees:" << forest.get_tree_count();
}
void project(const Template &src, Template &dst) const
{
dst = src;
float response;
if (classification && returnConfidence) {
// Fuzzy class label
response = forest.predict_prob(src.m().reshape(1,1));
} else {
response = forest.predict(src.m().reshape(1,1));
}
if (overwriteMat) {
dst.m() = Mat(1, 1, CV_32F);
dst.m().at<float>(0, 0) = response;
} else {
dst.file.set(outputVariable, response);
}
}
void load(QDataStream &stream)
{
loadModel(forest,stream);
}
void store(QDataStream &stream) const
{
storeModel(forest,stream);
}
void init()
{
if (outputVariable.isEmpty())
outputVariable = inputVariable;
}
};
BR_REGISTER(Transform, ForestTransform)
/*!
* \ingroup transforms
* \brief Wraps OpenCV's Ada Boost framework
* \author Scott Klum \cite sklum
* \brief http://docs.opencv.org/modules/ml/doc/boosting.html
*/
class AdaBoostTransform : public Transform
{
Q_OBJECT
Q_ENUMS(Type)
Q_ENUMS(SplitCriteria)
Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false)
Q_PROPERTY(SplitCriteria splitCriteria READ get_splitCriteria WRITE set_splitCriteria RESET reset_splitCriteria STORED false)
Q_PROPERTY(int weakCount READ get_weakCount WRITE set_weakCount RESET reset_weakCount STORED false)
Q_PROPERTY(float trimRate READ get_trimRate WRITE set_trimRate RESET reset_trimRate STORED false)
Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false)
Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false)
Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false)
Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false)
Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
public:
enum Type { Discrete = CvBoost::DISCRETE,
Real = CvBoost::REAL,
Logit = CvBoost::LOGIT,
Gentle = CvBoost::GENTLE};
enum SplitCriteria { Default = CvBoost::DEFAULT,
Gini = CvBoost::GINI,
Misclass = CvBoost::MISCLASS,
Sqerr = CvBoost::SQERR};
private:
BR_PROPERTY(Type, type, Real)
BR_PROPERTY(SplitCriteria, splitCriteria, Default)
BR_PROPERTY(int, weakCount, 100)
BR_PROPERTY(float, trimRate, .95)
BR_PROPERTY(int, folds, 0)
BR_PROPERTY(int, maxDepth, 1)
BR_PROPERTY(bool, returnConfidence, true)
BR_PROPERTY(bool, overwriteMat, true)
BR_PROPERTY(QString, inputVariable, "Label")
BR_PROPERTY(QString, outputVariable, "")
CvBoost boost;
void train(const TemplateList &data)
{
Mat samples = OpenCVUtils::toMat(data.data());
Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable));
Mat types = Mat(samples.cols + 1, 1, CV_8U);
types.setTo(Scalar(CV_VAR_NUMERICAL));
types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL;
CvBoostParams params;
params.boost_type = type;
params.split_criteria = splitCriteria;
params.weak_count = weakCount;
params.weight_trim_rate = trimRate;
params.cv_folds = folds;
params.max_depth = maxDepth;
boost.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(),
params);
}
void project(const Template &src, Template &dst) const
{
dst = src;
float response;
if (returnConfidence) {
response = boost.predict(src.m().reshape(1,1),Mat(),Range::all(),false,true)/weakCount;
} else {
response = boost.predict(src.m().reshape(1,1));
}
if (overwriteMat) {
dst.m() = Mat(1, 1, CV_32F);
dst.m().at<float>(0, 0) = response;
} else {
dst.file.set(outputVariable, response);
}
}
void load(QDataStream &stream)
{
loadModel(boost,stream);
}
void store(QDataStream &stream) const
{
storeModel(boost,stream);
}
void init()
{
if (outputVariable.isEmpty())
outputVariable = inputVariable;
}
};
BR_REGISTER(Transform, AdaBoostTransform)
} // namespace br
#include "tree.moc"