Commit 1aed6ade0f84f6d7b819289edffa8236295586bd

Authored by Scott Klum
1 parent 60f93778

Added additional SVM parameters

Showing 1 changed file with 26 additions and 6 deletions
openbr/plugins/svm.cpp
... ... @@ -59,7 +59,7 @@ static void loadSVM(SVM &svm, QDataStream &stream)
59 59 svm.load(qPrintable(tempFile.fileName()));
60 60 }
61 61  
62   -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma)
  62 +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria)
63 63 {
64 64 if (data.type() != CV_32FC1)
65 65 qFatal("Expected single channel floating point training data.");
... ... @@ -69,9 +69,18 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C,
69 69 params.svm_type = type;
70 70 params.p = 0.1;
71 71 params.nu = 0.5;
  72 + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON);
  73 +
72 74 if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) {
73 75 try {
74   - svm.train_auto(data, lab, Mat(), Mat(), params, 5);
  76 + svm.train_auto(data, lab, Mat(), Mat(), params, folds,
  77 + CvSVM::get_default_grid(CvSVM::C),
  78 + CvSVM::get_default_grid(CvSVM::GAMMA),
  79 + CvSVM::get_default_grid(CvSVM::P),
  80 + CvSVM::get_default_grid(CvSVM::NU),
  81 + CvSVM::get_default_grid(CvSVM::COEF),
  82 + CvSVM::get_default_grid(CvSVM::DEGREE),
  83 + balanceFolds);
75 84 } catch (...) {
76 85 qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification.");
77 86 svm.train(data, lab, Mat(), Mat(), params);
... ... @@ -104,6 +113,9 @@ class SVMTransform : public Transform
104 113 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
105 114 Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
106 115 Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false)
  116 + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false)
  117 + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false)
  118 + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false)
107 119  
108 120 public:
109 121 enum Kernel { Linear = CvSVM::LINEAR,
... ... @@ -125,7 +137,9 @@ private:
125 137 BR_PROPERTY(QString, inputVariable, "Label")
126 138 BR_PROPERTY(QString, outputVariable, "")
127 139 BR_PROPERTY(bool, returnDFVal, false)
128   -
  140 + BR_PROPERTY(int, termCriteria, 1000)
  141 + BR_PROPERTY(int, folds, 5)
  142 + BR_PROPERTY(bool, balanceFolds, false)
129 143  
130 144 SVM svm;
131 145 QHash<QString, int> labelMap;
... ... @@ -146,7 +160,8 @@ private:
146 160 QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup);
147 161 lab = OpenCVUtils::toMat(dataLabels);
148 162 }
149   - trainSVM(svm, data, lab, kernel, type, C, gamma);
  163 +
  164 + trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria);
150 165 }
151 166  
152 167 void project(const Template &src, Template &dst) const
... ... @@ -207,7 +222,9 @@ class SVMDistance : public Distance
207 222 Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false)
208 223 Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false)
209 224 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
210   -
  225 + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false)
  226 + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false)
  227 + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false)
211 228  
212 229 public:
213 230 enum Kernel { Linear = CvSVM::LINEAR,
... ... @@ -225,6 +242,9 @@ private:
225 242 BR_PROPERTY(Kernel, kernel, Linear)
226 243 BR_PROPERTY(Type, type, EPS_SVR)
227 244 BR_PROPERTY(QString, inputVariable, "Label")
  245 + BR_PROPERTY(int, termCriteria, 1000)
  246 + BR_PROPERTY(int, folds, 5)
  247 + BR_PROPERTY(bool, balanceFolds, false)
228 248  
229 249 SVM svm;
230 250  
... ... @@ -249,7 +269,7 @@ private:
249 269 deltaData = deltaData.rowRange(0, index);
250 270 deltaLab = deltaLab.rowRange(0, index);
251 271  
252   - trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1);
  272 + trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria);
253 273 }
254 274  
255 275 float compare(const Mat &a, const Mat &b) const
... ...