Commit ac67461f6850464d7f1dfa3bcbc7226a385b1df6
1 parent
fcb1bcac
removed boostedforest
Showing
1 changed file
with
0 additions
and
234 deletions
openbr/plugins/classification/boostedforest.cpp deleted
| 1 | -#include <openbr/plugins/openbr_internal.h> | ||
| 2 | -#include <openbr/core/boost.h> | ||
| 3 | - | ||
| 4 | -#define THRESHOLD_EPS 1e-5 | ||
| 5 | - | ||
| 6 | -using namespace cv; | ||
| 7 | - | ||
| 8 | -namespace br | ||
| 9 | -{ | ||
| 10 | - | ||
| 11 | -struct Node | ||
| 12 | -{ | ||
| 13 | - float value; // for leaf nodes | ||
| 14 | - | ||
| 15 | - float threshold; // for ordered features | ||
| 16 | - QList<int> subset; // for categorical features | ||
| 17 | - int featureIdx; | ||
| 18 | - | ||
| 19 | - Node *left, *right; | ||
| 20 | -}; | ||
| 21 | - | ||
| 22 | -static void buildTreeRecursive(Node *node, const CvDTreeNode *cv_node, int maxCatCount) | ||
| 23 | -{ | ||
| 24 | - if (!cv_node->left) { | ||
| 25 | - node->value = cv_node->value; | ||
| 26 | - node->left = node->right = NULL; | ||
| 27 | - } else { | ||
| 28 | - if (maxCatCount > 0) | ||
| 29 | - for (int i = 0; i < (maxCatCount + 31)/32; i++) | ||
| 30 | - node->subset.append(cv_node->split->subset[i]); | ||
| 31 | - else | ||
| 32 | - node->threshold = cv_node->split->ord.c; | ||
| 33 | - | ||
| 34 | - node->featureIdx = cv_node->split->var_idx; | ||
| 35 | - | ||
| 36 | - node->left = new Node; node->right = new Node; | ||
| 37 | - buildTreeRecursive(node->left, cv_node->left, maxCatCount); | ||
| 38 | - buildTreeRecursive(node->right, cv_node->right, maxCatCount); | ||
| 39 | - } | ||
| 40 | -} | ||
| 41 | - | ||
| 42 | -static void loadRecursive(QDataStream &stream, Node *node, int maxCatCount) | ||
| 43 | -{ | ||
| 44 | - bool hasChildren; stream >> hasChildren; | ||
| 45 | - | ||
| 46 | - if (!hasChildren) { | ||
| 47 | - stream >> node->value; | ||
| 48 | - node->left = node->right = NULL; | ||
| 49 | - } else { | ||
| 50 | - if (maxCatCount > 0) | ||
| 51 | - for (int i = 0; i < (maxCatCount + 31)/32; i++) { | ||
| 52 | - int s; stream >> s; node->subset.append(s); | ||
| 53 | - } | ||
| 54 | - else | ||
| 55 | - stream >> node->threshold; | ||
| 56 | - | ||
| 57 | - stream >> node->featureIdx; | ||
| 58 | - | ||
| 59 | - node->left = new Node; node->right = new Node; | ||
| 60 | - loadRecursive(stream, node->left, maxCatCount); | ||
| 61 | - loadRecursive(stream, node->right, maxCatCount); | ||
| 62 | - } | ||
| 63 | -} | ||
| 64 | - | ||
| 65 | -static void storeRecursive(QDataStream &stream, const Node *node, int maxCatCount) | ||
| 66 | -{ | ||
| 67 | - bool hasChildren = node->left ? true : false; | ||
| 68 | - stream << hasChildren; | ||
| 69 | - | ||
| 70 | - if (!hasChildren) | ||
| 71 | - stream << node->value; | ||
| 72 | - else { | ||
| 73 | - if (maxCatCount > 0) | ||
| 74 | - for (int i = 0; i < (maxCatCount + 31)/32; i++) | ||
| 75 | - stream << node->subset[i]; | ||
| 76 | - else | ||
| 77 | - stream << node->threshold; | ||
| 78 | - | ||
| 79 | - stream << node->featureIdx; | ||
| 80 | - | ||
| 81 | - storeRecursive(stream, node->left, maxCatCount); | ||
| 82 | - storeRecursive(stream, node->right, maxCatCount); | ||
| 83 | - } | ||
| 84 | -} | ||
| 85 | - | ||
| 86 | -/*! | ||
| 87 | - * \brief A classification wrapper on OpenCV's CvBoost class. It uses CvBoost for training a boosted forest and then performs classification using the trained nodes. | ||
| 88 | - * \author Jordan Cheney \cite jcheney | ||
| 89 | - * \author Scott Klum \cite sklum | ||
| 90 | - * \br_property Representation* representation The Representation describing the features used by the boosted forest | ||
| 91 | - * \br_property float minTAR The minimum true accept rate during training | ||
| 92 | - * \br_property float maxFAR The maximum false accept rate during training | ||
| 93 | - * \br_property float trimRate The trim rate during training | ||
| 94 | - * \br_property int maxDepth The maximum depth for each trained tree | ||
| 95 | - * \br_property int maxWeakCount The maximum number of trees in the forest | ||
| 96 | - * \br_property Type type. The type of boosting to perform. Options are [Discrete, Real, Logit, Gentle]. Gentle is the default. | ||
| 97 | - */ | ||
| 98 | -class BoostedForestClassifier : public Classifier | ||
| 99 | -{ | ||
| 100 | - Q_OBJECT | ||
| 101 | - Q_ENUMS(Type) | ||
| 102 | - | ||
| 103 | - Q_PROPERTY(br::Representation* representation READ get_representation WRITE set_representation RESET reset_representation STORED false) | ||
| 104 | - Q_PROPERTY(float minTAR READ get_minTAR WRITE set_minTAR RESET reset_minTAR STORED false) | ||
| 105 | - Q_PROPERTY(float maxFAR READ get_maxFAR WRITE set_maxFAR RESET reset_maxFAR STORED false) | ||
| 106 | - Q_PROPERTY(float trimRate READ get_trimRate WRITE set_trimRate RESET reset_trimRate STORED false) | ||
| 107 | - Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false) | ||
| 108 | - Q_PROPERTY(int maxWeakCount READ get_maxWeakCount WRITE set_maxWeakCount RESET reset_maxWeakCount STORED false) | ||
| 109 | - Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) | ||
| 110 | - Q_PROPERTY(float threshold READ get_threshold WRITE set_threshold RESET reset_threshold STORED false) | ||
| 111 | - | ||
| 112 | -public: | ||
| 113 | - QList<Node*> classifiers; | ||
| 114 | - | ||
| 115 | - enum Type { Discrete = CvBoost::DISCRETE, | ||
| 116 | - Real = CvBoost::REAL, | ||
| 117 | - Logit = CvBoost::LOGIT, | ||
| 118 | - Gentle = CvBoost::GENTLE}; | ||
| 119 | -private: | ||
| 120 | - BR_PROPERTY(br::Representation*, representation, NULL) | ||
| 121 | - BR_PROPERTY(float, minTAR, 0.995) | ||
| 122 | - BR_PROPERTY(float, maxFAR, 0.5) | ||
| 123 | - BR_PROPERTY(float, trimRate, 0.95) | ||
| 124 | - BR_PROPERTY(int, maxDepth, 1) | ||
| 125 | - BR_PROPERTY(int, maxWeakCount, 100) | ||
| 126 | - BR_PROPERTY(Type, type, Gentle) | ||
| 127 | - BR_PROPERTY(float, threshold, 0) | ||
| 128 | - | ||
| 129 | - void train(const TemplateList &data) | ||
| 130 | - { | ||
| 131 | - representation->train(data); | ||
| 132 | - | ||
| 133 | - CascadeBoostParams params(type, minTAR, maxFAR, trimRate, maxDepth, maxWeakCount); | ||
| 134 | - | ||
| 135 | - FeatureEvaluator featureEvaluator; | ||
| 136 | - featureEvaluator.init(representation, data.size()); | ||
| 137 | - | ||
| 138 | - for (int i = 0; i < data.size(); i++) | ||
| 139 | - featureEvaluator.setImage(data[i], data[i].file.get<float>("Label"), i); | ||
| 140 | - | ||
| 141 | - CascadeBoost boost; | ||
| 142 | - boost.train(&featureEvaluator, data.size(), 1024, 1024, representation->numChannels(), params); | ||
| 143 | - | ||
| 144 | - threshold = boost.getThreshold(); | ||
| 145 | - | ||
| 146 | - foreach (const CvBoostTree *classifier, boost.getClassifers()) { | ||
| 147 | - Node *root = new Node; | ||
| 148 | - buildTreeRecursive(root, classifier->get_root(), representation->maxCatCount()); | ||
| 149 | - classifiers.append(root); | ||
| 150 | - } | ||
| 151 | - } | ||
| 152 | - | ||
| 153 | - float classifyPreprocessed(const Template &t, float *confidence) const | ||
| 154 | - { | ||
| 155 | - const bool categorical = representation->maxCatCount() > 0; | ||
| 156 | - | ||
| 157 | - float sum = 0; | ||
| 158 | - for (int i = 0; i < classifiers.size(); i++) { | ||
| 159 | - const Node *node = classifiers[i]; | ||
| 160 | - | ||
| 161 | - while (node->left) { | ||
| 162 | - const float val = representation->evaluate(t, node->featureIdx); | ||
| 163 | - if (categorical) { | ||
| 164 | - const int c = (int)val; | ||
| 165 | - node = (node->subset[c >> 5] & (1 << (c & 31))) ? node->left : node->right; | ||
| 166 | - } else { | ||
| 167 | - node = val <= node->threshold ? node->left : node->right; | ||
| 168 | - } | ||
| 169 | - } | ||
| 170 | - | ||
| 171 | - sum += node->value; | ||
| 172 | - } | ||
| 173 | - | ||
| 174 | - if (confidence) | ||
| 175 | - *confidence = sum; | ||
| 176 | - return sum < threshold - THRESHOLD_EPS ? 0.0f : 1.0f; | ||
| 177 | - } | ||
| 178 | - | ||
| 179 | - float classify(const Template &src, bool process, float *confidence) const | ||
| 180 | - { | ||
| 181 | - // This code is written in a way to avoid an unnecessary copy construction and destruction of `src` when `process` is false. | ||
| 182 | - return process ? classifyPreprocessed(preprocess(src), confidence) : classifyPreprocessed(src, confidence); | ||
| 183 | - } | ||
| 184 | - | ||
| 185 | - int numFeatures() const | ||
| 186 | - { | ||
| 187 | - return representation->numFeatures(); | ||
| 188 | - } | ||
| 189 | - | ||
| 190 | - Template preprocess(const Template &src) const | ||
| 191 | - { | ||
| 192 | - return representation->preprocess(src); | ||
| 193 | - } | ||
| 194 | - | ||
| 195 | - Size windowSize(int *dx, int *dy) const | ||
| 196 | - { | ||
| 197 | - return representation->windowSize(dx, dy); | ||
| 198 | - } | ||
| 199 | - | ||
| 200 | - void load(QDataStream &stream) | ||
| 201 | - { | ||
| 202 | - representation->load(stream); | ||
| 203 | - | ||
| 204 | - stream >> threshold; | ||
| 205 | - int numClassifiers; stream >> numClassifiers; | ||
| 206 | - for (int i = 0; i < numClassifiers; i++) { | ||
| 207 | - Node *classifier = new Node; | ||
| 208 | - loadRecursive(stream, classifier, representation->maxCatCount()); | ||
| 209 | - classifiers.append(classifier); | ||
| 210 | - } | ||
| 211 | - } | ||
| 212 | - | ||
| 213 | - void store(QDataStream &stream) const | ||
| 214 | - { | ||
| 215 | - representation->store(stream); | ||
| 216 | - | ||
| 217 | - stream << threshold; | ||
| 218 | - stream << classifiers.size(); | ||
| 219 | - foreach (const Node *classifier, classifiers) | ||
| 220 | - storeRecursive(stream, classifier, representation->maxCatCount()); | ||
| 221 | - } | ||
| 222 | -}; | ||
| 223 | - | ||
| 224 | -QList<Node*> getClassifers(Classifier *classifier) | ||
| 225 | -{ | ||
| 226 | - BoostedForestClassifier *boostedForest = static_cast<BoostedForestClassifier*>(classifier); | ||
| 227 | - return boostedForest->classifiers; | ||
| 228 | -} | ||
| 229 | - | ||
| 230 | -BR_REGISTER(Classifier, BoostedForestClassifier) | ||
| 231 | - | ||
| 232 | -} // namespace br | ||
| 233 | - | ||
| 234 | -#include "classification/boostedforest.moc" |