Commit e76f4ef80e1eee2a3d2c7c7c258e321ddba11e4e

Authored by Scott Klum
1 parent 65672355

Starting tree rep

Showing 1 changed file with 86 additions and 5 deletions
openbr/plugins/tree.cpp
... ... @@ -73,6 +73,8 @@ class ForestTransform : public Transform
73 73 BR_PROPERTY(QString, outputVariable, "")
74 74  
75 75 CvRTrees forest;
  76 + int totalSize;
  77 + QList< QList<const CvDTreeNode*> > nodes;
76 78  
77 79 void train(const TemplateList &data)
78 80 {
... ... @@ -95,39 +97,116 @@ class ForestTransform : public Transform
95 97 0,
96 98 false,
97 99 2,
98   - 0, // priors
  100 + 0,
99 101 false,
100 102 0,
101 103 maxTrees,
102 104 forestAccuracy,
103   - CV_TERMCRIT_ITER | CV_TERMCRIT_EPS));
  105 + CV_TERMCRIT_ITER));
104 106  
105 107 qDebug() << "Number of trees:" << forest.get_tree_count();
  108 +
  109 + for (int i=0; i<forest.get_tree_count(); i++) {
  110 + nodes.append(QList<const CvDTreeNode*>());
  111 + const CvDTreeNode* node = forest.get_tree(i)->get_root();
  112 +
  113 + // traverse the tree and save all the nodes in depth-first order
  114 + for(;;)
  115 + {
  116 + CvDTreeNode* parent;
  117 + for(;;)
  118 + {
  119 + if( !node->left )
  120 + break;
  121 + node = node->left;
  122 + }
  123 +
  124 + nodes.last().append(node);
  125 +
  126 + for( parent = node->parent; parent && parent->right == node;
  127 + node = parent, parent = parent->parent )
  128 + ;
  129 +
  130 + if( !parent )
  131 + break;
  132 +
  133 + node = parent->right;
  134 + }
  135 +
  136 + totalSize += nodes.last().size();
  137 + }
106 138 }
107 139  
108 140 void project(const Template &src, Template &dst) const
109 141 {
110 142 dst = src;
111 143  
  144 + /*
112 145 float response;
113 146 if (classification && returnConfidence) {
114 147 // Fuzzy class label
115 148 response = forest.predict_prob(src.m().reshape(1,1));
116 149 } else {
117 150 response = forest.predict(src.m().reshape(1,1));
  151 + }*/
  152 +
  153 + // QTime timer;
  154 + // timer.start();
  155 +
  156 + //qDebug() << forest.get_tree(0)->get_var_count();
  157 +
  158 + Mat responses = Mat::zeros(totalSize,1,CV_32F);
  159 +
  160 + int offset = 0;
  161 + for (int i=0; i<nodes.size(); i++) {
  162 + int index = nodes[i].indexOf(forest.get_tree(i)->predict(src.m().reshape(1,1)));
  163 + responses.at<float>(offset+index,0) = 1;
  164 + offset += nodes[i].size();
118 165 }
119 166  
120 167 if (overwriteMat) {
121   - dst.m() = Mat(1, 1, CV_32F);
122   - dst.m().at<float>(0, 0) = response;
  168 + dst.m() = responses;
  169 + //dst.m() = Mat(1, 1, CV_32F);
  170 + //dst.m().at<float>(0, 0) = response;
123 171 } else {
124   - dst.file.set(outputVariable, response);
  172 + //dst.file.set(outputVariable, response);
125 173 }
  174 +
  175 + //qDebug() << timer.elapsed();
126 176 }
127 177  
128 178 void load(QDataStream &stream)
129 179 {
130 180 loadModel(forest,stream);
  181 + for (int i=0; i<forest.get_tree_count(); i++) {
  182 + nodes.append(QList<const CvDTreeNode*>());
  183 + const CvDTreeNode* node = forest.get_tree(i)->get_root();
  184 +
  185 + // traverse the tree and save all the nodes in depth-first order
  186 + for(;;)
  187 + {
  188 + CvDTreeNode* parent;
  189 + for(;;)
  190 + {
  191 + if( !node->left )
  192 + break;
  193 + node = node->left;
  194 + }
  195 +
  196 + nodes.last().append(node);
  197 +
  198 + for( parent = node->parent; parent && parent->right == node;
  199 + node = parent, parent = parent->parent )
  200 + ;
  201 +
  202 + if( !parent )
  203 + break;
  204 +
  205 + node = parent->right;
  206 + }
  207 +
  208 + totalSize += nodes.last().size();
  209 + }
131 210 }
132 211  
133 212 void store(QDataStream &stream) const
... ... @@ -137,6 +216,8 @@ class ForestTransform : public Transform
137 216  
138 217 void init()
139 218 {
  219 + totalSize = 0;
  220 +
140 221 if (outputVariable.isEmpty())
141 222 outputVariable = inputVariable;
142 223 }
... ...