Commit b95668200403b4ecd450b5658749a771f6abf783

Authored by Scott Klum
2 parents a65a90b2 b0b6e5cd

Merge branch 'master' of https://github.com/biometrics/openbr into flip_rotate

Showing 1 changed file with 26 additions and 6 deletions
openbr/plugins/svm.cpp
@@ -59,7 +59,7 @@ static void loadSVM(SVM &svm, QDataStream &stream) @@ -59,7 +59,7 @@ static void loadSVM(SVM &svm, QDataStream &stream)
59 svm.load(qPrintable(tempFile.fileName())); 59 svm.load(qPrintable(tempFile.fileName()));
60 } 60 }
61 61
62 -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma) 62 +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria)
63 { 63 {
64 if (data.type() != CV_32FC1) 64 if (data.type() != CV_32FC1)
65 qFatal("Expected single channel floating point training data."); 65 qFatal("Expected single channel floating point training data.");
@@ -69,9 +69,18 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, @@ -69,9 +69,18 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C,
69 params.svm_type = type; 69 params.svm_type = type;
70 params.p = 0.1; 70 params.p = 0.1;
71 params.nu = 0.5; 71 params.nu = 0.5;
  72 + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON);
  73 +
72 if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { 74 if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) {
73 try { 75 try {
74 - svm.train_auto(data, lab, Mat(), Mat(), params, 5); 76 + svm.train_auto(data, lab, Mat(), Mat(), params, folds,
  77 + CvSVM::get_default_grid(CvSVM::C),
  78 + CvSVM::get_default_grid(CvSVM::GAMMA),
  79 + CvSVM::get_default_grid(CvSVM::P),
  80 + CvSVM::get_default_grid(CvSVM::NU),
  81 + CvSVM::get_default_grid(CvSVM::COEF),
  82 + CvSVM::get_default_grid(CvSVM::DEGREE),
  83 + balanceFolds);
75 } catch (...) { 84 } catch (...) {
76 qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification."); 85 qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification.");
77 svm.train(data, lab, Mat(), Mat(), params); 86 svm.train(data, lab, Mat(), Mat(), params);
@@ -104,6 +113,9 @@ class SVMTransform : public Transform @@ -104,6 +113,9 @@ class SVMTransform : public Transform
104 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) 113 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
105 Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) 114 Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
106 Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) 115 Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false)
  116 + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false)
  117 + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false)
  118 + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false)
107 119
108 public: 120 public:
109 enum Kernel { Linear = CvSVM::LINEAR, 121 enum Kernel { Linear = CvSVM::LINEAR,
@@ -125,7 +137,9 @@ private: @@ -125,7 +137,9 @@ private:
125 BR_PROPERTY(QString, inputVariable, "Label") 137 BR_PROPERTY(QString, inputVariable, "Label")
126 BR_PROPERTY(QString, outputVariable, "") 138 BR_PROPERTY(QString, outputVariable, "")
127 BR_PROPERTY(bool, returnDFVal, false) 139 BR_PROPERTY(bool, returnDFVal, false)
128 - 140 + BR_PROPERTY(int, termCriteria, 1000)
  141 + BR_PROPERTY(int, folds, 5)
  142 + BR_PROPERTY(bool, balanceFolds, false)
129 143
130 SVM svm; 144 SVM svm;
131 QHash<QString, int> labelMap; 145 QHash<QString, int> labelMap;
@@ -146,7 +160,8 @@ private: @@ -146,7 +160,8 @@ private:
146 QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); 160 QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup);
147 lab = OpenCVUtils::toMat(dataLabels); 161 lab = OpenCVUtils::toMat(dataLabels);
148 } 162 }
149 - trainSVM(svm, data, lab, kernel, type, C, gamma); 163 +
  164 + trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria);
150 } 165 }
151 166
152 void project(const Template &src, Template &dst) const 167 void project(const Template &src, Template &dst) const
@@ -207,7 +222,9 @@ class SVMDistance : public Distance @@ -207,7 +222,9 @@ class SVMDistance : public Distance
207 Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) 222 Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false)
208 Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) 223 Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false)
209 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) 224 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
210 - 225 + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false)
  226 + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false)
  227 + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false)
211 228
212 public: 229 public:
213 enum Kernel { Linear = CvSVM::LINEAR, 230 enum Kernel { Linear = CvSVM::LINEAR,
@@ -225,6 +242,9 @@ private: @@ -225,6 +242,9 @@ private:
225 BR_PROPERTY(Kernel, kernel, Linear) 242 BR_PROPERTY(Kernel, kernel, Linear)
226 BR_PROPERTY(Type, type, EPS_SVR) 243 BR_PROPERTY(Type, type, EPS_SVR)
227 BR_PROPERTY(QString, inputVariable, "Label") 244 BR_PROPERTY(QString, inputVariable, "Label")
  245 + BR_PROPERTY(int, termCriteria, 1000)
  246 + BR_PROPERTY(int, folds, 5)
  247 + BR_PROPERTY(bool, balanceFolds, false)
228 248
229 SVM svm; 249 SVM svm;
230 250
@@ -249,7 +269,7 @@ private: @@ -249,7 +269,7 @@ private:
249 deltaData = deltaData.rowRange(0, index); 269 deltaData = deltaData.rowRange(0, index);
250 deltaLab = deltaLab.rowRange(0, index); 270 deltaLab = deltaLab.rowRange(0, index);
251 271
252 - trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1); 272 + trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria);
253 } 273 }
254 274
255 float compare(const Mat &a, const Mat &b) const 275 float compare(const Mat &a, const Mat &b) const