Commit 9b430a4f829a7e5aecd163738865cc37e6cbeb59

Authored by Scott Klum
1 parent 9e9e1a9b

Added forest induction

Showing 1 changed file with 67 additions and 36 deletions
openbr/plugins/tree.cpp
... ... @@ -72,11 +72,51 @@ class ForestTransform : public Transform
72 72 BR_PROPERTY(QString, inputVariable, "Label")
73 73 BR_PROPERTY(QString, outputVariable, "")
74 74  
  75 + void train(const TemplateList &data)
  76 + {
  77 + trainForest(data);
  78 + }
  79 +
  80 + void project(const Template &src, Template &dst) const
  81 + {
  82 + dst = src;
  83 +
  84 + float response;
  85 + if (classification && returnConfidence) {
  86 + // Fuzzy class label
  87 + response = forest.predict_prob(src.m().reshape(1,1));
  88 + } else {
  89 + response = forest.predict(src.m().reshape(1,1));
  90 + }
  91 +
  92 + if (overwriteMat) {
  93 + dst.m() = Mat(1, 1, CV_32F);
  94 + dst.m().at<float>(0, 0) = response;
  95 + } else {
  96 + dst.file.set(outputVariable, response);
  97 + }
  98 + }
  99 +
  100 + void load(QDataStream &stream)
  101 + {
  102 + loadModel(forest,stream);
  103 + }
  104 +
  105 + void store(QDataStream &stream) const
  106 + {
  107 + storeModel(forest,stream);
  108 + }
  109 +
  110 + void init()
  111 + {
  112 + if (outputVariable.isEmpty())
  113 + outputVariable = inputVariable;
  114 + }
  115 +
  116 +protected:
75 117 CvRTrees forest;
76   - int totalSize;
77   - QList< QList<const CvDTreeNode*> > nodes;
78 118  
79   - void train(const TemplateList &data)
  119 + void trainForest(const TemplateList &data)
80 120 {
81 121 Mat samples = OpenCVUtils::toMat(data.data());
82 122 Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable));
... ... @@ -102,9 +142,30 @@ class ForestTransform : public Transform
102 142 0,
103 143 maxTrees,
104 144 forestAccuracy,
105   - CV_TERMCRIT_ITER));
  145 + CV_TERMCRIT_ITER | CV_TERMCRIT_EPS));
106 146  
107 147 qDebug() << "Number of trees:" << forest.get_tree_count();
  148 + }
  149 +};
  150 +
  151 +BR_REGISTER(Transform, ForestTransform)
  152 +
  153 +/*!
  154 + * \ingroup transforms
  155 + * \brief Wraps OpenCV's random trees framework to induce features
  156 + * \author Scott Klum \cite sklum
  157 + * \brief https://lirias.kuleuven.be/bitstream/123456789/316661/1/icdm11-camready.pdf
  158 + */
  159 +class ForestInductionTransform : public ForestTransform
  160 +{
  161 + Q_OBJECT
  162 +
  163 + int totalSize;
  164 + QList< QList<const CvDTreeNode*> > nodes;
  165 +
  166 + void train(const TemplateList &data)
  167 + {
  168 + trainForest(data);
108 169  
109 170 for (int i=0; i<forest.get_tree_count(); i++) {
110 171 nodes.append(QList<const CvDTreeNode*>());
... ... @@ -141,20 +202,6 @@ class ForestTransform : public Transform
141 202 {
142 203 dst = src;
143 204  
144   - /*
145   - float response;
146   - if (classification && returnConfidence) {
147   - // Fuzzy class label
148   - response = forest.predict_prob(src.m().reshape(1,1));
149   - } else {
150   - response = forest.predict(src.m().reshape(1,1));
151   - }*/
152   -
153   - // QTime timer;
154   - // timer.start();
155   -
156   - //qDebug() << forest.get_tree(0)->get_var_count();
157   -
158 205 Mat responses = Mat::zeros(totalSize,1,CV_32F);
159 206  
160 207 int offset = 0;
... ... @@ -164,15 +211,7 @@ class ForestTransform : public Transform
164 211 offset += nodes[i].size();
165 212 }
166 213  
167   - if (overwriteMat) {
168   - dst.m() = responses;
169   - //dst.m() = Mat(1, 1, CV_32F);
170   - //dst.m().at<float>(0, 0) = response;
171   - } else {
172   - //dst.file.set(outputVariable, response);
173   - }
174   -
175   - //qDebug() << timer.elapsed();
  214 + dst.m() = responses;
176 215 }
177 216  
178 217 void load(QDataStream &stream)
... ... @@ -213,17 +252,9 @@ class ForestTransform : public Transform
213 252 {
214 253 storeModel(forest,stream);
215 254 }
216   -
217   - void init()
218   - {
219   - totalSize = 0;
220   -
221   - if (outputVariable.isEmpty())
222   - outputVariable = inputVariable;
223   - }
224 255 };
225 256  
226   -BR_REGISTER(Transform, ForestTransform)
  257 +BR_REGISTER(Transform, ForestInductionTransform)
227 258  
228 259 /*!
229 260 * \ingroup transforms
... ...