Commit e76f4ef80e1eee2a3d2c7c7c258e321ddba11e4e
1 parent
65672355
Starting tree rep
Showing
1 changed file
with
86 additions
and
5 deletions
openbr/plugins/tree.cpp
| @@ -73,6 +73,8 @@ class ForestTransform : public Transform | @@ -73,6 +73,8 @@ class ForestTransform : public Transform | ||
| 73 | BR_PROPERTY(QString, outputVariable, "") | 73 | BR_PROPERTY(QString, outputVariable, "") |
| 74 | 74 | ||
| 75 | CvRTrees forest; | 75 | CvRTrees forest; |
| 76 | + int totalSize; | ||
| 77 | + QList< QList<const CvDTreeNode*> > nodes; | ||
| 76 | 78 | ||
| 77 | void train(const TemplateList &data) | 79 | void train(const TemplateList &data) |
| 78 | { | 80 | { |
| @@ -95,39 +97,116 @@ class ForestTransform : public Transform | @@ -95,39 +97,116 @@ class ForestTransform : public Transform | ||
| 95 | 0, | 97 | 0, |
| 96 | false, | 98 | false, |
| 97 | 2, | 99 | 2, |
| 98 | - 0, // priors | 100 | + 0, |
| 99 | false, | 101 | false, |
| 100 | 0, | 102 | 0, |
| 101 | maxTrees, | 103 | maxTrees, |
| 102 | forestAccuracy, | 104 | forestAccuracy, |
| 103 | - CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)); | 105 | + CV_TERMCRIT_ITER)); |
| 104 | 106 | ||
| 105 | qDebug() << "Number of trees:" << forest.get_tree_count(); | 107 | qDebug() << "Number of trees:" << forest.get_tree_count(); |
| 108 | + | ||
| 109 | + for (int i=0; i<forest.get_tree_count(); i++) { | ||
| 110 | + nodes.append(QList<const CvDTreeNode*>()); | ||
| 111 | + const CvDTreeNode* node = forest.get_tree(i)->get_root(); | ||
| 112 | + | ||
| 113 | + // traverse the tree and save all the nodes in depth-first order | ||
| 114 | + for(;;) | ||
| 115 | + { | ||
| 116 | + CvDTreeNode* parent; | ||
| 117 | + for(;;) | ||
| 118 | + { | ||
| 119 | + if( !node->left ) | ||
| 120 | + break; | ||
| 121 | + node = node->left; | ||
| 122 | + } | ||
| 123 | + | ||
| 124 | + nodes.last().append(node); | ||
| 125 | + | ||
| 126 | + for( parent = node->parent; parent && parent->right == node; | ||
| 127 | + node = parent, parent = parent->parent ) | ||
| 128 | + ; | ||
| 129 | + | ||
| 130 | + if( !parent ) | ||
| 131 | + break; | ||
| 132 | + | ||
| 133 | + node = parent->right; | ||
| 134 | + } | ||
| 135 | + | ||
| 136 | + totalSize += nodes.last().size(); | ||
| 137 | + } | ||
| 106 | } | 138 | } |
| 107 | 139 | ||
| 108 | void project(const Template &src, Template &dst) const | 140 | void project(const Template &src, Template &dst) const |
| 109 | { | 141 | { |
| 110 | dst = src; | 142 | dst = src; |
| 111 | 143 | ||
| 144 | + /* | ||
| 112 | float response; | 145 | float response; |
| 113 | if (classification && returnConfidence) { | 146 | if (classification && returnConfidence) { |
| 114 | // Fuzzy class label | 147 | // Fuzzy class label |
| 115 | response = forest.predict_prob(src.m().reshape(1,1)); | 148 | response = forest.predict_prob(src.m().reshape(1,1)); |
| 116 | } else { | 149 | } else { |
| 117 | response = forest.predict(src.m().reshape(1,1)); | 150 | response = forest.predict(src.m().reshape(1,1)); |
| 151 | + }*/ | ||
| 152 | + | ||
| 153 | + // QTime timer; | ||
| 154 | + // timer.start(); | ||
| 155 | + | ||
| 156 | + //qDebug() << forest.get_tree(0)->get_var_count(); | ||
| 157 | + | ||
| 158 | + Mat responses = Mat::zeros(totalSize,1,CV_32F); | ||
| 159 | + | ||
| 160 | + int offset = 0; | ||
| 161 | + for (int i=0; i<nodes.size(); i++) { | ||
| 162 | + int index = nodes[i].indexOf(forest.get_tree(i)->predict(src.m().reshape(1,1))); | ||
| 163 | + responses.at<float>(offset+index,0) = 1; | ||
| 164 | + offset += nodes[i].size(); | ||
| 118 | } | 165 | } |
| 119 | 166 | ||
| 120 | if (overwriteMat) { | 167 | if (overwriteMat) { |
| 121 | - dst.m() = Mat(1, 1, CV_32F); | ||
| 122 | - dst.m().at<float>(0, 0) = response; | 168 | + dst.m() = responses; |
| 169 | + //dst.m() = Mat(1, 1, CV_32F); | ||
| 170 | + //dst.m().at<float>(0, 0) = response; | ||
| 123 | } else { | 171 | } else { |
| 124 | - dst.file.set(outputVariable, response); | 172 | + //dst.file.set(outputVariable, response); |
| 125 | } | 173 | } |
| 174 | + | ||
| 175 | + //qDebug() << timer.elapsed(); | ||
| 126 | } | 176 | } |
| 127 | 177 | ||
| 128 | void load(QDataStream &stream) | 178 | void load(QDataStream &stream) |
| 129 | { | 179 | { |
| 130 | loadModel(forest,stream); | 180 | loadModel(forest,stream); |
| 181 | + for (int i=0; i<forest.get_tree_count(); i++) { | ||
| 182 | + nodes.append(QList<const CvDTreeNode*>()); | ||
| 183 | + const CvDTreeNode* node = forest.get_tree(i)->get_root(); | ||
| 184 | + | ||
| 185 | + // traverse the tree and save all the nodes in depth-first order | ||
| 186 | + for(;;) | ||
| 187 | + { | ||
| 188 | + CvDTreeNode* parent; | ||
| 189 | + for(;;) | ||
| 190 | + { | ||
| 191 | + if( !node->left ) | ||
| 192 | + break; | ||
| 193 | + node = node->left; | ||
| 194 | + } | ||
| 195 | + | ||
| 196 | + nodes.last().append(node); | ||
| 197 | + | ||
| 198 | + for( parent = node->parent; parent && parent->right == node; | ||
| 199 | + node = parent, parent = parent->parent ) | ||
| 200 | + ; | ||
| 201 | + | ||
| 202 | + if( !parent ) | ||
| 203 | + break; | ||
| 204 | + | ||
| 205 | + node = parent->right; | ||
| 206 | + } | ||
| 207 | + | ||
| 208 | + totalSize += nodes.last().size(); | ||
| 209 | + } | ||
| 131 | } | 210 | } |
| 132 | 211 | ||
| 133 | void store(QDataStream &stream) const | 212 | void store(QDataStream &stream) const |
| @@ -137,6 +216,8 @@ class ForestTransform : public Transform | @@ -137,6 +216,8 @@ class ForestTransform : public Transform | ||
| 137 | 216 | ||
| 138 | void init() | 217 | void init() |
| 139 | { | 218 | { |
| 219 | + totalSize = 0; | ||
| 220 | + | ||
| 140 | if (outputVariable.isEmpty()) | 221 | if (outputVariable.isEmpty()) |
| 141 | outputVariable = inputVariable; | 222 | outputVariable = inputVariable; |
| 142 | } | 223 | } |