Commit 9ac4e2b34def5b9a2d720c30a64280c23f06fb79
1 parent
8f903c52
Added AdaBoost wrapper
Showing
1 changed file
with
123 additions
and
16 deletions
openbr/plugins/tree.cpp
| @@ -11,7 +11,7 @@ using namespace cv; | @@ -11,7 +11,7 @@ using namespace cv; | ||
| 11 | namespace br | 11 | namespace br |
| 12 | { | 12 | { |
| 13 | 13 | ||
| 14 | -static void storeForest(const CvRTrees &forest, QDataStream &stream) | 14 | +static void storeModel(const CvStatModel &model, QDataStream &stream) |
| 15 | { | 15 | { |
| 16 | // Create local file | 16 | // Create local file |
| 17 | QTemporaryFile tempFile; | 17 | QTemporaryFile tempFile; |
| @@ -19,7 +19,7 @@ static void storeForest(const CvRTrees &forest, QDataStream &stream) | @@ -19,7 +19,7 @@ static void storeForest(const CvRTrees &forest, QDataStream &stream) | ||
| 19 | tempFile.close(); | 19 | tempFile.close(); |
| 20 | 20 | ||
| 21 | // Save MLP to local file | 21 | // Save MLP to local file |
| 22 | - forest.save(qPrintable(tempFile.fileName())); | 22 | + model.save(qPrintable(tempFile.fileName())); |
| 23 | 23 | ||
| 24 | // Copy local file contents to stream | 24 | // Copy local file contents to stream |
| 25 | tempFile.open(); | 25 | tempFile.open(); |
| @@ -28,20 +28,20 @@ static void storeForest(const CvRTrees &forest, QDataStream &stream) | @@ -28,20 +28,20 @@ static void storeForest(const CvRTrees &forest, QDataStream &stream) | ||
| 28 | stream << data; | 28 | stream << data; |
| 29 | } | 29 | } |
| 30 | 30 | ||
| 31 | -static void loadForest(CvRTrees &forest, QDataStream &stream) | 31 | +static void loadModel(CvStatModel &model, QDataStream &stream) |
| 32 | { | 32 | { |
| 33 | // Copy local file contents from stream | 33 | // Copy local file contents from stream |
| 34 | QByteArray data; | 34 | QByteArray data; |
| 35 | stream >> data; | 35 | stream >> data; |
| 36 | 36 | ||
| 37 | // Create local file | 37 | // Create local file |
| 38 | - QTemporaryFile tempFile(QDir::tempPath()+"/forest"); | 38 | + QTemporaryFile tempFile(QDir::tempPath()+"/model"); |
| 39 | tempFile.open(); | 39 | tempFile.open(); |
| 40 | tempFile.write(data); | 40 | tempFile.write(data); |
| 41 | tempFile.close(); | 41 | tempFile.close(); |
| 42 | 42 | ||
| 43 | // Load MLP from local file | 43 | // Load MLP from local file |
| 44 | - forest.load(qPrintable(tempFile.fileName())); | 44 | + model.load(qPrintable(tempFile.fileName())); |
| 45 | } | 45 | } |
| 46 | 46 | ||
| 47 | /*! | 47 | /*! |
| @@ -50,16 +50,16 @@ static void loadForest(CvRTrees &forest, QDataStream &stream) | @@ -50,16 +50,16 @@ static void loadForest(CvRTrees &forest, QDataStream &stream) | ||
| 50 | * \author Scott Klum \cite sklum | 50 | * \author Scott Klum \cite sklum |
| 51 | * \brief http://docs.opencv.org/modules/ml/doc/random_trees.html | 51 | * \brief http://docs.opencv.org/modules/ml/doc/random_trees.html |
| 52 | */ | 52 | */ |
| 53 | -class ForestTransform : public MetaTransform | 53 | +class ForestTransform : public Transform |
| 54 | { | 54 | { |
| 55 | Q_OBJECT | 55 | Q_OBJECT |
| 56 | - Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED true) | ||
| 57 | - Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED true) | ||
| 58 | - Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true) | ||
| 59 | - Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true) | ||
| 60 | - Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true) | ||
| 61 | - Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED true) | ||
| 62 | - Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED true) | 56 | + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED false) |
| 57 | + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED false) | ||
| 58 | + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false) | ||
| 59 | + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED false) | ||
| 60 | + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED false) | ||
| 61 | + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false) | ||
| 62 | + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) | ||
| 63 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | 63 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 64 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | 64 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) |
| 65 | BR_PROPERTY(bool, classification, true) | 65 | BR_PROPERTY(bool, classification, true) |
| @@ -100,7 +100,7 @@ class ForestTransform : public MetaTransform | @@ -100,7 +100,7 @@ class ForestTransform : public MetaTransform | ||
| 100 | 0, | 100 | 0, |
| 101 | maxTrees, | 101 | maxTrees, |
| 102 | forestAccuracy, | 102 | forestAccuracy, |
| 103 | - CV_TERMCRIT_EPS)); | 103 | + CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)); |
| 104 | 104 | ||
| 105 | qDebug() << "Number of trees:" << forest.get_tree_count(); | 105 | qDebug() << "Number of trees:" << forest.get_tree_count(); |
| 106 | } | 106 | } |
| @@ -127,12 +127,12 @@ class ForestTransform : public MetaTransform | @@ -127,12 +127,12 @@ class ForestTransform : public MetaTransform | ||
| 127 | 127 | ||
| 128 | void load(QDataStream &stream) | 128 | void load(QDataStream &stream) |
| 129 | { | 129 | { |
| 130 | - loadForest(forest,stream); | 130 | + loadModel(forest,stream); |
| 131 | } | 131 | } |
| 132 | 132 | ||
| 133 | void store(QDataStream &stream) const | 133 | void store(QDataStream &stream) const |
| 134 | { | 134 | { |
| 135 | - storeForest(forest,stream); | 135 | + storeModel(forest,stream); |
| 136 | } | 136 | } |
| 137 | 137 | ||
| 138 | void init() | 138 | void init() |
| @@ -144,6 +144,113 @@ class ForestTransform : public MetaTransform | @@ -144,6 +144,113 @@ class ForestTransform : public MetaTransform | ||
| 144 | 144 | ||
| 145 | BR_REGISTER(Transform, ForestTransform) | 145 | BR_REGISTER(Transform, ForestTransform) |
| 146 | 146 | ||
| 147 | +/*! | ||
| 148 | + * \ingroup transforms | ||
| 149 | + * \brief Wraps OpenCV's Ada Boost framework | ||
| 150 | + * \author Scott Klum \cite sklum | ||
| 151 | + * \brief http://docs.opencv.org/modules/ml/doc/boosting.html | ||
| 152 | + */ | ||
| 153 | +class AdaBoostTransform : public Transform | ||
| 154 | +{ | ||
| 155 | + Q_OBJECT | ||
| 156 | + Q_ENUMS(Type) | ||
| 157 | + Q_ENUMS(SplitCriteria) | ||
| 158 | + | ||
| 159 | + Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) | ||
| 160 | + Q_PROPERTY(SplitCriteria splitCriteria READ get_splitCriteria WRITE set_splitCriteria RESET reset_splitCriteria STORED false) | ||
| 161 | + Q_PROPERTY(int weakCount READ get_weakCount WRITE set_weakCount RESET reset_weakCount STORED false) | ||
| 162 | + Q_PROPERTY(float trimRate READ get_trimRate WRITE set_trimRate RESET reset_trimRate STORED false) | ||
| 163 | + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | ||
| 164 | + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false) | ||
| 165 | + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false) | ||
| 166 | + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) | ||
| 167 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | ||
| 168 | + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | ||
| 169 | + | ||
| 170 | +public: | ||
| 171 | + enum Type { Discrete = CvBoost::DISCRETE, | ||
| 172 | + Real = CvBoost::REAL, | ||
| 173 | + Logit = CvBoost::LOGIT, | ||
| 174 | + Gentle = CvBoost::GENTLE}; | ||
| 175 | + | ||
| 176 | + enum SplitCriteria { Default = CvBoost::DEFAULT, | ||
| 177 | + Gini = CvBoost::GINI, | ||
| 178 | + Misclass = CvBoost::MISCLASS, | ||
| 179 | + Sqerr = CvBoost::SQERR}; | ||
| 180 | + | ||
| 181 | +private: | ||
| 182 | + BR_PROPERTY(Type, type, Real) | ||
| 183 | + BR_PROPERTY(SplitCriteria, splitCriteria, Default) | ||
| 184 | + BR_PROPERTY(int, weakCount, 100) | ||
| 185 | + BR_PROPERTY(float, trimRate, .95) | ||
| 186 | + BR_PROPERTY(int, folds, 0) | ||
| 187 | + BR_PROPERTY(int, maxDepth, 1) | ||
| 188 | + BR_PROPERTY(bool, returnConfidence, true) | ||
| 189 | + BR_PROPERTY(bool, overwriteMat, true) | ||
| 190 | + BR_PROPERTY(QString, inputVariable, "Label") | ||
| 191 | + BR_PROPERTY(QString, outputVariable, "") | ||
| 192 | + | ||
| 193 | + CvBoost boost; | ||
| 194 | + | ||
| 195 | + void train(const TemplateList &data) | ||
| 196 | + { | ||
| 197 | + Mat samples = OpenCVUtils::toMat(data.data()); | ||
| 198 | + Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); | ||
| 199 | + | ||
| 200 | + Mat types = Mat(samples.cols + 1, 1, CV_8U); | ||
| 201 | + types.setTo(Scalar(CV_VAR_NUMERICAL)); | ||
| 202 | + types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL; | ||
| 203 | + | ||
| 204 | + CvBoostParams params; | ||
| 205 | + params.boost_type = type; | ||
| 206 | + params.split_criteria = splitCriteria; | ||
| 207 | + params.weak_count = weakCount; | ||
| 208 | + params.weight_trim_rate = trimRate; | ||
| 209 | + params.cv_folds = folds; | ||
| 210 | + params.max_depth = maxDepth; | ||
| 211 | + | ||
| 212 | + boost.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), | ||
| 213 | + params); | ||
| 214 | + } | ||
| 215 | + | ||
| 216 | + void project(const Template &src, Template &dst) const | ||
| 217 | + { | ||
| 218 | + dst = src; | ||
| 219 | + float response; | ||
| 220 | + if (returnConfidence) { | ||
| 221 | + response = boost.predict(src.m().reshape(1,1),Mat(),Range::all(),false,true)/weakCount; | ||
| 222 | + } else { | ||
| 223 | + response = boost.predict(src.m().reshape(1,1)); | ||
| 224 | + } | ||
| 225 | + | ||
| 226 | + if (overwriteMat) { | ||
| 227 | + dst.m() = Mat(1, 1, CV_32F); | ||
| 228 | + dst.m().at<float>(0, 0) = response; | ||
| 229 | + } else { | ||
| 230 | + dst.file.set(outputVariable, response); | ||
| 231 | + } | ||
| 232 | + } | ||
| 233 | + | ||
| 234 | + void load(QDataStream &stream) | ||
| 235 | + { | ||
| 236 | + loadModel(boost,stream); | ||
| 237 | + } | ||
| 238 | + | ||
| 239 | + void store(QDataStream &stream) const | ||
| 240 | + { | ||
| 241 | + storeModel(boost,stream); | ||
| 242 | + } | ||
| 243 | + | ||
| 244 | + | ||
| 245 | + void init() | ||
| 246 | + { | ||
| 247 | + if (outputVariable.isEmpty()) | ||
| 248 | + outputVariable = inputVariable; | ||
| 249 | + } | ||
| 250 | +}; | ||
| 251 | + | ||
| 252 | +BR_REGISTER(Transform, AdaBoostTransform) | ||
| 253 | + | ||
| 147 | } // namespace br | 254 | } // namespace br |
| 148 | 255 | ||
| 149 | #include "tree.moc" | 256 | #include "tree.moc" |