Commit 9ac4e2b34def5b9a2d720c30a64280c23f06fb79
1 parent
8f903c52
Added AdaBoost wrapper
Showing
1 changed file
with
123 additions
and
16 deletions
openbr/plugins/tree.cpp
| ... | ... | @@ -11,7 +11,7 @@ using namespace cv; |
| 11 | 11 | namespace br |
| 12 | 12 | { |
| 13 | 13 | |
| 14 | -static void storeForest(const CvRTrees &forest, QDataStream &stream) | |
| 14 | +static void storeModel(const CvStatModel &model, QDataStream &stream) | |
| 15 | 15 | { |
| 16 | 16 | // Create local file |
| 17 | 17 | QTemporaryFile tempFile; |
| ... | ... | @@ -19,7 +19,7 @@ static void storeForest(const CvRTrees &forest, QDataStream &stream) |
| 19 | 19 | tempFile.close(); |
| 20 | 20 | |
| 21 | 21 | // Save MLP to local file |
| 22 | - forest.save(qPrintable(tempFile.fileName())); | |
| 22 | + model.save(qPrintable(tempFile.fileName())); | |
| 23 | 23 | |
| 24 | 24 | // Copy local file contents to stream |
| 25 | 25 | tempFile.open(); |
| ... | ... | @@ -28,20 +28,20 @@ static void storeForest(const CvRTrees &forest, QDataStream &stream) |
| 28 | 28 | stream << data; |
| 29 | 29 | } |
| 30 | 30 | |
| 31 | -static void loadForest(CvRTrees &forest, QDataStream &stream) | |
| 31 | +static void loadModel(CvStatModel &model, QDataStream &stream) | |
| 32 | 32 | { |
| 33 | 33 | // Copy local file contents from stream |
| 34 | 34 | QByteArray data; |
| 35 | 35 | stream >> data; |
| 36 | 36 | |
| 37 | 37 | // Create local file |
| 38 | - QTemporaryFile tempFile(QDir::tempPath()+"/forest"); | |
| 38 | + QTemporaryFile tempFile(QDir::tempPath()+"/model"); | |
| 39 | 39 | tempFile.open(); |
| 40 | 40 | tempFile.write(data); |
| 41 | 41 | tempFile.close(); |
| 42 | 42 | |
| 43 | 43 | // Load MLP from local file |
| 44 | - forest.load(qPrintable(tempFile.fileName())); | |
| 44 | + model.load(qPrintable(tempFile.fileName())); | |
| 45 | 45 | } |
| 46 | 46 | |
| 47 | 47 | /*! |
| ... | ... | @@ -50,16 +50,16 @@ static void loadForest(CvRTrees &forest, QDataStream &stream) |
| 50 | 50 | * \author Scott Klum \cite sklum |
| 51 | 51 | * \brief http://docs.opencv.org/modules/ml/doc/random_trees.html |
| 52 | 52 | */ |
| 53 | -class ForestTransform : public MetaTransform | |
| 53 | +class ForestTransform : public Transform | |
| 54 | 54 | { |
| 55 | 55 | Q_OBJECT |
| 56 | - Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED true) | |
| 57 | - Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED true) | |
| 58 | - Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true) | |
| 59 | - Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true) | |
| 60 | - Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true) | |
| 61 | - Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED true) | |
| 62 | - Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED true) | |
| 56 | + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED false) | |
| 57 | + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED false) | |
| 58 | + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false) | |
| 59 | + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED false) | |
| 60 | + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED false) | |
| 61 | + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false) | |
| 62 | + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) | |
| 63 | 63 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 64 | 64 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) |
| 65 | 65 | BR_PROPERTY(bool, classification, true) |
| ... | ... | @@ -100,7 +100,7 @@ class ForestTransform : public MetaTransform |
| 100 | 100 | 0, |
| 101 | 101 | maxTrees, |
| 102 | 102 | forestAccuracy, |
| 103 | - CV_TERMCRIT_EPS)); | |
| 103 | + CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)); | |
| 104 | 104 | |
| 105 | 105 | qDebug() << "Number of trees:" << forest.get_tree_count(); |
| 106 | 106 | } |
| ... | ... | @@ -127,12 +127,12 @@ class ForestTransform : public MetaTransform |
| 127 | 127 | |
| 128 | 128 | void load(QDataStream &stream) |
| 129 | 129 | { |
| 130 | - loadForest(forest,stream); | |
| 130 | + loadModel(forest,stream); | |
| 131 | 131 | } |
| 132 | 132 | |
| 133 | 133 | void store(QDataStream &stream) const |
| 134 | 134 | { |
| 135 | - storeForest(forest,stream); | |
| 135 | + storeModel(forest,stream); | |
| 136 | 136 | } |
| 137 | 137 | |
| 138 | 138 | void init() |
| ... | ... | @@ -144,6 +144,113 @@ class ForestTransform : public MetaTransform |
| 144 | 144 | |
| 145 | 145 | BR_REGISTER(Transform, ForestTransform) |
| 146 | 146 | |
| 147 | +/*! | |
| 148 | + * \ingroup transforms | |
| 149 | + * \brief Wraps OpenCV's Ada Boost framework | |
| 150 | + * \author Scott Klum \cite sklum | |
| 151 | + * \brief http://docs.opencv.org/modules/ml/doc/boosting.html | |
| 152 | + */ | |
| 153 | +class AdaBoostTransform : public Transform | |
| 154 | +{ | |
| 155 | + Q_OBJECT | |
| 156 | + Q_ENUMS(Type) | |
| 157 | + Q_ENUMS(SplitCriteria) | |
| 158 | + | |
| 159 | + Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) | |
| 160 | + Q_PROPERTY(SplitCriteria splitCriteria READ get_splitCriteria WRITE set_splitCriteria RESET reset_splitCriteria STORED false) | |
| 161 | + Q_PROPERTY(int weakCount READ get_weakCount WRITE set_weakCount RESET reset_weakCount STORED false) | |
| 162 | + Q_PROPERTY(float trimRate READ get_trimRate WRITE set_trimRate RESET reset_trimRate STORED false) | |
| 163 | + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | |
| 164 | + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false) | |
| 165 | + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false) | |
| 166 | + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) | |
| 167 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | |
| 168 | + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | |
| 169 | + | |
| 170 | +public: | |
| 171 | + enum Type { Discrete = CvBoost::DISCRETE, | |
| 172 | + Real = CvBoost::REAL, | |
| 173 | + Logit = CvBoost::LOGIT, | |
| 174 | + Gentle = CvBoost::GENTLE}; | |
| 175 | + | |
| 176 | + enum SplitCriteria { Default = CvBoost::DEFAULT, | |
| 177 | + Gini = CvBoost::GINI, | |
| 178 | + Misclass = CvBoost::MISCLASS, | |
| 179 | + Sqerr = CvBoost::SQERR}; | |
| 180 | + | |
| 181 | +private: | |
| 182 | + BR_PROPERTY(Type, type, Real) | |
| 183 | + BR_PROPERTY(SplitCriteria, splitCriteria, Default) | |
| 184 | + BR_PROPERTY(int, weakCount, 100) | |
| 185 | + BR_PROPERTY(float, trimRate, .95) | |
| 186 | + BR_PROPERTY(int, folds, 0) | |
| 187 | + BR_PROPERTY(int, maxDepth, 1) | |
| 188 | + BR_PROPERTY(bool, returnConfidence, true) | |
| 189 | + BR_PROPERTY(bool, overwriteMat, true) | |
| 190 | + BR_PROPERTY(QString, inputVariable, "Label") | |
| 191 | + BR_PROPERTY(QString, outputVariable, "") | |
| 192 | + | |
| 193 | + CvBoost boost; | |
| 194 | + | |
| 195 | + void train(const TemplateList &data) | |
| 196 | + { | |
| 197 | + Mat samples = OpenCVUtils::toMat(data.data()); | |
| 198 | + Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); | |
| 199 | + | |
| 200 | + Mat types = Mat(samples.cols + 1, 1, CV_8U); | |
| 201 | + types.setTo(Scalar(CV_VAR_NUMERICAL)); | |
| 202 | + types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL; | |
| 203 | + | |
| 204 | + CvBoostParams params; | |
| 205 | + params.boost_type = type; | |
| 206 | + params.split_criteria = splitCriteria; | |
| 207 | + params.weak_count = weakCount; | |
| 208 | + params.weight_trim_rate = trimRate; | |
| 209 | + params.cv_folds = folds; | |
| 210 | + params.max_depth = maxDepth; | |
| 211 | + | |
| 212 | + boost.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), | |
| 213 | + params); | |
| 214 | + } | |
| 215 | + | |
| 216 | + void project(const Template &src, Template &dst) const | |
| 217 | + { | |
| 218 | + dst = src; | |
| 219 | + float response; | |
| 220 | + if (returnConfidence) { | |
| 221 | + response = boost.predict(src.m().reshape(1,1),Mat(),Range::all(),false,true)/weakCount; | |
| 222 | + } else { | |
| 223 | + response = boost.predict(src.m().reshape(1,1)); | |
| 224 | + } | |
| 225 | + | |
| 226 | + if (overwriteMat) { | |
| 227 | + dst.m() = Mat(1, 1, CV_32F); | |
| 228 | + dst.m().at<float>(0, 0) = response; | |
| 229 | + } else { | |
| 230 | + dst.file.set(outputVariable, response); | |
| 231 | + } | |
| 232 | + } | |
| 233 | + | |
| 234 | + void load(QDataStream &stream) | |
| 235 | + { | |
| 236 | + loadModel(boost,stream); | |
| 237 | + } | |
| 238 | + | |
| 239 | + void store(QDataStream &stream) const | |
| 240 | + { | |
| 241 | + storeModel(boost,stream); | |
| 242 | + } | |
| 243 | + | |
| 244 | + | |
| 245 | + void init() | |
| 246 | + { | |
| 247 | + if (outputVariable.isEmpty()) | |
| 248 | + outputVariable = inputVariable; | |
| 249 | + } | |
| 250 | +}; | |
| 251 | + | |
| 252 | +BR_REGISTER(Transform, AdaBoostTransform) | |
| 253 | + | |
| 147 | 254 | } // namespace br |
| 148 | 255 | |
| 149 | 256 | #include "tree.moc" | ... | ... |