Commit 8b1dd0029b1f1e6d6952c5df0ddbfa01e88bc728

Authored by Jordan Cheney
1 parent 9eaaf62b

Progress towards better organization

openbr/core/boost.cpp
@@ -178,21 +178,6 @@ CascadeBoostParams::CascadeBoostParams( int _boostType, @@ -178,21 +178,6 @@ CascadeBoostParams::CascadeBoostParams( int _boostType,
178 use_surrogates = use_1se_rule = truncate_pruned_tree = false; 178 use_surrogates = use_1se_rule = truncate_pruned_tree = false;
179 } 179 }
180 180
181 -void CascadeBoostParams::write( FileStorage &fs ) const  
182 -{  
183 - string boostTypeStr = boost_type == CvBoost::DISCRETE ? CC_DISCRETE_BOOST :  
184 - boost_type == CvBoost::REAL ? CC_REAL_BOOST :  
185 - boost_type == CvBoost::LOGIT ? CC_LOGIT_BOOST :  
186 - boost_type == CvBoost::GENTLE ? CC_GENTLE_BOOST : string();  
187 - CV_Assert( !boostTypeStr.empty() );  
188 - fs << CC_BOOST_TYPE << boostTypeStr;  
189 - fs << CC_MINHITRATE << minHitRate;  
190 - fs << CC_MAXFALSEALARM << maxFalseAlarm;  
191 - fs << CC_TRIM_RATE << weight_trim_rate;  
192 - fs << CC_MAX_DEPTH << max_depth;  
193 - fs << CC_WEAK_COUNT << weak_count;  
194 -}  
195 -  
196 //---------------------------- CascadeBoostTrainData ----------------------------- 181 //---------------------------- CascadeBoostTrainData -----------------------------
197 182
198 CvDTreeNode* CascadeBoostTrainData::subsample_data( const CvMat* _subsample_idx ) 183 CvDTreeNode* CascadeBoostTrainData::subsample_data( const CvMat* _subsample_idx )
@@ -826,139 +811,6 @@ CvDTreeNode* CascadeBoostTree::predict( int sampleIdx ) const @@ -826,139 +811,6 @@ CvDTreeNode* CascadeBoostTree::predict( int sampleIdx ) const
826 return node; 811 return node;
827 } 812 }
828 813
829 -/*  
830 -static void writeRecursive(FileStorage &fs, CvDTreeNode *node, int maxCatCount)  
831 -{  
832 - bool hasChildren = node->left ? true : false;  
833 - fs << "hasChildren" << hasChildren;  
834 -  
835 - if (!hasChildren) // Write the leaf value  
836 - fs << "value" << node->value; // value of the node. Only relevant for leaf nodes  
837 - else { // Write the splitting information and then the children  
838 - if (maxCatCount > 0) {  
839 - fs << "subset" << "[:";  
840 - for (int i = 0; i < ((maxCatCount + 31) / 32); i++)  
841 - fs << node->split->subset[i]; // subset to split on (categorical features)  
842 - fs << "]";  
843 - } else {  
844 - fs << "threshold" << node->split->ord.c; // threshold to split on (ordered features)  
845 - }  
846 -  
847 - fs << "feature_idx" << node->split->var_idx; // feature idx of node  
848 -  
849 - fs << "left" << "{"; writeRecursive(fs, node->left, maxCatCount); fs << "}"; // write left child  
850 - fs << "right" << "{"; writeRecursive(fs, node->right, maxCatCount); fs << "}"; // write right child  
851 - }  
852 -}  
853 -  
854 -void CascadeBoostTree::write(FileStorage &fs)  
855 -{  
856 - fs << "{";  
857 - writeRecursive(fs, root, ((CascadeBoostTrainData*)data)->featureEvaluator->getMaxCatCount());  
858 - fs << "}";  
859 -}  
860 -  
861 -static void readRecursive(const FileNode &fn, CvDTreeNode *node, CvDTreeTrainData *data)  
862 -{  
863 - bool hasChildren = (int)fn["hasChildren"];  
864 -  
865 - if (!hasChildren)  
866 - node->value = (float)fn["value"];  
867 - else {  
868 - int maxCatCount = ((CascadeBoostTrainData*)data)->featureEvaluator->getMaxCatCount();  
869 - if (maxCatCount > 0) {  
870 - node->split = data->new_split_cat(0, 0);  
871 - FileNode subset_node = fn["subset"]; FileNodeIterator subset_it = subset_node.begin();  
872 - for (int i = 0; i < (maxCatCount + 31) / 32; i++, ++subset_it)  
873 - node->split->subset[i] = (int)*subset_it;  
874 - } else {  
875 - float threshold = (float)fn["threshold"];  
876 - node->split = data->new_split_ord(0, threshold, 0, 0, 0);  
877 - }  
878 -  
879 - node->split->var_idx = (int)fn["feature_idx"];  
880 -  
881 - CvDTreeNode *leftChild = data->new_node(node, 0, 0, 0);  
882 - node->left = leftChild;  
883 - readRecursive(fn["left"], leftChild, data);  
884 -  
885 - CvDTreeNode *rightChild = data->new_node(node, 0, 0, 0);  
886 - node->right = rightChild;  
887 - readRecursive(fn["right"], rightChild, data);  
888 - }  
889 -}  
890 -  
891 -void CascadeBoostTree::read(const FileNode &fn, CvBoost* _ensemble, CvDTreeTrainData* _data)  
892 -{  
893 - clear();  
894 - data = _data;  
895 - ensemble = _ensemble;  
896 - pruned_tree_idx = 0;  
897 -  
898 - root = data->new_node(0, 0, 0, 0);  
899 - readRecursive(fn, root, data);  
900 -}*/  
901 -  
902 -void CascadeBoostTree::write(FileStorage &fs)  
903 -{  
904 - int maxCatCount = ((CascadeBoostTrainData*)data)->featureEvaluator->getMaxCatCount();  
905 - int subsetN = (maxCatCount + 31)/32;  
906 - queue<CvDTreeNode*> internalNodesQueue;  
907 - int size = (int)pow( 2.f, (float)ensemble->get_params().max_depth);  
908 - Ptr<float> leafVals = new float[size];  
909 - int leafValIdx = 0;  
910 - int internalNodeIdx = 1;  
911 - CvDTreeNode* tempNode;  
912 -  
913 - CV_DbgAssert( root );  
914 - internalNodesQueue.push( root );  
915 -  
916 - fs << "{";  
917 - fs << CC_INTERNAL_NODES << "[:";  
918 - while (!internalNodesQueue.empty())  
919 - {  
920 - tempNode = internalNodesQueue.front();  
921 - CV_Assert( tempNode->left );  
922 - if ( !tempNode->left->left && !tempNode->left->right) // left node is leaf  
923 - {  
924 - leafVals[-leafValIdx] = (float)tempNode->left->value;  
925 - fs << leafValIdx-- ;  
926 - }  
927 - else  
928 - {  
929 - internalNodesQueue.push( tempNode->left );  
930 - fs << internalNodeIdx++;  
931 - }  
932 - CV_Assert( tempNode->right );  
933 - if ( !tempNode->right->left && !tempNode->right->right) // right node is leaf  
934 - {  
935 - leafVals[-leafValIdx] = (float)tempNode->right->value;  
936 - fs << leafValIdx--;  
937 - }  
938 - else  
939 - {  
940 - internalNodesQueue.push( tempNode->right );  
941 - fs << internalNodeIdx++;  
942 - }  
943 - int fidx = tempNode->split->var_idx;  
944 -  
945 - fs << fidx;  
946 - if ( !maxCatCount )  
947 - fs << tempNode->split->ord.c;  
948 - else  
949 - for( int i = 0; i < subsetN; i++ )  
950 - fs << tempNode->split->subset[i];  
951 - internalNodesQueue.pop();  
952 - }  
953 - fs << "]"; // CC_INTERNAL_NODES  
954 -  
955 - fs << CC_LEAF_VALUES << "[:";  
956 - for (int ni = 0; ni < -leafValIdx; ni++)  
957 - fs << leafVals[ni];  
958 - fs << "]"; // CC_LEAF_VALUES  
959 - fs << "}";  
960 -}  
961 -  
962 void CascadeBoostTree::split_node_data( CvDTreeNode* node ) 814 void CascadeBoostTree::split_node_data( CvDTreeNode* node )
963 { 815 {
964 int n = node->sample_count, nl, nr, scount = data->sample_count; 816 int n = node->sample_count, nl, nr, scount = data->sample_count;
@@ -1214,6 +1066,7 @@ bool CascadeBoost::train( const FeatureEvaluator* _featureEvaluator, @@ -1214,6 +1066,7 @@ bool CascadeBoost::train( const FeatureEvaluator* _featureEvaluator,
1214 break; 1066 break;
1215 } 1067 }
1216 1068
  1069 + trees.append(tree);
1217 cvSeqPush( weak, &tree ); 1070 cvSeqPush( weak, &tree );
1218 update_weights( tree ); 1071 update_weights( tree );
1219 trim_weights(); 1072 trim_weights();
@@ -1541,21 +1394,3 @@ bool CascadeBoost::isErrDesired() @@ -1541,21 +1394,3 @@ bool CascadeBoost::isErrDesired()
1541 1394
1542 return falseAlarm <= maxFalseAlarm; 1395 return falseAlarm <= maxFalseAlarm;
1543 } 1396 }
1544 -  
1545 -void CascadeBoost::write(FileStorage &fs) const  
1546 -{  
1547 -// char cmnt[30];  
1548 - CascadeBoostTree* weakTree;  
1549 - fs << CC_WEAK_COUNT << weak->total;  
1550 - fs << CC_STAGE_THRESHOLD << threshold;  
1551 - fs << CC_WEAK_CLASSIFIERS << "[";  
1552 - for( int wi = 0; wi < weak->total; wi++)  
1553 - {  
1554 - /*sprintf( cmnt, "tree %i", wi );  
1555 - cvWriteComment( fs, cmnt, 0 );*/  
1556 - weakTree = *((CascadeBoostTree**) cvGetSeqElem( weak, wi ));  
1557 - weakTree->write(fs);  
1558 - }  
1559 - fs << "]";  
1560 -}  
1561 -  
openbr/core/boost.h
@@ -77,7 +77,6 @@ struct CascadeBoostParams : CvBoostParams @@ -77,7 +77,6 @@ struct CascadeBoostParams : CvBoostParams
77 CascadeBoostParams(int _boostType, float _minHitRate, float _maxFalseAlarm, 77 CascadeBoostParams(int _boostType, float _minHitRate, float _maxFalseAlarm,
78 double _weightTrimRate, int _maxDepth, int _maxWeakCount); 78 double _weightTrimRate, int _maxDepth, int _maxWeakCount);
79 virtual ~CascadeBoostParams() {} 79 virtual ~CascadeBoostParams() {}
80 - void write( cv::FileStorage &fs ) const;  
81 }; 80 };
82 81
83 struct CascadeBoostTrainData : CvDTreeTrainData 82 struct CascadeBoostTrainData : CvDTreeTrainData
@@ -113,7 +112,6 @@ class CascadeBoostTree : public CvBoostTree @@ -113,7 +112,6 @@ class CascadeBoostTree : public CvBoostTree
113 { 112 {
114 public: 113 public:
115 virtual CvDTreeNode* predict(int sampleIdx) const; 114 virtual CvDTreeNode* predict(int sampleIdx) const;
116 - void write(cv::FileStorage &fs);  
117 115
118 protected: 116 protected:
119 virtual void split_node_data(CvDTreeNode* n); 117 virtual void split_node_data(CvDTreeNode* n);
@@ -128,13 +126,15 @@ public: @@ -128,13 +126,15 @@ public:
128 virtual float predict( int sampleIdx, bool returnSum = false ) const; 126 virtual float predict( int sampleIdx, bool returnSum = false ) const;
129 127
130 float getThreshold() const { return threshold; } 128 float getThreshold() const { return threshold; }
131 - void write(cv::FileStorage &fs) const; 129 + const QList<CvBoostTree*> getTrees() const { return trees; }
132 130
133 protected: 131 protected:
134 virtual bool set_params(const CvBoostParams& _params); 132 virtual bool set_params(const CvBoostParams& _params);
135 virtual void update_weights(CvBoostTree* tree); 133 virtual void update_weights(CvBoostTree* tree);
136 virtual bool isErrDesired(); 134 virtual bool isErrDesired();
137 135
  136 + QList<CvBoostTree*> trees;
  137 +
138 float threshold; 138 float threshold;
139 float minHitRate, maxFalseAlarm; 139 float minHitRate, maxFalseAlarm;
140 }; 140 };
openbr/core/cascade.cpp
@@ -124,6 +124,7 @@ void br::groupRectangles(vector&lt;Rect&gt;&amp; rectList, vector&lt;int&gt;&amp; rejectLevels, vect @@ -124,6 +124,7 @@ void br::groupRectangles(vector&lt;Rect&gt;&amp; rectList, vector&lt;int&gt;&amp; rejectLevels, vect
124 static void loadRecursive(const FileNode &fn, _CascadeClassifier::Node *node, int maxCatCount) 124 static void loadRecursive(const FileNode &fn, _CascadeClassifier::Node *node, int maxCatCount)
125 { 125 {
126 bool hasChildren = (int)fn["hasChildren"]; 126 bool hasChildren = (int)fn["hasChildren"];
  127 +
127 if (hasChildren) { 128 if (hasChildren) {
128 if (maxCatCount > 1) { 129 if (maxCatCount > 1) {
129 FileNode subset_fn = fn["subset"]; 130 FileNode subset_fn = fn["subset"];
@@ -158,6 +159,7 @@ bool _CascadeClassifier::load(const string&amp; filename) @@ -158,6 +159,7 @@ bool _CascadeClassifier::load(const string&amp; filename)
158 159
159 // load stages 160 // load stages
160 FileNode stages_fn = root["stages"]; 161 FileNode stages_fn = root["stages"];
  162 +
161 if( stages_fn.empty() ) 163 if( stages_fn.empty() )
162 return false; 164 return false;
163 165
@@ -168,6 +170,7 @@ bool _CascadeClassifier::load(const string&amp; filename) @@ -168,6 +170,7 @@ bool _CascadeClassifier::load(const string&amp; filename)
168 stage.threshold = (float)stage_fn["stageThreshold"] - THRESHOLD_EPS; 170 stage.threshold = (float)stage_fn["stageThreshold"] - THRESHOLD_EPS;
169 171
170 FileNode nodes_fn = stage_fn["weakClassifiers"]; 172 FileNode nodes_fn = stage_fn["weakClassifiers"];
  173 +
171 if(nodes_fn.empty()) 174 if(nodes_fn.empty())
172 return false; 175 return false;
173 176
openbr/plugins/classification/boostedforest.cpp
@@ -6,6 +6,86 @@ using namespace cv; @@ -6,6 +6,86 @@ using namespace cv;
6 namespace br 6 namespace br
7 { 7 {
8 8
  9 +struct Node
  10 +{
  11 + Node() : left(NULL), right(NULL) {}
  12 +
  13 + int featureIdx;
  14 + float threshold; // for ordered features only
  15 + QList<int> subset; // for categorical features only
  16 + float value; // for leaf nodes only
  17 + Node *left;
  18 + Node *right;
  19 +};
  20 +
  21 +static void buildTreeRecursive(Node *node, const CvDTreeNode *tree_node, int maxCatCount)
  22 +{
  23 + if (tree_node->left) {
  24 + if (maxCatCount > 1) {
  25 + for (int i = 0; i < (maxCatCount + 31)/32; i++)
  26 + node->subset.append(tree_node->split->subset[i]);
  27 + } else {
  28 + node->threshold = tree_node->split->ord.c;
  29 + }
  30 +
  31 + node->featureIdx = tree_node->split->var_idx;
  32 +
  33 + node->left = new Node;
  34 + buildTreeRecursive(node->left, tree_node->left, maxCatCount);
  35 + node->right = new Node;
  36 + buildTreeRecursive(node->right, tree_node->right, maxCatCount);
  37 + } else {
  38 + node->value = tree_node->value;
  39 + }
  40 +}
  41 +
  42 +static void writeRecursive(FileStorage &fs, const Node *node, int maxCatCount)
  43 +{
  44 + bool hasChildren = node->left ? true : false;
  45 + fs << "hasChildren" << hasChildren;
  46 +
  47 + if (!hasChildren) // Write the leaf value
  48 + fs << "value" << node->value; // value of the node.
  49 + else { // Write the splitting information and then the children
  50 + if (maxCatCount > 1) {
  51 + fs << "subset" << "[";
  52 + for (int i = 0; i < ((maxCatCount + 31) / 32); i++)
  53 + fs << node->subset[i]; // subset to split on (categorical features)
  54 + fs << "]";
  55 + } else {
  56 + fs << "threshold" << node->threshold; // threshold to split on (ordered features)
  57 + }
  58 +
  59 + fs << "feature_idx" << node->featureIdx; // feature idx of node
  60 +
  61 + fs << "left" << "{"; writeRecursive(fs, node->left, maxCatCount); fs << "}"; // write left child
  62 + fs << "right" << "{"; writeRecursive(fs, node->right, maxCatCount); fs << "}"; // write right child
  63 + }
  64 +}
  65 +
  66 +static void readRecursive(const FileNode &fn, Node *node, int maxCatCount)
  67 +{
  68 + bool hasChildren = (int)fn["hasChildren"];
  69 + if (!hasChildren) {
  70 + node->value = (float)fn["value"];
  71 + } else {
  72 + if (maxCatCount > 1) {
  73 + FileNode subset_fn = fn["subset"];
  74 + for (FileNodeIterator subset_it = subset_fn.begin(); subset_it != subset_fn.end(); ++subset_it)
  75 + node->subset.append((int)*subset_it);
  76 + } else {
  77 + node->threshold = (float)fn["threshold"];
  78 + }
  79 +
  80 + node->featureIdx = (int)fn["feature_idx"];
  81 +
  82 + node->left = new Node;
  83 + readRecursive(fn["left"], node->left, maxCatCount);
  84 + node->right = new Node;
  85 + readRecursive(fn["right"], node->right, maxCatCount);
  86 + }
  87 +}
  88 +
9 class BoostedForestClassifier : public Classifier 89 class BoostedForestClassifier : public Classifier
10 { 90 {
11 Q_OBJECT 91 Q_OBJECT
@@ -24,27 +104,55 @@ class BoostedForestClassifier : public Classifier @@ -24,27 +104,55 @@ class BoostedForestClassifier : public Classifier
24 BR_PROPERTY(int, maxDepth, 1) 104 BR_PROPERTY(int, maxDepth, 1)
25 BR_PROPERTY(int, maxWeakCount, 100) 105 BR_PROPERTY(int, maxWeakCount, 100)
26 106
27 - CascadeBoost *boost;  
28 - FeatureEvaluator *featureEvaluator; 107 + QList<Node*> weakClassifiers;
  108 + float threshold;
29 109
30 void train(const QList<Mat> &images, const QList<float> &labels) 110 void train(const QList<Mat> &images, const QList<float> &labels)
31 { 111 {
32 CascadeBoostParams params(CvBoost::GENTLE, minTAR, maxFAR, trimRate, maxDepth, maxWeakCount); 112 CascadeBoostParams params(CvBoost::GENTLE, minTAR, maxFAR, trimRate, maxDepth, maxWeakCount);
33 113
34 - featureEvaluator = new FeatureEvaluator;  
35 - featureEvaluator->init(representation, images.size()); 114 + FeatureEvaluator featureEvaluator;
  115 + featureEvaluator.init(representation, images.size());
36 116
37 for (int i = 0; i < images.size(); i++) 117 for (int i = 0; i < images.size(); i++)
38 - featureEvaluator->setImage(images[i], labels[i], i); 118 + featureEvaluator.setImage(images[i], labels[i], i);
  119 +
  120 + CascadeBoost boost;
  121 + boost.train(&featureEvaluator, images.size(), 1024, 1024, params);
39 122
40 - boost = new CascadeBoost;  
41 - boost->train(featureEvaluator, images.size(), 1024, 1024, params); 123 + // Convert into simpler, cleaner cascade after training
  124 + threshold = boost.getThreshold();
  125 +
  126 + foreach (const CvBoostTree *tree, boost.getTrees()) {
  127 + Node *root = new Node;
  128 + buildTreeRecursive(root, tree->get_root(), representation->maxCatCount());
  129 + weakClassifiers.append(root);
  130 + }
42 } 131 }
43 132
44 float classify(const Mat &image) const 133 float classify(const Mat &image) const
45 { 134 {
46 - featureEvaluator->setImage(image, 0, 0);  
47 - return boost->predict(0); 135 + Mat pp;
  136 + representation->preprocess(image, pp);
  137 +
  138 + float sum = 0;
  139 +
  140 + foreach (const Node *node, weakClassifiers) {
  141 + while (node->left) {
  142 + if (representation->maxCatCount() > 1) {
  143 + int c = (int)representation->evaluate(pp, node->featureIdx);
  144 + node = (node->subset[c >> 5] & (1 << (c & 31))) ? node->left : node->right;
  145 + } else {
  146 + double val = representation->evaluate(pp, node->featureIdx);
  147 + node = val < node->threshold ? node->left : node->right;
  148 + }
  149 + }
  150 + sum += node->value;
  151 + }
  152 +
  153 + if (sum < threshold)
  154 + return -std::abs(sum);
  155 + return std::abs(sum);
48 } 156 }
49 157
50 int numFeatures() const 158 int numFeatures() const
@@ -64,7 +172,31 @@ class BoostedForestClassifier : public Classifier @@ -64,7 +172,31 @@ class BoostedForestClassifier : public Classifier
64 172
65 void write(FileStorage &fs) const 173 void write(FileStorage &fs) const
66 { 174 {
67 - boost->write(fs); 175 + fs << "numWeak" << weakClassifiers.size();
  176 + fs << "stageThreshold" << threshold;
  177 + fs << "weakClassifiers" << "[";
  178 + foreach (const Node *root, weakClassifiers) {
  179 + fs << "{";
  180 + writeRecursive(fs, root, representation->maxCatCount());
  181 + fs << "}";
  182 + }
  183 + fs << "]";
  184 + }
  185 +
  186 + void read(const FileNode &node)
  187 + {
  188 + weakClassifiers.reserve((int)node["numWeak"]);
  189 + threshold = (float)node["stageThreshold"];
  190 +
  191 + FileNode weaks_fn = node["weakClassifiers"];
  192 + for (FileNodeIterator weaks_it = weaks_fn.begin(); weaks_it != weaks_fn.end(); ++weaks_it) {
  193 + FileNode weak_fn = *weaks_it;
  194 +
  195 + Node *root = new Node;
  196 + readRecursive(weak_fn, root, representation->maxCatCount());
  197 +
  198 + weakClassifiers.append(root);
  199 + }
68 } 200 }
69 }; 201 };
70 202
openbr/plugins/classification/cascade.cpp
@@ -112,12 +112,14 @@ class CascadeClassifier : public Classifier @@ -112,12 +112,14 @@ class CascadeClassifier : public Classifier
112 Q_PROPERTY(int numPos READ get_numPos WRITE set_numPos RESET reset_numPos STORED false) 112 Q_PROPERTY(int numPos READ get_numPos WRITE set_numPos RESET reset_numPos STORED false)
113 Q_PROPERTY(int numNegs READ get_numNegs WRITE set_numNegs RESET reset_numNegs STORED false) 113 Q_PROPERTY(int numNegs READ get_numNegs WRITE set_numNegs RESET reset_numNegs STORED false)
114 Q_PROPERTY(float maxFAR READ get_maxFAR WRITE set_maxFAR RESET reset_maxFAR STORED false) 114 Q_PROPERTY(float maxFAR READ get_maxFAR WRITE set_maxFAR RESET reset_maxFAR STORED false)
  115 + Q_PROPERTY(bool ROCMode READ get_ROCMode WRITE set_ROCMode RESET reset_ROCMode STORED false)
115 116
116 BR_PROPERTY(QString, stageDescription, "") 117 BR_PROPERTY(QString, stageDescription, "")
117 BR_PROPERTY(int, numStages, 20) 118 BR_PROPERTY(int, numStages, 20)
118 BR_PROPERTY(int, numPos, 1000) 119 BR_PROPERTY(int, numPos, 1000)
119 BR_PROPERTY(int, numNegs, 1000) 120 BR_PROPERTY(int, numNegs, 1000)
120 BR_PROPERTY(float, maxFAR, pow(0.5, numStages)) 121 BR_PROPERTY(float, maxFAR, pow(0.5, numStages))
  122 + BR_PROPERTY(bool, ROCMode, false)
121 123
122 QList<Classifier *> stages; 124 QList<Classifier *> stages;
123 125
@@ -155,10 +157,17 @@ class CascadeClassifier : public Classifier @@ -155,10 +157,17 @@ class CascadeClassifier : public Classifier
155 157
156 float classify(const Mat &image) const 158 float classify(const Mat &image) const
157 { 159 {
158 - foreach (const Classifier *stage, stages)  
159 - if (stage->classify(image) == 0.0f)  
160 - return 0.0f;  
161 - return 1.0f; 160 + if (stages.size() == 0) // special case for empty cascade
  161 + return 1.0f;
  162 +
  163 + float result = 0.0f;
  164 + for (int stageIdx = 0; stageIdx < stages.size(); stageIdx++) {
  165 + result = stages[stageIdx]->classify(image);
  166 +
  167 + if (result < 0)
  168 + return stageIdx > (stages.size() - 4) ? stageIdx * result : 0.0f;
  169 + }
  170 + return std::abs(stages.size() * result);
162 } 171 }
163 172
164 int numFeatures() const 173 int numFeatures() const
@@ -178,21 +187,6 @@ class CascadeClassifier : public Classifier @@ -178,21 +187,6 @@ class CascadeClassifier : public Classifier
178 187
179 void write(FileStorage &fs) const 188 void write(FileStorage &fs) const
180 { 189 {
181 - fs << CC_STAGE_TYPE << CC_BOOST;  
182 - fs << CC_FEATURE_TYPE << CC_LBP;  
183 - fs << CC_HEIGHT << 24;  
184 - fs << CC_WIDTH << 24;  
185 -  
186 - CascadeBoostParams stageParams(CvBoost::GINI, 0.999, 0.5, 0.95, 1, 200);  
187 - fs << CC_STAGE_PARAMS << "{"; stageParams.write( fs ); fs << "}";  
188 -  
189 - fs << CC_FEATURE_PARAMS << "{";  
190 - fs << CC_MAX_CAT_COUNT << stages.first()->maxCatCount();  
191 - fs << CC_FEATURE_SIZE << 1;  
192 - fs << "}";  
193 -  
194 - fs << CC_STAGE_NUM << stages.size();  
195 -  
196 fs << CC_STAGES << "["; 190 fs << CC_STAGES << "[";
197 foreach (const Classifier *stage, stages) { 191 foreach (const Classifier *stage, stages) {
198 fs << "{"; 192 fs << "{";
@@ -212,7 +206,7 @@ private: @@ -212,7 +206,7 @@ private:
212 if (!imgHandler.getPos(pos)) 206 if (!imgHandler.getPos(pos))
213 qFatal("Cannot get another positive sample!"); 207 qFatal("Cannot get another positive sample!");
214 208
215 - if (classify(pos) == 1.0f) { 209 + if (classify(pos) > 0.0f) {
216 printf("POS current samples: %d\r", images.size()); 210 printf("POS current samples: %d\r", images.size());
217 images.append(pos); 211 images.append(pos);
218 labels.append(1.0f); 212 labels.append(1.0f);
@@ -228,7 +222,7 @@ private: @@ -228,7 +222,7 @@ private:
228 if (!imgHandler.getNeg(neg)) 222 if (!imgHandler.getNeg(neg))
229 qFatal("Cannot get another negative sample!"); 223 qFatal("Cannot get another negative sample!");
230 224
231 - if (classify(neg) == 1.0f) { 225 + if (classify(neg) > 0.0f) {
232 printf("NEG current samples: %d\r", images.size() - posCount); 226 printf("NEG current samples: %d\r", images.size() - posCount);
233 images.append(neg); 227 images.append(neg);
234 labels.append(0.0f); 228 labels.append(0.0f);
openbr/plugins/imgproc/slidingwindow.cpp
@@ -21,6 +21,8 @@ @@ -21,6 +21,8 @@
21 #include <openbr/core/qtutils.h> 21 #include <openbr/core/qtutils.h>
22 22
23 #include <opencv2/highgui/highgui.hpp> 23 #include <opencv2/highgui/highgui.hpp>
  24 +#include <opencv2/imgproc/imgproc.hpp>
  25 +#include <opencv2/objdetect/objdetect.hpp>
24 26
25 using namespace cv; 27 using namespace cv;
26 28
@@ -39,11 +41,23 @@ class SlidingWindowTransform : public Transform @@ -39,11 +41,23 @@ class SlidingWindowTransform : public Transform
39 Q_OBJECT 41 Q_OBJECT
40 42
41 Q_PROPERTY(br::Classifier *classifier READ get_classifier WRITE set_classifier RESET reset_classifier STORED false) 43 Q_PROPERTY(br::Classifier *classifier READ get_classifier WRITE set_classifier RESET reset_classifier STORED false)
  44 + Q_PROPERTY(int minSize READ get_minSize WRITE set_minSize RESET reset_minSize STORED false)
  45 + Q_PROPERTY(int maxSize READ get_maxSize WRITE set_maxSize RESET reset_maxSize STORED false)
  46 + Q_PROPERTY(float scaleFactor READ get_scaleFactor WRITE set_scaleFactor RESET reset_scaleFactor STORED false)
  47 + Q_PROPERTY(int minNeighbors READ get_minNeighbors WRITE set_minNeighbors RESET reset_minNeighbors STORED false)
  48 + Q_PROPERTY(float eps READ get_eps WRITE set_eps RESET reset_eps STORED false)
  49 +
42 Q_PROPERTY(QString cascadeDir READ get_cascadeDir WRITE set_cascadeDir RESET reset_cascadeDir STORED false) 50 Q_PROPERTY(QString cascadeDir READ get_cascadeDir WRITE set_cascadeDir RESET reset_cascadeDir STORED false)
43 Q_PROPERTY(QString vecFile READ get_vecFile WRITE set_vecFile RESET reset_vecFile STORED false) 51 Q_PROPERTY(QString vecFile READ get_vecFile WRITE set_vecFile RESET reset_vecFile STORED false)
44 Q_PROPERTY(QString negFile READ get_negFile WRITE set_negFile RESET reset_negFile STORED false) 52 Q_PROPERTY(QString negFile READ get_negFile WRITE set_negFile RESET reset_negFile STORED false)
45 53
46 BR_PROPERTY(br::Classifier *, classifier, NULL) 54 BR_PROPERTY(br::Classifier *, classifier, NULL)
  55 + BR_PROPERTY(int, minSize, 24)
  56 + BR_PROPERTY(int, maxSize, -1)
  57 + BR_PROPERTY(float, scaleFactor, 1.2)
  58 + BR_PROPERTY(int, minNeighbors, 5)
  59 + BR_PROPERTY(float, eps, 0.2)
  60 +
47 BR_PROPERTY(QString, cascadeDir, "") 61 BR_PROPERTY(QString, cascadeDir, "")
48 BR_PROPERTY(QString, vecFile, "vec.vec") 62 BR_PROPERTY(QString, vecFile, "vec.vec")
49 BR_PROPERTY(QString, negFile, "neg.txt") 63 BR_PROPERTY(QString, negFile, "neg.txt")
@@ -162,7 +176,91 @@ class SlidingWindowTransform : public Transform @@ -162,7 +176,91 @@ class SlidingWindowTransform : public Transform
162 176
163 void project(const Template &src, Template &dst) const 177 void project(const Template &src, Template &dst) const
164 { 178 {
165 - (void)src; (void)dst; 179 + TemplateList temp;
  180 + project(TemplateList() << src, temp);
  181 + if (!temp.isEmpty()) dst = temp.first();
  182 + }
  183 +
  184 + void project(const TemplateList &src, TemplateList &dst) const
  185 + {
  186 + foreach (const Template &t, src) {
  187 + const bool enrollAll = t.file.getBool("enrollAll");
  188 +
  189 + // Mirror the behavior of ExpandTransform in the special case
  190 + // of an empty template.
  191 + if (t.empty() && !enrollAll) {
  192 + dst.append(t);
  193 + continue;
  194 + }
  195 +
  196 + for (int i = 0; i < t.size(); i++) {
  197 + Mat image;
  198 + OpenCVUtils::cvtUChar(t[i], image);
  199 +
  200 + std::vector<Rect> rects;
  201 + std::vector<int> rejectLevels;
  202 + std::vector<double> levelWeights;
  203 +
  204 + Size minObjectSize(minSize, minSize);
  205 + Size maxObjectSize(maxSize, maxSize);
  206 + if (maxObjectSize.height == 0 || maxObjectSize.width == 0)
  207 + maxObjectSize = image.size();
  208 +
  209 + Mat imageBuffer(image.rows + 1, image.cols + 1, CV_8U);
  210 +
  211 + for (double factor = 1; ; factor *= scaleFactor) {
  212 + Size originalWindowSize = classifier->windowSize();
  213 +
  214 + Size windowSize(cvRound(originalWindowSize.width*factor), cvRound(originalWindowSize.height*factor) );
  215 + Size scaledImageSize(cvRound(image.cols/factor ), cvRound(image.rows/factor));
  216 + Size processingRectSize(scaledImageSize.width - originalWindowSize.width, scaledImageSize.height - originalWindowSize.height);
  217 +
  218 + if (processingRectSize.width <= 0 || processingRectSize.height <= 0)
  219 + break;
  220 + if (windowSize.width > maxObjectSize.width || windowSize.height > maxObjectSize.height)
  221 + break;
  222 + if (windowSize.width < minObjectSize.width || windowSize.height < minObjectSize.height)
  223 + continue;
  224 +
  225 + Mat scaledImage(scaledImageSize, CV_8U, imageBuffer.data);
  226 + resize(image, scaledImage, scaledImageSize, 0, 0, CV_INTER_LINEAR);
  227 +
  228 + int yStep = factor > 2. ? 1 : 2;
  229 + for (int y = 0; y < processingRectSize.height; y += yStep) {
  230 + for (int x = 0; x < processingRectSize.width; x += yStep) {
  231 + Mat window = scaledImage(Rect(Point(x, y), classifier->windowSize())).clone();
  232 +
  233 + float result = classifier->classify(window);
  234 +
  235 + if (result > 0) {
  236 + rects.push_back(Rect(cvRound(x*factor), cvRound(y*factor), windowSize.width, windowSize.height));
  237 + rejectLevels.push_back(1);
  238 + levelWeights.push_back(result);
  239 + }
  240 + if (result == 0)
  241 + x += yStep;
  242 + }
  243 + }
  244 + }
  245 +
  246 + groupRectangles(rects, rejectLevels, levelWeights, minNeighbors, eps);
  247 +
  248 + if (!enrollAll && rects.empty())
  249 + rects.push_back(Rect(0, 0, image.cols, image.rows));
  250 +
  251 + for (size_t j = 0; j < rects.size(); j++) {
  252 + Template u(t.file, image);
  253 + if (rejectLevels.size() > j)
  254 + u.file.set("Confidence", rejectLevels[j]*levelWeights[j]);
  255 + else
  256 + u.file.set("Confidence", 1);
  257 + const QRectF rect = OpenCVUtils::fromRect(rects[j]);
  258 + u.file.appendRect(rect);
  259 + u.file.set("Face", rect);
  260 + dst.append(u);
  261 + }
  262 + }
  263 + }
166 } 264 }
167 }; 265 };
168 266