Commit 9ac4e2b34def5b9a2d720c30a64280c23f06fb79

Authored by Scott Klum
1 parent 8f903c52

Added AdaBoost wrapper

Showing 1 changed file with 123 additions and 16 deletions
openbr/plugins/tree.cpp
... ... @@ -11,7 +11,7 @@ using namespace cv;
11 11 namespace br
12 12 {
13 13  
14   -static void storeForest(const CvRTrees &forest, QDataStream &stream)
  14 +static void storeModel(const CvStatModel &model, QDataStream &stream)
15 15 {
16 16 // Create local file
17 17 QTemporaryFile tempFile;
... ... @@ -19,7 +19,7 @@ static void storeForest(const CvRTrees &forest, QDataStream &stream)
19 19 tempFile.close();
20 20  
21 21 // Save MLP to local file
22   - forest.save(qPrintable(tempFile.fileName()));
  22 + model.save(qPrintable(tempFile.fileName()));
23 23  
24 24 // Copy local file contents to stream
25 25 tempFile.open();
... ... @@ -28,20 +28,20 @@ static void storeForest(const CvRTrees &forest, QDataStream &stream)
28 28 stream << data;
29 29 }
30 30  
31   -static void loadForest(CvRTrees &forest, QDataStream &stream)
  31 +static void loadModel(CvStatModel &model, QDataStream &stream)
32 32 {
33 33 // Copy local file contents from stream
34 34 QByteArray data;
35 35 stream >> data;
36 36  
37 37 // Create local file
38   - QTemporaryFile tempFile(QDir::tempPath()+"/forest");
  38 + QTemporaryFile tempFile(QDir::tempPath()+"/model");
39 39 tempFile.open();
40 40 tempFile.write(data);
41 41 tempFile.close();
42 42  
43 43 // Load MLP from local file
44   - forest.load(qPrintable(tempFile.fileName()));
  44 + model.load(qPrintable(tempFile.fileName()));
45 45 }
46 46  
47 47 /*!
... ... @@ -50,16 +50,16 @@ static void loadForest(CvRTrees &amp;forest, QDataStream &amp;stream)
50 50 * \author Scott Klum \cite sklum
51 51 * \brief http://docs.opencv.org/modules/ml/doc/random_trees.html
52 52 */
53   -class ForestTransform : public MetaTransform
  53 +class ForestTransform : public Transform
54 54 {
55 55 Q_OBJECT
56   - Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED true)
57   - Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED true)
58   - Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true)
59   - Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true)
60   - Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true)
61   - Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED true)
62   - Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED true)
  56 + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED false)
  57 + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED false)
  58 + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false)
  59 + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED false)
  60 + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED false)
  61 + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false)
  62 + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false)
63 63 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
64 64 Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
65 65 BR_PROPERTY(bool, classification, true)
... ... @@ -100,7 +100,7 @@ class ForestTransform : public MetaTransform
100 100 0,
101 101 maxTrees,
102 102 forestAccuracy,
103   - CV_TERMCRIT_EPS));
  103 + CV_TERMCRIT_ITER | CV_TERMCRIT_EPS));
104 104  
105 105 qDebug() << "Number of trees:" << forest.get_tree_count();
106 106 }
... ... @@ -127,12 +127,12 @@ class ForestTransform : public MetaTransform
127 127  
128 128 void load(QDataStream &stream)
129 129 {
130   - loadForest(forest,stream);
  130 + loadModel(forest,stream);
131 131 }
132 132  
133 133 void store(QDataStream &stream) const
134 134 {
135   - storeForest(forest,stream);
  135 + storeModel(forest,stream);
136 136 }
137 137  
138 138 void init()
... ... @@ -144,6 +144,113 @@ class ForestTransform : public MetaTransform
144 144  
145 145 BR_REGISTER(Transform, ForestTransform)
146 146  
  147 +/*!
  148 + * \ingroup transforms
  149 + * \brief Wraps OpenCV's Ada Boost framework
  150 + * \author Scott Klum \cite sklum
  151 + * \brief http://docs.opencv.org/modules/ml/doc/boosting.html
  152 + */
  153 +class AdaBoostTransform : public Transform
  154 +{
  155 + Q_OBJECT
  156 + Q_ENUMS(Type)
  157 + Q_ENUMS(SplitCriteria)
  158 +
  159 + Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false)
  160 + Q_PROPERTY(SplitCriteria splitCriteria READ get_splitCriteria WRITE set_splitCriteria RESET reset_splitCriteria STORED false)
  161 + Q_PROPERTY(int weakCount READ get_weakCount WRITE set_weakCount RESET reset_weakCount STORED false)
  162 + Q_PROPERTY(float trimRate READ get_trimRate WRITE set_trimRate RESET reset_trimRate STORED false)
  163 + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false)
  164 + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false)
  165 + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false)
  166 + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false)
  167 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  168 + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
  169 +
  170 +public:
  171 + enum Type { Discrete = CvBoost::DISCRETE,
  172 + Real = CvBoost::REAL,
  173 + Logit = CvBoost::LOGIT,
  174 + Gentle = CvBoost::GENTLE};
  175 +
  176 + enum SplitCriteria { Default = CvBoost::DEFAULT,
  177 + Gini = CvBoost::GINI,
  178 + Misclass = CvBoost::MISCLASS,
  179 + Sqerr = CvBoost::SQERR};
  180 +
  181 +private:
  182 + BR_PROPERTY(Type, type, Real)
  183 + BR_PROPERTY(SplitCriteria, splitCriteria, Default)
  184 + BR_PROPERTY(int, weakCount, 100)
  185 + BR_PROPERTY(float, trimRate, .95)
  186 + BR_PROPERTY(int, folds, 0)
  187 + BR_PROPERTY(int, maxDepth, 1)
  188 + BR_PROPERTY(bool, returnConfidence, true)
  189 + BR_PROPERTY(bool, overwriteMat, true)
  190 + BR_PROPERTY(QString, inputVariable, "Label")
  191 + BR_PROPERTY(QString, outputVariable, "")
  192 +
  193 + CvBoost boost;
  194 +
  195 + void train(const TemplateList &data)
  196 + {
  197 + Mat samples = OpenCVUtils::toMat(data.data());
  198 + Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable));
  199 +
  200 + Mat types = Mat(samples.cols + 1, 1, CV_8U);
  201 + types.setTo(Scalar(CV_VAR_NUMERICAL));
  202 + types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL;
  203 +
  204 + CvBoostParams params;
  205 + params.boost_type = type;
  206 + params.split_criteria = splitCriteria;
  207 + params.weak_count = weakCount;
  208 + params.weight_trim_rate = trimRate;
  209 + params.cv_folds = folds;
  210 + params.max_depth = maxDepth;
  211 +
  212 + boost.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(),
  213 + params);
  214 + }
  215 +
  216 + void project(const Template &src, Template &dst) const
  217 + {
  218 + dst = src;
  219 + float response;
  220 + if (returnConfidence) {
  221 + response = boost.predict(src.m().reshape(1,1),Mat(),Range::all(),false,true)/weakCount;
  222 + } else {
  223 + response = boost.predict(src.m().reshape(1,1));
  224 + }
  225 +
  226 + if (overwriteMat) {
  227 + dst.m() = Mat(1, 1, CV_32F);
  228 + dst.m().at<float>(0, 0) = response;
  229 + } else {
  230 + dst.file.set(outputVariable, response);
  231 + }
  232 + }
  233 +
  234 + void load(QDataStream &stream)
  235 + {
  236 + loadModel(boost,stream);
  237 + }
  238 +
  239 + void store(QDataStream &stream) const
  240 + {
  241 + storeModel(boost,stream);
  242 + }
  243 +
  244 +
  245 + void init()
  246 + {
  247 + if (outputVariable.isEmpty())
  248 + outputVariable = inputVariable;
  249 + }
  250 +};
  251 +
  252 +BR_REGISTER(Transform, AdaBoostTransform)
  253 +
147 254 } // namespace br
148 255  
149 256 #include "tree.moc"
... ...