Commit b95668200403b4ecd450b5658749a771f6abf783
Merge branch 'master' of https://github.com/biometrics/openbr into flip_rotate
Showing
1 changed file
with
26 additions
and
6 deletions
openbr/plugins/svm.cpp
| @@ -59,7 +59,7 @@ static void loadSVM(SVM &svm, QDataStream &stream) | @@ -59,7 +59,7 @@ static void loadSVM(SVM &svm, QDataStream &stream) | ||
| 59 | svm.load(qPrintable(tempFile.fileName())); | 59 | svm.load(qPrintable(tempFile.fileName())); |
| 60 | } | 60 | } |
| 61 | 61 | ||
| 62 | -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma) | 62 | +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria) |
| 63 | { | 63 | { |
| 64 | if (data.type() != CV_32FC1) | 64 | if (data.type() != CV_32FC1) |
| 65 | qFatal("Expected single channel floating point training data."); | 65 | qFatal("Expected single channel floating point training data."); |
| @@ -69,9 +69,18 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, | @@ -69,9 +69,18 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, | ||
| 69 | params.svm_type = type; | 69 | params.svm_type = type; |
| 70 | params.p = 0.1; | 70 | params.p = 0.1; |
| 71 | params.nu = 0.5; | 71 | params.nu = 0.5; |
| 72 | + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); | ||
| 73 | + | ||
| 72 | if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { | 74 | if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { |
| 73 | try { | 75 | try { |
| 74 | - svm.train_auto(data, lab, Mat(), Mat(), params, 5); | 76 | + svm.train_auto(data, lab, Mat(), Mat(), params, folds, |
| 77 | + CvSVM::get_default_grid(CvSVM::C), | ||
| 78 | + CvSVM::get_default_grid(CvSVM::GAMMA), | ||
| 79 | + CvSVM::get_default_grid(CvSVM::P), | ||
| 80 | + CvSVM::get_default_grid(CvSVM::NU), | ||
| 81 | + CvSVM::get_default_grid(CvSVM::COEF), | ||
| 82 | + CvSVM::get_default_grid(CvSVM::DEGREE), | ||
| 83 | + balanceFolds); | ||
| 75 | } catch (...) { | 84 | } catch (...) { |
| 76 | qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification."); | 85 | qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification."); |
| 77 | svm.train(data, lab, Mat(), Mat(), params); | 86 | svm.train(data, lab, Mat(), Mat(), params); |
| @@ -104,6 +113,9 @@ class SVMTransform : public Transform | @@ -104,6 +113,9 @@ class SVMTransform : public Transform | ||
| 104 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | 113 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 105 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | 114 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) |
| 106 | Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) | 115 | Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) |
| 116 | + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) | ||
| 117 | + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | ||
| 118 | + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) | ||
| 107 | 119 | ||
| 108 | public: | 120 | public: |
| 109 | enum Kernel { Linear = CvSVM::LINEAR, | 121 | enum Kernel { Linear = CvSVM::LINEAR, |
| @@ -125,7 +137,9 @@ private: | @@ -125,7 +137,9 @@ private: | ||
| 125 | BR_PROPERTY(QString, inputVariable, "Label") | 137 | BR_PROPERTY(QString, inputVariable, "Label") |
| 126 | BR_PROPERTY(QString, outputVariable, "") | 138 | BR_PROPERTY(QString, outputVariable, "") |
| 127 | BR_PROPERTY(bool, returnDFVal, false) | 139 | BR_PROPERTY(bool, returnDFVal, false) |
| 128 | - | 140 | + BR_PROPERTY(int, termCriteria, 1000) |
| 141 | + BR_PROPERTY(int, folds, 5) | ||
| 142 | + BR_PROPERTY(bool, balanceFolds, false) | ||
| 129 | 143 | ||
| 130 | SVM svm; | 144 | SVM svm; |
| 131 | QHash<QString, int> labelMap; | 145 | QHash<QString, int> labelMap; |
| @@ -146,7 +160,8 @@ private: | @@ -146,7 +160,8 @@ private: | ||
| 146 | QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); | 160 | QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); |
| 147 | lab = OpenCVUtils::toMat(dataLabels); | 161 | lab = OpenCVUtils::toMat(dataLabels); |
| 148 | } | 162 | } |
| 149 | - trainSVM(svm, data, lab, kernel, type, C, gamma); | 163 | + |
| 164 | + trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria); | ||
| 150 | } | 165 | } |
| 151 | 166 | ||
| 152 | void project(const Template &src, Template &dst) const | 167 | void project(const Template &src, Template &dst) const |
| @@ -207,7 +222,9 @@ class SVMDistance : public Distance | @@ -207,7 +222,9 @@ class SVMDistance : public Distance | ||
| 207 | Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) | 222 | Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) |
| 208 | Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) | 223 | Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) |
| 209 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | 224 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 210 | - | 225 | + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) |
| 226 | + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | ||
| 227 | + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) | ||
| 211 | 228 | ||
| 212 | public: | 229 | public: |
| 213 | enum Kernel { Linear = CvSVM::LINEAR, | 230 | enum Kernel { Linear = CvSVM::LINEAR, |
| @@ -225,6 +242,9 @@ private: | @@ -225,6 +242,9 @@ private: | ||
| 225 | BR_PROPERTY(Kernel, kernel, Linear) | 242 | BR_PROPERTY(Kernel, kernel, Linear) |
| 226 | BR_PROPERTY(Type, type, EPS_SVR) | 243 | BR_PROPERTY(Type, type, EPS_SVR) |
| 227 | BR_PROPERTY(QString, inputVariable, "Label") | 244 | BR_PROPERTY(QString, inputVariable, "Label") |
| 245 | + BR_PROPERTY(int, termCriteria, 1000) | ||
| 246 | + BR_PROPERTY(int, folds, 5) | ||
| 247 | + BR_PROPERTY(bool, balanceFolds, false) | ||
| 228 | 248 | ||
| 229 | SVM svm; | 249 | SVM svm; |
| 230 | 250 | ||
| @@ -249,7 +269,7 @@ private: | @@ -249,7 +269,7 @@ private: | ||
| 249 | deltaData = deltaData.rowRange(0, index); | 269 | deltaData = deltaData.rowRange(0, index); |
| 250 | deltaLab = deltaLab.rowRange(0, index); | 270 | deltaLab = deltaLab.rowRange(0, index); |
| 251 | 271 | ||
| 252 | - trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1); | 272 | + trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria); |
| 253 | } | 273 | } |
| 254 | 274 | ||
| 255 | float compare(const Mat &a, const Mat &b) const | 275 | float compare(const Mat &a, const Mat &b) const |