Commit fb0027a6e60acae160c2b4bc6252b03d0302c268

Authored by dgcrouse
1 parent 48ca4edd

Updated SVM to allow for setting degree for POLY classification.\n Wrapped some …

…parameters to trainSVM in a CvSVMParams object to cut down on clutter.
openbr/plugins/classification/svm.cpp
... ... @@ -26,19 +26,12 @@ using namespace cv;
26 26 namespace br
27 27 {
28 28  
29   -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria)
  29 +static void trainSVM(SVM &svm, Mat data, Mat lab, CvSVMParams params, float C, float gamma, int folds, bool balanceFolds)
30 30 {
31 31 if (data.type() != CV_32FC1)
32 32 qFatal("Expected single channel floating point training data.");
33 33  
34   - CvSVMParams params;
35   - params.kernel_type = kernel;
36   - params.svm_type = type;
37   - params.p = 0.1;
38   - params.nu = 0.5;
39   - params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON);
40   -
41   - if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) {
  34 + if ((C == -1) || ((gamma == -1) && (params.kernel_type == CvSVM::RBF))) {
42 35 try {
43 36 svm.train_auto(data, lab, Mat(), Mat(), params, folds,
44 37 CvSVM::get_default_grid(CvSVM::C),
... ... @@ -83,6 +76,7 @@ class SVMTransform : public Transform
83 76 Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false)
84 77 Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false)
85 78 Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false)
  79 + Q_PROPERTY(int degree READ get_degree WRITE set_degree RESET reset_degree STORED false)
86 80  
87 81 public:
88 82 enum Kernel { Linear = CvSVM::LINEAR,
... ... @@ -107,6 +101,7 @@ private:
107 101 BR_PROPERTY(int, termCriteria, 1000)
108 102 BR_PROPERTY(int, folds, 5)
109 103 BR_PROPERTY(bool, balanceFolds, false)
  104 + BR_PROPERTY(int, degree, 2)
110 105  
111 106 SVM svm;
112 107 QHash<QString, int> labelMap;
... ... @@ -127,8 +122,16 @@ private:
127 122 QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup);
128 123 lab = OpenCVUtils::toMat(dataLabels);
129 124 }
130   -
131   - trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria);
  125 +
  126 + CvSVMParams params;
  127 + params.kernel_type = kernel;
  128 + params.svm_type = type;
  129 + params.p = 0.1;
  130 + params.nu = 0.5;
  131 + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON);
  132 + params.degree = degree;
  133 +
  134 + trainSVM(svm, data, lab, params, C, gamma, folds, balanceFolds);
132 135 }
133 136  
134 137 void project(const Template &src, Template &dst) const
... ... @@ -192,6 +195,7 @@ class SVMDistance : public Distance
192 195 Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false)
193 196 Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false)
194 197 Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false)
  198 + Q_PROPERTY(int degree READ get_degree WRITE set_degree RESET reset_degree STORED false)
195 199  
196 200 public:
197 201 enum Kernel { Linear = CvSVM::LINEAR,
... ... @@ -212,6 +216,7 @@ private:
212 216 BR_PROPERTY(int, termCriteria, 1000)
213 217 BR_PROPERTY(int, folds, 5)
214 218 BR_PROPERTY(bool, balanceFolds, false)
  219 + BR_PROPERTY(int, degree, 2)
215 220  
216 221 SVM svm;
217 222  
... ... @@ -235,8 +240,16 @@ private:
235 240 }
236 241 deltaData = deltaData.rowRange(0, index);
237 242 deltaLab = deltaLab.rowRange(0, index);
238   -
239   - trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria);
  243 +
  244 + CvSVMParams params;
  245 + params.kernel_type = kernel;
  246 + params.svm_type = type;
  247 + params.p = 0.1;
  248 + params.nu = 0.5;
  249 + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON);
  250 + params.degree = degree;
  251 +
  252 + trainSVM(svm, deltaData, deltaLab, params, -1, -1, folds, balanceFolds);
240 253 }
241 254  
242 255 float compare(const Mat &a, const Mat &b) const
... ...