Commit 31ac65fc2826a05a49d85d4043b717b694af2db4
1 parent
d93507d9
Added additional options to SVM
Showing
3 changed files
with
22 additions
and
7 deletions
.gitignore
openbr/plugins/stasm4.cpp
| @@ -186,7 +186,7 @@ private: | @@ -186,7 +186,7 @@ private: | ||
| 186 | for (int j = 0; j < 3; j++, cnt++) | 186 | for (int j = 0; j < 3; j++, cnt++) |
| 187 | affine(i, j) = paramList[cnt]; | 187 | affine(i, j) = paramList[cnt]; |
| 188 | affine(2, 2) = 1; | 188 | affine(2, 2) = 1; |
| 189 | - affine = affine.inverse(); | 189 | + //affine = affine.inverse(); |
| 190 | Eigen::MatrixXf affineInv = affine.block(0, 0, 2, 3); | 190 | Eigen::MatrixXf affineInv = affine.block(0, 0, 2, 3); |
| 191 | Eigen::MatrixXf pointsT = points.transpose(); | 191 | Eigen::MatrixXf pointsT = points.transpose(); |
| 192 | points = affineInv * pointsT; | 192 | points = affineInv * pointsT; |
openbr/plugins/svm.cpp
| @@ -40,6 +40,7 @@ static void storeSVM(const SVM &svm, QDataStream &stream) | @@ -40,6 +40,7 @@ static void storeSVM(const SVM &svm, QDataStream &stream) | ||
| 40 | tempFile.open(); | 40 | tempFile.open(); |
| 41 | QByteArray data = tempFile.readAll(); | 41 | QByteArray data = tempFile.readAll(); |
| 42 | tempFile.close(); | 42 | tempFile.close(); |
| 43 | + qDebug() << "Storing" << data.size() << "bytes for SVM"; | ||
| 43 | stream << data; | 44 | stream << data; |
| 44 | } | 45 | } |
| 45 | 46 | ||
| @@ -59,7 +60,7 @@ static void loadSVM(SVM &svm, QDataStream &stream) | @@ -59,7 +60,7 @@ static void loadSVM(SVM &svm, QDataStream &stream) | ||
| 59 | svm.load(qPrintable(tempFile.fileName())); | 60 | svm.load(qPrintable(tempFile.fileName())); |
| 60 | } | 61 | } |
| 61 | 62 | ||
| 62 | -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma) | 63 | +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria) |
| 63 | { | 64 | { |
| 64 | if (data.type() != CV_32FC1) | 65 | if (data.type() != CV_32FC1) |
| 65 | qFatal("Expected single channel floating point training data."); | 66 | qFatal("Expected single channel floating point training data."); |
| @@ -67,11 +68,18 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, | @@ -67,11 +68,18 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, | ||
| 67 | CvSVMParams params; | 68 | CvSVMParams params; |
| 68 | params.kernel_type = kernel; | 69 | params.kernel_type = kernel; |
| 69 | params.svm_type = type; | 70 | params.svm_type = type; |
| 70 | - params.p = 0.1; | ||
| 71 | - params.nu = 0.5; | 71 | + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); |
| 72 | + | ||
| 72 | if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { | 73 | if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { |
| 73 | try { | 74 | try { |
| 74 | - svm.train_auto(data, lab, Mat(), Mat(), params, 5); | 75 | + svm.train_auto(data, lab, Mat(), Mat(), params, folds, |
| 76 | + CvSVM::get_default_grid(CvSVM::C), | ||
| 77 | + CvSVM::get_default_grid(CvSVM::GAMMA), | ||
| 78 | + CvSVM::get_default_grid(CvSVM::P), | ||
| 79 | + CvSVM::get_default_grid(CvSVM::NU), | ||
| 80 | + CvSVM::get_default_grid(CvSVM::COEF), | ||
| 81 | + CvSVM::get_default_grid(CvSVM::DEGREE), | ||
| 82 | + balanceFolds); | ||
| 75 | } catch (...) { | 83 | } catch (...) { |
| 76 | qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification."); | 84 | qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification."); |
| 77 | svm.train(data, lab, Mat(), Mat(), params); | 85 | svm.train(data, lab, Mat(), Mat(), params); |
| @@ -104,6 +112,9 @@ class SVMTransform : public Transform | @@ -104,6 +112,9 @@ class SVMTransform : public Transform | ||
| 104 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | 112 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 105 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | 113 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) |
| 106 | Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) | 114 | Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) |
| 115 | + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) | ||
| 116 | + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | ||
| 117 | + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) | ||
| 107 | 118 | ||
| 108 | public: | 119 | public: |
| 109 | enum Kernel { Linear = CvSVM::LINEAR, | 120 | enum Kernel { Linear = CvSVM::LINEAR, |
| @@ -125,7 +136,9 @@ private: | @@ -125,7 +136,9 @@ private: | ||
| 125 | BR_PROPERTY(QString, inputVariable, "Label") | 136 | BR_PROPERTY(QString, inputVariable, "Label") |
| 126 | BR_PROPERTY(QString, outputVariable, "") | 137 | BR_PROPERTY(QString, outputVariable, "") |
| 127 | BR_PROPERTY(bool, returnDFVal, false) | 138 | BR_PROPERTY(bool, returnDFVal, false) |
| 128 | - | 139 | + BR_PROPERTY(int, termCriteria, 1000) |
| 140 | + BR_PROPERTY(int, folds, 5) | ||
| 141 | + BR_PROPERTY(bool, balanceFolds, false) | ||
| 129 | 142 | ||
| 130 | SVM svm; | 143 | SVM svm; |
| 131 | QHash<QString, int> labelMap; | 144 | QHash<QString, int> labelMap; |
| @@ -146,7 +159,8 @@ private: | @@ -146,7 +159,8 @@ private: | ||
| 146 | QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); | 159 | QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); |
| 147 | lab = OpenCVUtils::toMat(dataLabels); | 160 | lab = OpenCVUtils::toMat(dataLabels); |
| 148 | } | 161 | } |
| 149 | - trainSVM(svm, data, lab, kernel, type, C, gamma); | 162 | + |
| 163 | + trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria); | ||
| 150 | } | 164 | } |
| 151 | 165 | ||
| 152 | void project(const Template &src, Template &dst) const | 166 | void project(const Template &src, Template &dst) const |