Commit fb0027a6e60acae160c2b4bc6252b03d0302c268
1 parent
48ca4edd
Updated SVM to allow for setting degree for POLY classification.\n Wrapped some …
…parameters to trainSVM in a CvSVMParams object to cut down on clutter.
Showing
1 changed file
with
26 additions
and
13 deletions
openbr/plugins/classification/svm.cpp
| @@ -26,19 +26,12 @@ using namespace cv; | @@ -26,19 +26,12 @@ using namespace cv; | ||
| 26 | namespace br | 26 | namespace br |
| 27 | { | 27 | { |
| 28 | 28 | ||
| 29 | -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria) | 29 | +static void trainSVM(SVM &svm, Mat data, Mat lab, CvSVMParams params, float C, float gamma, int folds, bool balanceFolds) |
| 30 | { | 30 | { |
| 31 | if (data.type() != CV_32FC1) | 31 | if (data.type() != CV_32FC1) |
| 32 | qFatal("Expected single channel floating point training data."); | 32 | qFatal("Expected single channel floating point training data."); |
| 33 | 33 | ||
| 34 | - CvSVMParams params; | ||
| 35 | - params.kernel_type = kernel; | ||
| 36 | - params.svm_type = type; | ||
| 37 | - params.p = 0.1; | ||
| 38 | - params.nu = 0.5; | ||
| 39 | - params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); | ||
| 40 | - | ||
| 41 | - if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { | 34 | + if ((C == -1) || ((gamma == -1) && (params.kernel_type == CvSVM::RBF))) { |
| 42 | try { | 35 | try { |
| 43 | svm.train_auto(data, lab, Mat(), Mat(), params, folds, | 36 | svm.train_auto(data, lab, Mat(), Mat(), params, folds, |
| 44 | CvSVM::get_default_grid(CvSVM::C), | 37 | CvSVM::get_default_grid(CvSVM::C), |
| @@ -83,6 +76,7 @@ class SVMTransform : public Transform | @@ -83,6 +76,7 @@ class SVMTransform : public Transform | ||
| 83 | Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) | 76 | Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) |
| 84 | Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | 77 | Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) |
| 85 | Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) | 78 | Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) |
| 79 | + Q_PROPERTY(int degree READ get_degree WRITE set_degree RESET reset_degree STORED false) | ||
| 86 | 80 | ||
| 87 | public: | 81 | public: |
| 88 | enum Kernel { Linear = CvSVM::LINEAR, | 82 | enum Kernel { Linear = CvSVM::LINEAR, |
| @@ -107,6 +101,7 @@ private: | @@ -107,6 +101,7 @@ private: | ||
| 107 | BR_PROPERTY(int, termCriteria, 1000) | 101 | BR_PROPERTY(int, termCriteria, 1000) |
| 108 | BR_PROPERTY(int, folds, 5) | 102 | BR_PROPERTY(int, folds, 5) |
| 109 | BR_PROPERTY(bool, balanceFolds, false) | 103 | BR_PROPERTY(bool, balanceFolds, false) |
| 104 | + BR_PROPERTY(int, degree, 2) | ||
| 110 | 105 | ||
| 111 | SVM svm; | 106 | SVM svm; |
| 112 | QHash<QString, int> labelMap; | 107 | QHash<QString, int> labelMap; |
| @@ -127,8 +122,16 @@ private: | @@ -127,8 +122,16 @@ private: | ||
| 127 | QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); | 122 | QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); |
| 128 | lab = OpenCVUtils::toMat(dataLabels); | 123 | lab = OpenCVUtils::toMat(dataLabels); |
| 129 | } | 124 | } |
| 130 | - | ||
| 131 | - trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria); | 125 | + |
| 126 | + CvSVMParams params; | ||
| 127 | + params.kernel_type = kernel; | ||
| 128 | + params.svm_type = type; | ||
| 129 | + params.p = 0.1; | ||
| 130 | + params.nu = 0.5; | ||
| 131 | + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); | ||
| 132 | + params.degree = degree; | ||
| 133 | + | ||
| 134 | + trainSVM(svm, data, lab, params, C, gamma, folds, balanceFolds); | ||
| 132 | } | 135 | } |
| 133 | 136 | ||
| 134 | void project(const Template &src, Template &dst) const | 137 | void project(const Template &src, Template &dst) const |
| @@ -192,6 +195,7 @@ class SVMDistance : public Distance | @@ -192,6 +195,7 @@ class SVMDistance : public Distance | ||
| 192 | Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) | 195 | Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) |
| 193 | Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | 196 | Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) |
| 194 | Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) | 197 | Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) |
| 198 | + Q_PROPERTY(int degree READ get_degree WRITE set_degree RESET reset_degree STORED false) | ||
| 195 | 199 | ||
| 196 | public: | 200 | public: |
| 197 | enum Kernel { Linear = CvSVM::LINEAR, | 201 | enum Kernel { Linear = CvSVM::LINEAR, |
| @@ -212,6 +216,7 @@ private: | @@ -212,6 +216,7 @@ private: | ||
| 212 | BR_PROPERTY(int, termCriteria, 1000) | 216 | BR_PROPERTY(int, termCriteria, 1000) |
| 213 | BR_PROPERTY(int, folds, 5) | 217 | BR_PROPERTY(int, folds, 5) |
| 214 | BR_PROPERTY(bool, balanceFolds, false) | 218 | BR_PROPERTY(bool, balanceFolds, false) |
| 219 | + BR_PROPERTY(int, degree, 2) | ||
| 215 | 220 | ||
| 216 | SVM svm; | 221 | SVM svm; |
| 217 | 222 | ||
| @@ -235,8 +240,16 @@ private: | @@ -235,8 +240,16 @@ private: | ||
| 235 | } | 240 | } |
| 236 | deltaData = deltaData.rowRange(0, index); | 241 | deltaData = deltaData.rowRange(0, index); |
| 237 | deltaLab = deltaLab.rowRange(0, index); | 242 | deltaLab = deltaLab.rowRange(0, index); |
| 238 | - | ||
| 239 | - trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria); | 243 | + |
| 244 | + CvSVMParams params; | ||
| 245 | + params.kernel_type = kernel; | ||
| 246 | + params.svm_type = type; | ||
| 247 | + params.p = 0.1; | ||
| 248 | + params.nu = 0.5; | ||
| 249 | + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); | ||
| 250 | + params.degree = degree; | ||
| 251 | + | ||
| 252 | + trainSVM(svm, deltaData, deltaLab, params, -1, -1, folds, balanceFolds); | ||
| 240 | } | 253 | } |
| 241 | 254 | ||
| 242 | float compare(const Mat &a, const Mat &b) const | 255 | float compare(const Mat &a, const Mat &b) const |