Commit 31ac65fc2826a05a49d85d4043b717b694af2db4

Authored by Scott Klum
1 parent d93507d9

Added additional options to SVM

.gitignore
... ... @@ -6,6 +6,7 @@ data/*/img
6 6 data/*/vid
7 7 data/PCSO/*
8 8 data/lfpw
  9 +data/lfw
9 10 build*
10 11 scripts/results
11 12  
... ...
openbr/plugins/stasm4.cpp
... ... @@ -186,7 +186,7 @@ private:
186 186 for (int j = 0; j < 3; j++, cnt++)
187 187 affine(i, j) = paramList[cnt];
188 188 affine(2, 2) = 1;
189   - affine = affine.inverse();
  189 + //affine = affine.inverse();
190 190 Eigen::MatrixXf affineInv = affine.block(0, 0, 2, 3);
191 191 Eigen::MatrixXf pointsT = points.transpose();
192 192 points = affineInv * pointsT;
... ...
openbr/plugins/svm.cpp
... ... @@ -40,6 +40,7 @@ static void storeSVM(const SVM &amp;svm, QDataStream &amp;stream)
40 40 tempFile.open();
41 41 QByteArray data = tempFile.readAll();
42 42 tempFile.close();
  43 + qDebug() << "Storing" << data.size() << "bytes for SVM";
43 44 stream << data;
44 45 }
45 46  
... ... @@ -59,7 +60,7 @@ static void loadSVM(SVM &amp;svm, QDataStream &amp;stream)
59 60 svm.load(qPrintable(tempFile.fileName()));
60 61 }
61 62  
62   -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma)
  63 +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria)
63 64 {
64 65 if (data.type() != CV_32FC1)
65 66 qFatal("Expected single channel floating point training data.");
... ... @@ -67,11 +68,18 @@ static void trainSVM(SVM &amp;svm, Mat data, Mat lab, int kernel, int type, float C,
67 68 CvSVMParams params;
68 69 params.kernel_type = kernel;
69 70 params.svm_type = type;
70   - params.p = 0.1;
71   - params.nu = 0.5;
  71 + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON);
  72 +
72 73 if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) {
73 74 try {
74   - svm.train_auto(data, lab, Mat(), Mat(), params, 5);
  75 + svm.train_auto(data, lab, Mat(), Mat(), params, folds,
  76 + CvSVM::get_default_grid(CvSVM::C),
  77 + CvSVM::get_default_grid(CvSVM::GAMMA),
  78 + CvSVM::get_default_grid(CvSVM::P),
  79 + CvSVM::get_default_grid(CvSVM::NU),
  80 + CvSVM::get_default_grid(CvSVM::COEF),
  81 + CvSVM::get_default_grid(CvSVM::DEGREE),
  82 + balanceFolds);
75 83 } catch (...) {
76 84 qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification.");
77 85 svm.train(data, lab, Mat(), Mat(), params);
... ... @@ -104,6 +112,9 @@ class SVMTransform : public Transform
104 112 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
105 113 Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
106 114 Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false)
  115 + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false)
  116 + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false)
  117 + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false)
107 118  
108 119 public:
109 120 enum Kernel { Linear = CvSVM::LINEAR,
... ... @@ -125,7 +136,9 @@ private:
125 136 BR_PROPERTY(QString, inputVariable, "Label")
126 137 BR_PROPERTY(QString, outputVariable, "")
127 138 BR_PROPERTY(bool, returnDFVal, false)
128   -
  139 + BR_PROPERTY(int, termCriteria, 1000)
  140 + BR_PROPERTY(int, folds, 5)
  141 + BR_PROPERTY(bool, balanceFolds, false)
129 142  
130 143 SVM svm;
131 144 QHash<QString, int> labelMap;
... ... @@ -146,7 +159,8 @@ private:
146 159 QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup);
147 160 lab = OpenCVUtils::toMat(dataLabels);
148 161 }
149   - trainSVM(svm, data, lab, kernel, type, C, gamma);
  162 +
  163 + trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria);
150 164 }
151 165  
152 166 void project(const Template &src, Template &dst) const
... ...