Commit b95668200403b4ecd450b5658749a771f6abf783
Merge branch 'master' of https://github.com/biometrics/openbr into flip_rotate
Showing
1 changed file
with
26 additions
and
6 deletions
openbr/plugins/svm.cpp
| ... | ... | @@ -59,7 +59,7 @@ static void loadSVM(SVM &svm, QDataStream &stream) |
| 59 | 59 | svm.load(qPrintable(tempFile.fileName())); |
| 60 | 60 | } |
| 61 | 61 | |
| 62 | -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma) | |
| 62 | +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria) | |
| 63 | 63 | { |
| 64 | 64 | if (data.type() != CV_32FC1) |
| 65 | 65 | qFatal("Expected single channel floating point training data."); |
| ... | ... | @@ -69,9 +69,18 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, |
| 69 | 69 | params.svm_type = type; |
| 70 | 70 | params.p = 0.1; |
| 71 | 71 | params.nu = 0.5; |
| 72 | + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); | |
| 73 | + | |
| 72 | 74 | if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { |
| 73 | 75 | try { |
| 74 | - svm.train_auto(data, lab, Mat(), Mat(), params, 5); | |
| 76 | + svm.train_auto(data, lab, Mat(), Mat(), params, folds, | |
| 77 | + CvSVM::get_default_grid(CvSVM::C), | |
| 78 | + CvSVM::get_default_grid(CvSVM::GAMMA), | |
| 79 | + CvSVM::get_default_grid(CvSVM::P), | |
| 80 | + CvSVM::get_default_grid(CvSVM::NU), | |
| 81 | + CvSVM::get_default_grid(CvSVM::COEF), | |
| 82 | + CvSVM::get_default_grid(CvSVM::DEGREE), | |
| 83 | + balanceFolds); | |
| 75 | 84 | } catch (...) { |
| 76 | 85 | qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification."); |
| 77 | 86 | svm.train(data, lab, Mat(), Mat(), params); |
| ... | ... | @@ -104,6 +113,9 @@ class SVMTransform : public Transform |
| 104 | 113 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 105 | 114 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) |
| 106 | 115 | Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) |
| 116 | + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) | |
| 117 | + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | |
| 118 | + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) | |
| 107 | 119 | |
| 108 | 120 | public: |
| 109 | 121 | enum Kernel { Linear = CvSVM::LINEAR, |
| ... | ... | @@ -125,7 +137,9 @@ private: |
| 125 | 137 | BR_PROPERTY(QString, inputVariable, "Label") |
| 126 | 138 | BR_PROPERTY(QString, outputVariable, "") |
| 127 | 139 | BR_PROPERTY(bool, returnDFVal, false) |
| 128 | - | |
| 140 | + BR_PROPERTY(int, termCriteria, 1000) | |
| 141 | + BR_PROPERTY(int, folds, 5) | |
| 142 | + BR_PROPERTY(bool, balanceFolds, false) | |
| 129 | 143 | |
| 130 | 144 | SVM svm; |
| 131 | 145 | QHash<QString, int> labelMap; |
| ... | ... | @@ -146,7 +160,8 @@ private: |
| 146 | 160 | QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); |
| 147 | 161 | lab = OpenCVUtils::toMat(dataLabels); |
| 148 | 162 | } |
| 149 | - trainSVM(svm, data, lab, kernel, type, C, gamma); | |
| 163 | + | |
| 164 | + trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria); | |
| 150 | 165 | } |
| 151 | 166 | |
| 152 | 167 | void project(const Template &src, Template &dst) const |
| ... | ... | @@ -207,7 +222,9 @@ class SVMDistance : public Distance |
| 207 | 222 | Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) |
| 208 | 223 | Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) |
| 209 | 224 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 210 | - | |
| 225 | + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) | |
| 226 | + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | |
| 227 | + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) | |
| 211 | 228 | |
| 212 | 229 | public: |
| 213 | 230 | enum Kernel { Linear = CvSVM::LINEAR, |
| ... | ... | @@ -225,6 +242,9 @@ private: |
| 225 | 242 | BR_PROPERTY(Kernel, kernel, Linear) |
| 226 | 243 | BR_PROPERTY(Type, type, EPS_SVR) |
| 227 | 244 | BR_PROPERTY(QString, inputVariable, "Label") |
| 245 | + BR_PROPERTY(int, termCriteria, 1000) | |
| 246 | + BR_PROPERTY(int, folds, 5) | |
| 247 | + BR_PROPERTY(bool, balanceFolds, false) | |
| 228 | 248 | |
| 229 | 249 | SVM svm; |
| 230 | 250 | |
| ... | ... | @@ -249,7 +269,7 @@ private: |
| 249 | 269 | deltaData = deltaData.rowRange(0, index); |
| 250 | 270 | deltaLab = deltaLab.rowRange(0, index); |
| 251 | 271 | |
| 252 | - trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1); | |
| 272 | + trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria); | |
| 253 | 273 | } |
| 254 | 274 | |
| 255 | 275 | float compare(const Mat &a, const Mat &b) const | ... | ... |