Commit e76f4ef80e1eee2a3d2c7c7c258e321ddba11e4e
1 parent
65672355
Starting tree rep
Showing
1 changed file
with
86 additions
and
5 deletions
openbr/plugins/tree.cpp
| ... | ... | @@ -73,6 +73,8 @@ class ForestTransform : public Transform |
| 73 | 73 | BR_PROPERTY(QString, outputVariable, "") |
| 74 | 74 | |
| 75 | 75 | CvRTrees forest; |
| 76 | + int totalSize; | |
| 77 | + QList< QList<const CvDTreeNode*> > nodes; | |
| 76 | 78 | |
| 77 | 79 | void train(const TemplateList &data) |
| 78 | 80 | { |
| ... | ... | @@ -95,39 +97,116 @@ class ForestTransform : public Transform |
| 95 | 97 | 0, |
| 96 | 98 | false, |
| 97 | 99 | 2, |
| 98 | - 0, // priors | |
| 100 | + 0, | |
| 99 | 101 | false, |
| 100 | 102 | 0, |
| 101 | 103 | maxTrees, |
| 102 | 104 | forestAccuracy, |
| 103 | - CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)); | |
| 105 | + CV_TERMCRIT_ITER)); | |
| 104 | 106 | |
| 105 | 107 | qDebug() << "Number of trees:" << forest.get_tree_count(); |
| 108 | + | |
| 109 | + for (int i=0; i<forest.get_tree_count(); i++) { | |
| 110 | + nodes.append(QList<const CvDTreeNode*>()); | |
| 111 | + const CvDTreeNode* node = forest.get_tree(i)->get_root(); | |
| 112 | + | |
| 113 | + // traverse the tree and save all the nodes in depth-first order | |
| 114 | + for(;;) | |
| 115 | + { | |
| 116 | + CvDTreeNode* parent; | |
| 117 | + for(;;) | |
| 118 | + { | |
| 119 | + if( !node->left ) | |
| 120 | + break; | |
| 121 | + node = node->left; | |
| 122 | + } | |
| 123 | + | |
| 124 | + nodes.last().append(node); | |
| 125 | + | |
| 126 | + for( parent = node->parent; parent && parent->right == node; | |
| 127 | + node = parent, parent = parent->parent ) | |
| 128 | + ; | |
| 129 | + | |
| 130 | + if( !parent ) | |
| 131 | + break; | |
| 132 | + | |
| 133 | + node = parent->right; | |
| 134 | + } | |
| 135 | + | |
| 136 | + totalSize += nodes.last().size(); | |
| 137 | + } | |
| 106 | 138 | } |
| 107 | 139 | |
| 108 | 140 | void project(const Template &src, Template &dst) const |
| 109 | 141 | { |
| 110 | 142 | dst = src; |
| 111 | 143 | |
| 144 | + /* | |
| 112 | 145 | float response; |
| 113 | 146 | if (classification && returnConfidence) { |
| 114 | 147 | // Fuzzy class label |
| 115 | 148 | response = forest.predict_prob(src.m().reshape(1,1)); |
| 116 | 149 | } else { |
| 117 | 150 | response = forest.predict(src.m().reshape(1,1)); |
| 151 | + }*/ | |
| 152 | + | |
| 153 | + // QTime timer; | |
| 154 | + // timer.start(); | |
| 155 | + | |
| 156 | + //qDebug() << forest.get_tree(0)->get_var_count(); | |
| 157 | + | |
| 158 | + Mat responses = Mat::zeros(totalSize,1,CV_32F); | |
| 159 | + | |
| 160 | + int offset = 0; | |
| 161 | + for (int i=0; i<nodes.size(); i++) { | |
| 162 | + int index = nodes[i].indexOf(forest.get_tree(i)->predict(src.m().reshape(1,1))); | |
| 163 | + responses.at<float>(offset+index,0) = 1; | |
| 164 | + offset += nodes[i].size(); | |
| 118 | 165 | } |
| 119 | 166 | |
| 120 | 167 | if (overwriteMat) { |
| 121 | - dst.m() = Mat(1, 1, CV_32F); | |
| 122 | - dst.m().at<float>(0, 0) = response; | |
| 168 | + dst.m() = responses; | |
| 169 | + //dst.m() = Mat(1, 1, CV_32F); | |
| 170 | + //dst.m().at<float>(0, 0) = response; | |
| 123 | 171 | } else { |
| 124 | - dst.file.set(outputVariable, response); | |
| 172 | + //dst.file.set(outputVariable, response); | |
| 125 | 173 | } |
| 174 | + | |
| 175 | + //qDebug() << timer.elapsed(); | |
| 126 | 176 | } |
| 127 | 177 | |
| 128 | 178 | void load(QDataStream &stream) |
| 129 | 179 | { |
| 130 | 180 | loadModel(forest,stream); |
| 181 | + for (int i=0; i<forest.get_tree_count(); i++) { | |
| 182 | + nodes.append(QList<const CvDTreeNode*>()); | |
| 183 | + const CvDTreeNode* node = forest.get_tree(i)->get_root(); | |
| 184 | + | |
| 185 | + // traverse the tree and save all the nodes in depth-first order | |
| 186 | + for(;;) | |
| 187 | + { | |
| 188 | + CvDTreeNode* parent; | |
| 189 | + for(;;) | |
| 190 | + { | |
| 191 | + if( !node->left ) | |
| 192 | + break; | |
| 193 | + node = node->left; | |
| 194 | + } | |
| 195 | + | |
| 196 | + nodes.last().append(node); | |
| 197 | + | |
| 198 | + for( parent = node->parent; parent && parent->right == node; | |
| 199 | + node = parent, parent = parent->parent ) | |
| 200 | + ; | |
| 201 | + | |
| 202 | + if( !parent ) | |
| 203 | + break; | |
| 204 | + | |
| 205 | + node = parent->right; | |
| 206 | + } | |
| 207 | + | |
| 208 | + totalSize += nodes.last().size(); | |
| 209 | + } | |
| 131 | 210 | } |
| 132 | 211 | |
| 133 | 212 | void store(QDataStream &stream) const |
| ... | ... | @@ -137,6 +216,8 @@ class ForestTransform : public Transform |
| 137 | 216 | |
| 138 | 217 | void init() |
| 139 | 218 | { |
| 219 | + totalSize = 0; | |
| 220 | + | |
| 140 | 221 | if (outputVariable.isEmpty()) |
| 141 | 222 | outputVariable = inputVariable; |
| 142 | 223 | } | ... | ... |