Commit 3b38780bec9da8c964048e301aa02e533f7d0a25
1 parent
fb0027a6
Revert "Updated SVM to allow for setting degree for POLY classification.\n Wrapp…
…ed some parameters to trainSVM in a CvSVMParams object to cut down on clutter." This reverts commit fb0027a6e60acae160c2b4bc6252b03d0302c268.
Showing
1 changed file
with
13 additions
and
26 deletions
openbr/plugins/classification/svm.cpp
| @@ -26,12 +26,19 @@ using namespace cv; | @@ -26,12 +26,19 @@ using namespace cv; | ||
| 26 | namespace br | 26 | namespace br |
| 27 | { | 27 | { |
| 28 | 28 | ||
| 29 | -static void trainSVM(SVM &svm, Mat data, Mat lab, CvSVMParams params, float C, float gamma, int folds, bool balanceFolds) | 29 | +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria) |
| 30 | { | 30 | { |
| 31 | if (data.type() != CV_32FC1) | 31 | if (data.type() != CV_32FC1) |
| 32 | qFatal("Expected single channel floating point training data."); | 32 | qFatal("Expected single channel floating point training data."); |
| 33 | 33 | ||
| 34 | - if ((C == -1) || ((gamma == -1) && (params.kernel_type == CvSVM::RBF))) { | 34 | + CvSVMParams params; |
| 35 | + params.kernel_type = kernel; | ||
| 36 | + params.svm_type = type; | ||
| 37 | + params.p = 0.1; | ||
| 38 | + params.nu = 0.5; | ||
| 39 | + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); | ||
| 40 | + | ||
| 41 | + if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { | ||
| 35 | try { | 42 | try { |
| 36 | svm.train_auto(data, lab, Mat(), Mat(), params, folds, | 43 | svm.train_auto(data, lab, Mat(), Mat(), params, folds, |
| 37 | CvSVM::get_default_grid(CvSVM::C), | 44 | CvSVM::get_default_grid(CvSVM::C), |
| @@ -76,7 +83,6 @@ class SVMTransform : public Transform | @@ -76,7 +83,6 @@ class SVMTransform : public Transform | ||
| 76 | Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) | 83 | Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) |
| 77 | Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | 84 | Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) |
| 78 | Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) | 85 | Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) |
| 79 | - Q_PROPERTY(int degree READ get_degree WRITE set_degree RESET reset_degree STORED false) | ||
| 80 | 86 | ||
| 81 | public: | 87 | public: |
| 82 | enum Kernel { Linear = CvSVM::LINEAR, | 88 | enum Kernel { Linear = CvSVM::LINEAR, |
| @@ -101,7 +107,6 @@ private: | @@ -101,7 +107,6 @@ private: | ||
| 101 | BR_PROPERTY(int, termCriteria, 1000) | 107 | BR_PROPERTY(int, termCriteria, 1000) |
| 102 | BR_PROPERTY(int, folds, 5) | 108 | BR_PROPERTY(int, folds, 5) |
| 103 | BR_PROPERTY(bool, balanceFolds, false) | 109 | BR_PROPERTY(bool, balanceFolds, false) |
| 104 | - BR_PROPERTY(int, degree, 2) | ||
| 105 | 110 | ||
| 106 | SVM svm; | 111 | SVM svm; |
| 107 | QHash<QString, int> labelMap; | 112 | QHash<QString, int> labelMap; |
| @@ -122,16 +127,8 @@ private: | @@ -122,16 +127,8 @@ private: | ||
| 122 | QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); | 127 | QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); |
| 123 | lab = OpenCVUtils::toMat(dataLabels); | 128 | lab = OpenCVUtils::toMat(dataLabels); |
| 124 | } | 129 | } |
| 125 | - | ||
| 126 | - CvSVMParams params; | ||
| 127 | - params.kernel_type = kernel; | ||
| 128 | - params.svm_type = type; | ||
| 129 | - params.p = 0.1; | ||
| 130 | - params.nu = 0.5; | ||
| 131 | - params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); | ||
| 132 | - params.degree = degree; | ||
| 133 | - | ||
| 134 | - trainSVM(svm, data, lab, params, C, gamma, folds, balanceFolds); | 130 | + |
| 131 | + trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria); | ||
| 135 | } | 132 | } |
| 136 | 133 | ||
| 137 | void project(const Template &src, Template &dst) const | 134 | void project(const Template &src, Template &dst) const |
| @@ -195,7 +192,6 @@ class SVMDistance : public Distance | @@ -195,7 +192,6 @@ class SVMDistance : public Distance | ||
| 195 | Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) | 192 | Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) |
| 196 | Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | 193 | Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) |
| 197 | Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) | 194 | Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) |
| 198 | - Q_PROPERTY(int degree READ get_degree WRITE set_degree RESET reset_degree STORED false) | ||
| 199 | 195 | ||
| 200 | public: | 196 | public: |
| 201 | enum Kernel { Linear = CvSVM::LINEAR, | 197 | enum Kernel { Linear = CvSVM::LINEAR, |
| @@ -216,7 +212,6 @@ private: | @@ -216,7 +212,6 @@ private: | ||
| 216 | BR_PROPERTY(int, termCriteria, 1000) | 212 | BR_PROPERTY(int, termCriteria, 1000) |
| 217 | BR_PROPERTY(int, folds, 5) | 213 | BR_PROPERTY(int, folds, 5) |
| 218 | BR_PROPERTY(bool, balanceFolds, false) | 214 | BR_PROPERTY(bool, balanceFolds, false) |
| 219 | - BR_PROPERTY(int, degree, 2) | ||
| 220 | 215 | ||
| 221 | SVM svm; | 216 | SVM svm; |
| 222 | 217 | ||
| @@ -240,16 +235,8 @@ private: | @@ -240,16 +235,8 @@ private: | ||
| 240 | } | 235 | } |
| 241 | deltaData = deltaData.rowRange(0, index); | 236 | deltaData = deltaData.rowRange(0, index); |
| 242 | deltaLab = deltaLab.rowRange(0, index); | 237 | deltaLab = deltaLab.rowRange(0, index); |
| 243 | - | ||
| 244 | - CvSVMParams params; | ||
| 245 | - params.kernel_type = kernel; | ||
| 246 | - params.svm_type = type; | ||
| 247 | - params.p = 0.1; | ||
| 248 | - params.nu = 0.5; | ||
| 249 | - params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); | ||
| 250 | - params.degree = degree; | ||
| 251 | - | ||
| 252 | - trainSVM(svm, deltaData, deltaLab, params, -1, -1, folds, balanceFolds); | 238 | + |
| 239 | + trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria); | ||
| 253 | } | 240 | } |
| 254 | 241 | ||
| 255 | float compare(const Mat &a, const Mat &b) const | 242 | float compare(const Mat &a, const Mat &b) const |