Commit fb0027a6e60acae160c2b4bc6252b03d0302c268
1 parent
48ca4edd
Updated SVM to allow for setting degree for POLY classification.\n Wrapped some …
…parameters to trainSVM in a CvSVMParams object to cut down on clutter.
Showing
1 changed file
with
26 additions
and
13 deletions
openbr/plugins/classification/svm.cpp
| ... | ... | @@ -26,19 +26,12 @@ using namespace cv; |
| 26 | 26 | namespace br |
| 27 | 27 | { |
| 28 | 28 | |
| 29 | -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria) | |
| 29 | +static void trainSVM(SVM &svm, Mat data, Mat lab, CvSVMParams params, float C, float gamma, int folds, bool balanceFolds) | |
| 30 | 30 | { |
| 31 | 31 | if (data.type() != CV_32FC1) |
| 32 | 32 | qFatal("Expected single channel floating point training data."); |
| 33 | 33 | |
| 34 | - CvSVMParams params; | |
| 35 | - params.kernel_type = kernel; | |
| 36 | - params.svm_type = type; | |
| 37 | - params.p = 0.1; | |
| 38 | - params.nu = 0.5; | |
| 39 | - params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); | |
| 40 | - | |
| 41 | - if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { | |
| 34 | + if ((C == -1) || ((gamma == -1) && (params.kernel_type == CvSVM::RBF))) { | |
| 42 | 35 | try { |
| 43 | 36 | svm.train_auto(data, lab, Mat(), Mat(), params, folds, |
| 44 | 37 | CvSVM::get_default_grid(CvSVM::C), |
| ... | ... | @@ -83,6 +76,7 @@ class SVMTransform : public Transform |
| 83 | 76 | Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) |
| 84 | 77 | Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) |
| 85 | 78 | Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) |
| 79 | + Q_PROPERTY(int degree READ get_degree WRITE set_degree RESET reset_degree STORED false) | |
| 86 | 80 | |
| 87 | 81 | public: |
| 88 | 82 | enum Kernel { Linear = CvSVM::LINEAR, |
| ... | ... | @@ -107,6 +101,7 @@ private: |
| 107 | 101 | BR_PROPERTY(int, termCriteria, 1000) |
| 108 | 102 | BR_PROPERTY(int, folds, 5) |
| 109 | 103 | BR_PROPERTY(bool, balanceFolds, false) |
| 104 | + BR_PROPERTY(int, degree, 2) | |
| 110 | 105 | |
| 111 | 106 | SVM svm; |
| 112 | 107 | QHash<QString, int> labelMap; |
| ... | ... | @@ -127,8 +122,16 @@ private: |
| 127 | 122 | QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); |
| 128 | 123 | lab = OpenCVUtils::toMat(dataLabels); |
| 129 | 124 | } |
| 130 | - | |
| 131 | - trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria); | |
| 125 | + | |
| 126 | + CvSVMParams params; | |
| 127 | + params.kernel_type = kernel; | |
| 128 | + params.svm_type = type; | |
| 129 | + params.p = 0.1; | |
| 130 | + params.nu = 0.5; | |
| 131 | + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); | |
| 132 | + params.degree = degree; | |
| 133 | + | |
| 134 | + trainSVM(svm, data, lab, params, C, gamma, folds, balanceFolds); | |
| 132 | 135 | } |
| 133 | 136 | |
| 134 | 137 | void project(const Template &src, Template &dst) const |
| ... | ... | @@ -192,6 +195,7 @@ class SVMDistance : public Distance |
| 192 | 195 | Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) |
| 193 | 196 | Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) |
| 194 | 197 | Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) |
| 198 | + Q_PROPERTY(int degree READ get_degree WRITE set_degree RESET reset_degree STORED false) | |
| 195 | 199 | |
| 196 | 200 | public: |
| 197 | 201 | enum Kernel { Linear = CvSVM::LINEAR, |
| ... | ... | @@ -212,6 +216,7 @@ private: |
| 212 | 216 | BR_PROPERTY(int, termCriteria, 1000) |
| 213 | 217 | BR_PROPERTY(int, folds, 5) |
| 214 | 218 | BR_PROPERTY(bool, balanceFolds, false) |
| 219 | + BR_PROPERTY(int, degree, 2) | |
| 215 | 220 | |
| 216 | 221 | SVM svm; |
| 217 | 222 | |
| ... | ... | @@ -235,8 +240,16 @@ private: |
| 235 | 240 | } |
| 236 | 241 | deltaData = deltaData.rowRange(0, index); |
| 237 | 242 | deltaLab = deltaLab.rowRange(0, index); |
| 238 | - | |
| 239 | - trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria); | |
| 243 | + | |
| 244 | + CvSVMParams params; | |
| 245 | + params.kernel_type = kernel; | |
| 246 | + params.svm_type = type; | |
| 247 | + params.p = 0.1; | |
| 248 | + params.nu = 0.5; | |
| 249 | + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); | |
| 250 | + params.degree = degree; | |
| 251 | + | |
| 252 | + trainSVM(svm, deltaData, deltaLab, params, -1, -1, folds, balanceFolds); | |
| 240 | 253 | } |
| 241 | 254 | |
| 242 | 255 | float compare(const Mat &a, const Mat &b) const | ... | ... |