Commit fb0027a6e60acae160c2b4bc6252b03d0302c268

Authored by dgcrouse
1 parent 48ca4edd

Updated SVM to allow for setting degree for POLY classification.\n Wrapped some …

…parameters to trainSVM in a CvSVMParams object to cut down on clutter.
openbr/plugins/classification/svm.cpp
@@ -26,19 +26,12 @@ using namespace cv; @@ -26,19 +26,12 @@ using namespace cv;
26 namespace br 26 namespace br
27 { 27 {
28 28
29 -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria) 29 +static void trainSVM(SVM &svm, Mat data, Mat lab, CvSVMParams params, float C, float gamma, int folds, bool balanceFolds)
30 { 30 {
31 if (data.type() != CV_32FC1) 31 if (data.type() != CV_32FC1)
32 qFatal("Expected single channel floating point training data."); 32 qFatal("Expected single channel floating point training data.");
33 33
34 - CvSVMParams params;  
35 - params.kernel_type = kernel;  
36 - params.svm_type = type;  
37 - params.p = 0.1;  
38 - params.nu = 0.5;  
39 - params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON);  
40 -  
41 - if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { 34 + if ((C == -1) || ((gamma == -1) && (params.kernel_type == CvSVM::RBF))) {
42 try { 35 try {
43 svm.train_auto(data, lab, Mat(), Mat(), params, folds, 36 svm.train_auto(data, lab, Mat(), Mat(), params, folds,
44 CvSVM::get_default_grid(CvSVM::C), 37 CvSVM::get_default_grid(CvSVM::C),
@@ -83,6 +76,7 @@ class SVMTransform : public Transform @@ -83,6 +76,7 @@ class SVMTransform : public Transform
83 Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) 76 Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false)
84 Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) 77 Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false)
85 Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) 78 Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false)
  79 + Q_PROPERTY(int degree READ get_degree WRITE set_degree RESET reset_degree STORED false)
86 80
87 public: 81 public:
88 enum Kernel { Linear = CvSVM::LINEAR, 82 enum Kernel { Linear = CvSVM::LINEAR,
@@ -107,6 +101,7 @@ private: @@ -107,6 +101,7 @@ private:
107 BR_PROPERTY(int, termCriteria, 1000) 101 BR_PROPERTY(int, termCriteria, 1000)
108 BR_PROPERTY(int, folds, 5) 102 BR_PROPERTY(int, folds, 5)
109 BR_PROPERTY(bool, balanceFolds, false) 103 BR_PROPERTY(bool, balanceFolds, false)
  104 + BR_PROPERTY(int, degree, 2)
110 105
111 SVM svm; 106 SVM svm;
112 QHash<QString, int> labelMap; 107 QHash<QString, int> labelMap;
@@ -127,8 +122,16 @@ private: @@ -127,8 +122,16 @@ private:
127 QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); 122 QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup);
128 lab = OpenCVUtils::toMat(dataLabels); 123 lab = OpenCVUtils::toMat(dataLabels);
129 } 124 }
130 -  
131 - trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria); 125 +
  126 + CvSVMParams params;
  127 + params.kernel_type = kernel;
  128 + params.svm_type = type;
  129 + params.p = 0.1;
  130 + params.nu = 0.5;
  131 + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON);
  132 + params.degree = degree;
  133 +
  134 + trainSVM(svm, data, lab, params, C, gamma, folds, balanceFolds);
132 } 135 }
133 136
134 void project(const Template &src, Template &dst) const 137 void project(const Template &src, Template &dst) const
@@ -192,6 +195,7 @@ class SVMDistance : public Distance @@ -192,6 +195,7 @@ class SVMDistance : public Distance
192 Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) 195 Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false)
193 Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) 196 Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false)
194 Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) 197 Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false)
  198 + Q_PROPERTY(int degree READ get_degree WRITE set_degree RESET reset_degree STORED false)
195 199
196 public: 200 public:
197 enum Kernel { Linear = CvSVM::LINEAR, 201 enum Kernel { Linear = CvSVM::LINEAR,
@@ -212,6 +216,7 @@ private: @@ -212,6 +216,7 @@ private:
212 BR_PROPERTY(int, termCriteria, 1000) 216 BR_PROPERTY(int, termCriteria, 1000)
213 BR_PROPERTY(int, folds, 5) 217 BR_PROPERTY(int, folds, 5)
214 BR_PROPERTY(bool, balanceFolds, false) 218 BR_PROPERTY(bool, balanceFolds, false)
  219 + BR_PROPERTY(int, degree, 2)
215 220
216 SVM svm; 221 SVM svm;
217 222
@@ -235,8 +240,16 @@ private: @@ -235,8 +240,16 @@ private:
235 } 240 }
236 deltaData = deltaData.rowRange(0, index); 241 deltaData = deltaData.rowRange(0, index);
237 deltaLab = deltaLab.rowRange(0, index); 242 deltaLab = deltaLab.rowRange(0, index);
238 -  
239 - trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria); 243 +
  244 + CvSVMParams params;
  245 + params.kernel_type = kernel;
  246 + params.svm_type = type;
  247 + params.p = 0.1;
  248 + params.nu = 0.5;
  249 + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON);
  250 + params.degree = degree;
  251 +
  252 + trainSVM(svm, deltaData, deltaLab, params, -1, -1, folds, balanceFolds);
240 } 253 }
241 254
242 float compare(const Mat &a, const Mat &b) const 255 float compare(const Mat &a, const Mat &b) const