Commit ac67461f6850464d7f1dfa3bcbc7226a385b1df6

Authored by Josh Klontz
1 parent fcb1bcac

removed boostedforest

openbr/plugins/classification/boostedforest.cpp deleted
1   -#include <openbr/plugins/openbr_internal.h>
2   -#include <openbr/core/boost.h>
3   -
4   -#define THRESHOLD_EPS 1e-5
5   -
6   -using namespace cv;
7   -
8   -namespace br
9   -{
10   -
11   -struct Node
12   -{
13   - float value; // for leaf nodes
14   -
15   - float threshold; // for ordered features
16   - QList<int> subset; // for categorical features
17   - int featureIdx;
18   -
19   - Node *left, *right;
20   -};
21   -
22   -static void buildTreeRecursive(Node *node, const CvDTreeNode *cv_node, int maxCatCount)
23   -{
24   - if (!cv_node->left) {
25   - node->value = cv_node->value;
26   - node->left = node->right = NULL;
27   - } else {
28   - if (maxCatCount > 0)
29   - for (int i = 0; i < (maxCatCount + 31)/32; i++)
30   - node->subset.append(cv_node->split->subset[i]);
31   - else
32   - node->threshold = cv_node->split->ord.c;
33   -
34   - node->featureIdx = cv_node->split->var_idx;
35   -
36   - node->left = new Node; node->right = new Node;
37   - buildTreeRecursive(node->left, cv_node->left, maxCatCount);
38   - buildTreeRecursive(node->right, cv_node->right, maxCatCount);
39   - }
40   -}
41   -
42   -static void loadRecursive(QDataStream &stream, Node *node, int maxCatCount)
43   -{
44   - bool hasChildren; stream >> hasChildren;
45   -
46   - if (!hasChildren) {
47   - stream >> node->value;
48   - node->left = node->right = NULL;
49   - } else {
50   - if (maxCatCount > 0)
51   - for (int i = 0; i < (maxCatCount + 31)/32; i++) {
52   - int s; stream >> s; node->subset.append(s);
53   - }
54   - else
55   - stream >> node->threshold;
56   -
57   - stream >> node->featureIdx;
58   -
59   - node->left = new Node; node->right = new Node;
60   - loadRecursive(stream, node->left, maxCatCount);
61   - loadRecursive(stream, node->right, maxCatCount);
62   - }
63   -}
64   -
65   -static void storeRecursive(QDataStream &stream, const Node *node, int maxCatCount)
66   -{
67   - bool hasChildren = node->left ? true : false;
68   - stream << hasChildren;
69   -
70   - if (!hasChildren)
71   - stream << node->value;
72   - else {
73   - if (maxCatCount > 0)
74   - for (int i = 0; i < (maxCatCount + 31)/32; i++)
75   - stream << node->subset[i];
76   - else
77   - stream << node->threshold;
78   -
79   - stream << node->featureIdx;
80   -
81   - storeRecursive(stream, node->left, maxCatCount);
82   - storeRecursive(stream, node->right, maxCatCount);
83   - }
84   -}
85   -
86   -/*!
87   - * \brief A classification wrapper on OpenCV's CvBoost class. It uses CvBoost for training a boosted forest and then performs classification using the trained nodes.
88   - * \author Jordan Cheney \cite jcheney
89   - * \author Scott Klum \cite sklum
90   - * \br_property Representation* representation The Representation describing the features used by the boosted forest
91   - * \br_property float minTAR The minimum true accept rate during training
92   - * \br_property float maxFAR The maximum false accept rate during training
93   - * \br_property float trimRate The trim rate during training
94   - * \br_property int maxDepth The maximum depth for each trained tree
95   - * \br_property int maxWeakCount The maximum number of trees in the forest
96   - * \br_property Type type. The type of boosting to perform. Options are [Discrete, Real, Logit, Gentle]. Gentle is the default.
97   - */
98   -class BoostedForestClassifier : public Classifier
99   -{
100   - Q_OBJECT
101   - Q_ENUMS(Type)
102   -
103   - Q_PROPERTY(br::Representation* representation READ get_representation WRITE set_representation RESET reset_representation STORED false)
104   - Q_PROPERTY(float minTAR READ get_minTAR WRITE set_minTAR RESET reset_minTAR STORED false)
105   - Q_PROPERTY(float maxFAR READ get_maxFAR WRITE set_maxFAR RESET reset_maxFAR STORED false)
106   - Q_PROPERTY(float trimRate READ get_trimRate WRITE set_trimRate RESET reset_trimRate STORED false)
107   - Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false)
108   - Q_PROPERTY(int maxWeakCount READ get_maxWeakCount WRITE set_maxWeakCount RESET reset_maxWeakCount STORED false)
109   - Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false)
110   - Q_PROPERTY(float threshold READ get_threshold WRITE set_threshold RESET reset_threshold STORED false)
111   -
112   -public:
113   - QList<Node*> classifiers;
114   -
115   - enum Type { Discrete = CvBoost::DISCRETE,
116   - Real = CvBoost::REAL,
117   - Logit = CvBoost::LOGIT,
118   - Gentle = CvBoost::GENTLE};
119   -private:
120   - BR_PROPERTY(br::Representation*, representation, NULL)
121   - BR_PROPERTY(float, minTAR, 0.995)
122   - BR_PROPERTY(float, maxFAR, 0.5)
123   - BR_PROPERTY(float, trimRate, 0.95)
124   - BR_PROPERTY(int, maxDepth, 1)
125   - BR_PROPERTY(int, maxWeakCount, 100)
126   - BR_PROPERTY(Type, type, Gentle)
127   - BR_PROPERTY(float, threshold, 0)
128   -
129   - void train(const TemplateList &data)
130   - {
131   - representation->train(data);
132   -
133   - CascadeBoostParams params(type, minTAR, maxFAR, trimRate, maxDepth, maxWeakCount);
134   -
135   - FeatureEvaluator featureEvaluator;
136   - featureEvaluator.init(representation, data.size());
137   -
138   - for (int i = 0; i < data.size(); i++)
139   - featureEvaluator.setImage(data[i], data[i].file.get<float>("Label"), i);
140   -
141   - CascadeBoost boost;
142   - boost.train(&featureEvaluator, data.size(), 1024, 1024, representation->numChannels(), params);
143   -
144   - threshold = boost.getThreshold();
145   -
146   - foreach (const CvBoostTree *classifier, boost.getClassifers()) {
147   - Node *root = new Node;
148   - buildTreeRecursive(root, classifier->get_root(), representation->maxCatCount());
149   - classifiers.append(root);
150   - }
151   - }
152   -
153   - float classifyPreprocessed(const Template &t, float *confidence) const
154   - {
155   - const bool categorical = representation->maxCatCount() > 0;
156   -
157   - float sum = 0;
158   - for (int i = 0; i < classifiers.size(); i++) {
159   - const Node *node = classifiers[i];
160   -
161   - while (node->left) {
162   - const float val = representation->evaluate(t, node->featureIdx);
163   - if (categorical) {
164   - const int c = (int)val;
165   - node = (node->subset[c >> 5] & (1 << (c & 31))) ? node->left : node->right;
166   - } else {
167   - node = val <= node->threshold ? node->left : node->right;
168   - }
169   - }
170   -
171   - sum += node->value;
172   - }
173   -
174   - if (confidence)
175   - *confidence = sum;
176   - return sum < threshold - THRESHOLD_EPS ? 0.0f : 1.0f;
177   - }
178   -
179   - float classify(const Template &src, bool process, float *confidence) const
180   - {
181   - // This code is written in a way to avoid an unnecessary copy construction and destruction of `src` when `process` is false.
182   - return process ? classifyPreprocessed(preprocess(src), confidence) : classifyPreprocessed(src, confidence);
183   - }
184   -
185   - int numFeatures() const
186   - {
187   - return representation->numFeatures();
188   - }
189   -
190   - Template preprocess(const Template &src) const
191   - {
192   - return representation->preprocess(src);
193   - }
194   -
195   - Size windowSize(int *dx, int *dy) const
196   - {
197   - return representation->windowSize(dx, dy);
198   - }
199   -
200   - void load(QDataStream &stream)
201   - {
202   - representation->load(stream);
203   -
204   - stream >> threshold;
205   - int numClassifiers; stream >> numClassifiers;
206   - for (int i = 0; i < numClassifiers; i++) {
207   - Node *classifier = new Node;
208   - loadRecursive(stream, classifier, representation->maxCatCount());
209   - classifiers.append(classifier);
210   - }
211   - }
212   -
213   - void store(QDataStream &stream) const
214   - {
215   - representation->store(stream);
216   -
217   - stream << threshold;
218   - stream << classifiers.size();
219   - foreach (const Node *classifier, classifiers)
220   - storeRecursive(stream, classifier, representation->maxCatCount());
221   - }
222   -};
223   -
224   -QList<Node*> getClassifers(Classifier *classifier)
225   -{
226   - BoostedForestClassifier *boostedForest = static_cast<BoostedForestClassifier*>(classifier);
227   - return boostedForest->classifiers;
228   -}
229   -
230   -BR_REGISTER(Classifier, BoostedForestClassifier)
231   -
232   -} // namespace br
233   -
234   -#include "classification/boostedforest.moc"