Commit 9b430a4f829a7e5aecd163738865cc37e6cbeb59
1 parent
9e9e1a9b
Added forest induction
Showing
1 changed file
with
67 additions
and
36 deletions
openbr/plugins/tree.cpp
| ... | ... | @@ -72,11 +72,51 @@ class ForestTransform : public Transform |
| 72 | 72 | BR_PROPERTY(QString, inputVariable, "Label") |
| 73 | 73 | BR_PROPERTY(QString, outputVariable, "") |
| 74 | 74 | |
| 75 | + void train(const TemplateList &data) | |
| 76 | + { | |
| 77 | + trainForest(data); | |
| 78 | + } | |
| 79 | + | |
| 80 | + void project(const Template &src, Template &dst) const | |
| 81 | + { | |
| 82 | + dst = src; | |
| 83 | + | |
| 84 | + float response; | |
| 85 | + if (classification && returnConfidence) { | |
| 86 | + // Fuzzy class label | |
| 87 | + response = forest.predict_prob(src.m().reshape(1,1)); | |
| 88 | + } else { | |
| 89 | + response = forest.predict(src.m().reshape(1,1)); | |
| 90 | + } | |
| 91 | + | |
| 92 | + if (overwriteMat) { | |
| 93 | + dst.m() = Mat(1, 1, CV_32F); | |
| 94 | + dst.m().at<float>(0, 0) = response; | |
| 95 | + } else { | |
| 96 | + dst.file.set(outputVariable, response); | |
| 97 | + } | |
| 98 | + } | |
| 99 | + | |
| 100 | + void load(QDataStream &stream) | |
| 101 | + { | |
| 102 | + loadModel(forest,stream); | |
| 103 | + } | |
| 104 | + | |
| 105 | + void store(QDataStream &stream) const | |
| 106 | + { | |
| 107 | + storeModel(forest,stream); | |
| 108 | + } | |
| 109 | + | |
| 110 | + void init() | |
| 111 | + { | |
| 112 | + if (outputVariable.isEmpty()) | |
| 113 | + outputVariable = inputVariable; | |
| 114 | + } | |
| 115 | + | |
| 116 | +protected: | |
| 75 | 117 | CvRTrees forest; |
| 76 | - int totalSize; | |
| 77 | - QList< QList<const CvDTreeNode*> > nodes; | |
| 78 | 118 | |
| 79 | - void train(const TemplateList &data) | |
| 119 | + void trainForest(const TemplateList &data) | |
| 80 | 120 | { |
| 81 | 121 | Mat samples = OpenCVUtils::toMat(data.data()); |
| 82 | 122 | Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); |
| ... | ... | @@ -102,9 +142,30 @@ class ForestTransform : public Transform |
| 102 | 142 | 0, |
| 103 | 143 | maxTrees, |
| 104 | 144 | forestAccuracy, |
| 105 | - CV_TERMCRIT_ITER)); | |
| 145 | + CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)); | |
| 106 | 146 | |
| 107 | 147 | qDebug() << "Number of trees:" << forest.get_tree_count(); |
| 148 | + } | |
| 149 | +}; | |
| 150 | + | |
| 151 | +BR_REGISTER(Transform, ForestTransform) | |
| 152 | + | |
| 153 | +/*! | |
| 154 | + * \ingroup transforms | |
| 155 | + * \brief Wraps OpenCV's random trees framework to induce features | |
| 156 | + * \author Scott Klum \cite sklum | |
| 157 | + * \brief https://lirias.kuleuven.be/bitstream/123456789/316661/1/icdm11-camready.pdf | |
| 158 | + */ | |
| 159 | +class ForestInductionTransform : public ForestTransform | |
| 160 | +{ | |
| 161 | + Q_OBJECT | |
| 162 | + | |
| 163 | + int totalSize; | |
| 164 | + QList< QList<const CvDTreeNode*> > nodes; | |
| 165 | + | |
| 166 | + void train(const TemplateList &data) | |
| 167 | + { | |
| 168 | + trainForest(data); | |
| 108 | 169 | |
| 109 | 170 | for (int i=0; i<forest.get_tree_count(); i++) { |
| 110 | 171 | nodes.append(QList<const CvDTreeNode*>()); |
| ... | ... | @@ -141,20 +202,6 @@ class ForestTransform : public Transform |
| 141 | 202 | { |
| 142 | 203 | dst = src; |
| 143 | 204 | |
| 144 | - /* | |
| 145 | - float response; | |
| 146 | - if (classification && returnConfidence) { | |
| 147 | - // Fuzzy class label | |
| 148 | - response = forest.predict_prob(src.m().reshape(1,1)); | |
| 149 | - } else { | |
| 150 | - response = forest.predict(src.m().reshape(1,1)); | |
| 151 | - }*/ | |
| 152 | - | |
| 153 | - // QTime timer; | |
| 154 | - // timer.start(); | |
| 155 | - | |
| 156 | - //qDebug() << forest.get_tree(0)->get_var_count(); | |
| 157 | - | |
| 158 | 205 | Mat responses = Mat::zeros(totalSize,1,CV_32F); |
| 159 | 206 | |
| 160 | 207 | int offset = 0; |
| ... | ... | @@ -164,15 +211,7 @@ class ForestTransform : public Transform |
| 164 | 211 | offset += nodes[i].size(); |
| 165 | 212 | } |
| 166 | 213 | |
| 167 | - if (overwriteMat) { | |
| 168 | - dst.m() = responses; | |
| 169 | - //dst.m() = Mat(1, 1, CV_32F); | |
| 170 | - //dst.m().at<float>(0, 0) = response; | |
| 171 | - } else { | |
| 172 | - //dst.file.set(outputVariable, response); | |
| 173 | - } | |
| 174 | - | |
| 175 | - //qDebug() << timer.elapsed(); | |
| 214 | + dst.m() = responses; | |
| 176 | 215 | } |
| 177 | 216 | |
| 178 | 217 | void load(QDataStream &stream) |
| ... | ... | @@ -213,17 +252,9 @@ class ForestTransform : public Transform |
| 213 | 252 | { |
| 214 | 253 | storeModel(forest,stream); |
| 215 | 254 | } |
| 216 | - | |
| 217 | - void init() | |
| 218 | - { | |
| 219 | - totalSize = 0; | |
| 220 | - | |
| 221 | - if (outputVariable.isEmpty()) | |
| 222 | - outputVariable = inputVariable; | |
| 223 | - } | |
| 224 | 255 | }; |
| 225 | 256 | |
| 226 | -BR_REGISTER(Transform, ForestTransform) | |
| 257 | +BR_REGISTER(Transform, ForestInductionTransform) | |
| 227 | 258 | |
| 228 | 259 | /*! |
| 229 | 260 | * \ingroup transforms | ... | ... |