Commit ac67461f6850464d7f1dfa3bcbc7226a385b1df6

Authored by Josh Klontz
1 parent fcb1bcac

removed boostedforest

openbr/plugins/classification/boostedforest.cpp deleted
1 -#include <openbr/plugins/openbr_internal.h>  
2 -#include <openbr/core/boost.h>  
3 -  
4 -#define THRESHOLD_EPS 1e-5  
5 -  
6 -using namespace cv;  
7 -  
8 -namespace br  
9 -{  
10 -  
11 -struct Node  
12 -{  
13 - float value; // for leaf nodes  
14 -  
15 - float threshold; // for ordered features  
16 - QList<int> subset; // for categorical features  
17 - int featureIdx;  
18 -  
19 - Node *left, *right;  
20 -};  
21 -  
22 -static void buildTreeRecursive(Node *node, const CvDTreeNode *cv_node, int maxCatCount)  
23 -{  
24 - if (!cv_node->left) {  
25 - node->value = cv_node->value;  
26 - node->left = node->right = NULL;  
27 - } else {  
28 - if (maxCatCount > 0)  
29 - for (int i = 0; i < (maxCatCount + 31)/32; i++)  
30 - node->subset.append(cv_node->split->subset[i]);  
31 - else  
32 - node->threshold = cv_node->split->ord.c;  
33 -  
34 - node->featureIdx = cv_node->split->var_idx;  
35 -  
36 - node->left = new Node; node->right = new Node;  
37 - buildTreeRecursive(node->left, cv_node->left, maxCatCount);  
38 - buildTreeRecursive(node->right, cv_node->right, maxCatCount);  
39 - }  
40 -}  
41 -  
42 -static void loadRecursive(QDataStream &stream, Node *node, int maxCatCount)  
43 -{  
44 - bool hasChildren; stream >> hasChildren;  
45 -  
46 - if (!hasChildren) {  
47 - stream >> node->value;  
48 - node->left = node->right = NULL;  
49 - } else {  
50 - if (maxCatCount > 0)  
51 - for (int i = 0; i < (maxCatCount + 31)/32; i++) {  
52 - int s; stream >> s; node->subset.append(s);  
53 - }  
54 - else  
55 - stream >> node->threshold;  
56 -  
57 - stream >> node->featureIdx;  
58 -  
59 - node->left = new Node; node->right = new Node;  
60 - loadRecursive(stream, node->left, maxCatCount);  
61 - loadRecursive(stream, node->right, maxCatCount);  
62 - }  
63 -}  
64 -  
65 -static void storeRecursive(QDataStream &stream, const Node *node, int maxCatCount)  
66 -{  
67 - bool hasChildren = node->left ? true : false;  
68 - stream << hasChildren;  
69 -  
70 - if (!hasChildren)  
71 - stream << node->value;  
72 - else {  
73 - if (maxCatCount > 0)  
74 - for (int i = 0; i < (maxCatCount + 31)/32; i++)  
75 - stream << node->subset[i];  
76 - else  
77 - stream << node->threshold;  
78 -  
79 - stream << node->featureIdx;  
80 -  
81 - storeRecursive(stream, node->left, maxCatCount);  
82 - storeRecursive(stream, node->right, maxCatCount);  
83 - }  
84 -}  
85 -  
86 -/*!  
87 - * \brief A classification wrapper on OpenCV's CvBoost class. It uses CvBoost for training a boosted forest and then performs classification using the trained nodes.  
88 - * \author Jordan Cheney \cite jcheney  
89 - * \author Scott Klum \cite sklum  
90 - * \br_property Representation* representation The Representation describing the features used by the boosted forest  
91 - * \br_property float minTAR The minimum true accept rate during training  
92 - * \br_property float maxFAR The maximum false accept rate during training  
93 - * \br_property float trimRate The trim rate during training  
94 - * \br_property int maxDepth The maximum depth for each trained tree  
95 - * \br_property int maxWeakCount The maximum number of trees in the forest  
96 - * \br_property Type type. The type of boosting to perform. Options are [Discrete, Real, Logit, Gentle]. Gentle is the default.  
97 - */  
98 -class BoostedForestClassifier : public Classifier  
99 -{  
100 - Q_OBJECT  
101 - Q_ENUMS(Type)  
102 -  
103 - Q_PROPERTY(br::Representation* representation READ get_representation WRITE set_representation RESET reset_representation STORED false)  
104 - Q_PROPERTY(float minTAR READ get_minTAR WRITE set_minTAR RESET reset_minTAR STORED false)  
105 - Q_PROPERTY(float maxFAR READ get_maxFAR WRITE set_maxFAR RESET reset_maxFAR STORED false)  
106 - Q_PROPERTY(float trimRate READ get_trimRate WRITE set_trimRate RESET reset_trimRate STORED false)  
107 - Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false)  
108 - Q_PROPERTY(int maxWeakCount READ get_maxWeakCount WRITE set_maxWeakCount RESET reset_maxWeakCount STORED false)  
109 - Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false)  
110 - Q_PROPERTY(float threshold READ get_threshold WRITE set_threshold RESET reset_threshold STORED false)  
111 -  
112 -public:  
113 - QList<Node*> classifiers;  
114 -  
115 - enum Type { Discrete = CvBoost::DISCRETE,  
116 - Real = CvBoost::REAL,  
117 - Logit = CvBoost::LOGIT,  
118 - Gentle = CvBoost::GENTLE};  
119 -private:  
120 - BR_PROPERTY(br::Representation*, representation, NULL)  
121 - BR_PROPERTY(float, minTAR, 0.995)  
122 - BR_PROPERTY(float, maxFAR, 0.5)  
123 - BR_PROPERTY(float, trimRate, 0.95)  
124 - BR_PROPERTY(int, maxDepth, 1)  
125 - BR_PROPERTY(int, maxWeakCount, 100)  
126 - BR_PROPERTY(Type, type, Gentle)  
127 - BR_PROPERTY(float, threshold, 0)  
128 -  
129 - void train(const TemplateList &data)  
130 - {  
131 - representation->train(data);  
132 -  
133 - CascadeBoostParams params(type, minTAR, maxFAR, trimRate, maxDepth, maxWeakCount);  
134 -  
135 - FeatureEvaluator featureEvaluator;  
136 - featureEvaluator.init(representation, data.size());  
137 -  
138 - for (int i = 0; i < data.size(); i++)  
139 - featureEvaluator.setImage(data[i], data[i].file.get<float>("Label"), i);  
140 -  
141 - CascadeBoost boost;  
142 - boost.train(&featureEvaluator, data.size(), 1024, 1024, representation->numChannels(), params);  
143 -  
144 - threshold = boost.getThreshold();  
145 -  
146 - foreach (const CvBoostTree *classifier, boost.getClassifers()) {  
147 - Node *root = new Node;  
148 - buildTreeRecursive(root, classifier->get_root(), representation->maxCatCount());  
149 - classifiers.append(root);  
150 - }  
151 - }  
152 -  
153 - float classifyPreprocessed(const Template &t, float *confidence) const  
154 - {  
155 - const bool categorical = representation->maxCatCount() > 0;  
156 -  
157 - float sum = 0;  
158 - for (int i = 0; i < classifiers.size(); i++) {  
159 - const Node *node = classifiers[i];  
160 -  
161 - while (node->left) {  
162 - const float val = representation->evaluate(t, node->featureIdx);  
163 - if (categorical) {  
164 - const int c = (int)val;  
165 - node = (node->subset[c >> 5] & (1 << (c & 31))) ? node->left : node->right;  
166 - } else {  
167 - node = val <= node->threshold ? node->left : node->right;  
168 - }  
169 - }  
170 -  
171 - sum += node->value;  
172 - }  
173 -  
174 - if (confidence)  
175 - *confidence = sum;  
176 - return sum < threshold - THRESHOLD_EPS ? 0.0f : 1.0f;  
177 - }  
178 -  
179 - float classify(const Template &src, bool process, float *confidence) const  
180 - {  
181 - // This code is written in a way to avoid an unnecessary copy construction and destruction of `src` when `process` is false.  
182 - return process ? classifyPreprocessed(preprocess(src), confidence) : classifyPreprocessed(src, confidence);  
183 - }  
184 -  
185 - int numFeatures() const  
186 - {  
187 - return representation->numFeatures();  
188 - }  
189 -  
190 - Template preprocess(const Template &src) const  
191 - {  
192 - return representation->preprocess(src);  
193 - }  
194 -  
195 - Size windowSize(int *dx, int *dy) const  
196 - {  
197 - return representation->windowSize(dx, dy);  
198 - }  
199 -  
200 - void load(QDataStream &stream)  
201 - {  
202 - representation->load(stream);  
203 -  
204 - stream >> threshold;  
205 - int numClassifiers; stream >> numClassifiers;  
206 - for (int i = 0; i < numClassifiers; i++) {  
207 - Node *classifier = new Node;  
208 - loadRecursive(stream, classifier, representation->maxCatCount());  
209 - classifiers.append(classifier);  
210 - }  
211 - }  
212 -  
213 - void store(QDataStream &stream) const  
214 - {  
215 - representation->store(stream);  
216 -  
217 - stream << threshold;  
218 - stream << classifiers.size();  
219 - foreach (const Node *classifier, classifiers)  
220 - storeRecursive(stream, classifier, representation->maxCatCount());  
221 - }  
222 -};  
223 -  
224 -QList<Node*> getClassifers(Classifier *classifier)  
225 -{  
226 - BoostedForestClassifier *boostedForest = static_cast<BoostedForestClassifier*>(classifier);  
227 - return boostedForest->classifiers;  
228 -}  
229 -  
230 -BR_REGISTER(Classifier, BoostedForestClassifier)  
231 -  
232 -} // namespace br  
233 -  
234 -#include "classification/boostedforest.moc"