Commit 60f9377867e8765b358bbd1288ed8bca2d6ad679
1 parent
31ac65fc
Revert "Added additional options to SVM"
This reverts commit 31ac65fc2826a05a49d85d4043b717b694af2db4.
Showing
3 changed files
with
7 additions
and
22 deletions
.gitignore
openbr/plugins/stasm4.cpp
| @@ -186,7 +186,7 @@ private: | @@ -186,7 +186,7 @@ private: | ||
| 186 | for (int j = 0; j < 3; j++, cnt++) | 186 | for (int j = 0; j < 3; j++, cnt++) |
| 187 | affine(i, j) = paramList[cnt]; | 187 | affine(i, j) = paramList[cnt]; |
| 188 | affine(2, 2) = 1; | 188 | affine(2, 2) = 1; |
| 189 | - //affine = affine.inverse(); | 189 | + affine = affine.inverse(); |
| 190 | Eigen::MatrixXf affineInv = affine.block(0, 0, 2, 3); | 190 | Eigen::MatrixXf affineInv = affine.block(0, 0, 2, 3); |
| 191 | Eigen::MatrixXf pointsT = points.transpose(); | 191 | Eigen::MatrixXf pointsT = points.transpose(); |
| 192 | points = affineInv * pointsT; | 192 | points = affineInv * pointsT; |
openbr/plugins/svm.cpp
| @@ -40,7 +40,6 @@ static void storeSVM(const SVM &svm, QDataStream &stream) | @@ -40,7 +40,6 @@ static void storeSVM(const SVM &svm, QDataStream &stream) | ||
| 40 | tempFile.open(); | 40 | tempFile.open(); |
| 41 | QByteArray data = tempFile.readAll(); | 41 | QByteArray data = tempFile.readAll(); |
| 42 | tempFile.close(); | 42 | tempFile.close(); |
| 43 | - qDebug() << "Storing" << data.size() << "bytes for SVM"; | ||
| 44 | stream << data; | 43 | stream << data; |
| 45 | } | 44 | } |
| 46 | 45 | ||
| @@ -60,7 +59,7 @@ static void loadSVM(SVM &svm, QDataStream &stream) | @@ -60,7 +59,7 @@ static void loadSVM(SVM &svm, QDataStream &stream) | ||
| 60 | svm.load(qPrintable(tempFile.fileName())); | 59 | svm.load(qPrintable(tempFile.fileName())); |
| 61 | } | 60 | } |
| 62 | 61 | ||
| 63 | -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria) | 62 | +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma) |
| 64 | { | 63 | { |
| 65 | if (data.type() != CV_32FC1) | 64 | if (data.type() != CV_32FC1) |
| 66 | qFatal("Expected single channel floating point training data."); | 65 | qFatal("Expected single channel floating point training data."); |
| @@ -68,18 +67,11 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, | @@ -68,18 +67,11 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, | ||
| 68 | CvSVMParams params; | 67 | CvSVMParams params; |
| 69 | params.kernel_type = kernel; | 68 | params.kernel_type = kernel; |
| 70 | params.svm_type = type; | 69 | params.svm_type = type; |
| 71 | - params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); | ||
| 72 | - | 70 | + params.p = 0.1; |
| 71 | + params.nu = 0.5; | ||
| 73 | if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { | 72 | if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { |
| 74 | try { | 73 | try { |
| 75 | - svm.train_auto(data, lab, Mat(), Mat(), params, folds, | ||
| 76 | - CvSVM::get_default_grid(CvSVM::C), | ||
| 77 | - CvSVM::get_default_grid(CvSVM::GAMMA), | ||
| 78 | - CvSVM::get_default_grid(CvSVM::P), | ||
| 79 | - CvSVM::get_default_grid(CvSVM::NU), | ||
| 80 | - CvSVM::get_default_grid(CvSVM::COEF), | ||
| 81 | - CvSVM::get_default_grid(CvSVM::DEGREE), | ||
| 82 | - balanceFolds); | 74 | + svm.train_auto(data, lab, Mat(), Mat(), params, 5); |
| 83 | } catch (...) { | 75 | } catch (...) { |
| 84 | qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification."); | 76 | qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification."); |
| 85 | svm.train(data, lab, Mat(), Mat(), params); | 77 | svm.train(data, lab, Mat(), Mat(), params); |
| @@ -112,9 +104,6 @@ class SVMTransform : public Transform | @@ -112,9 +104,6 @@ class SVMTransform : public Transform | ||
| 112 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | 104 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 113 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | 105 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) |
| 114 | Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) | 106 | Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) |
| 115 | - Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) | ||
| 116 | - Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | ||
| 117 | - Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) | ||
| 118 | 107 | ||
| 119 | public: | 108 | public: |
| 120 | enum Kernel { Linear = CvSVM::LINEAR, | 109 | enum Kernel { Linear = CvSVM::LINEAR, |
| @@ -136,9 +125,7 @@ private: | @@ -136,9 +125,7 @@ private: | ||
| 136 | BR_PROPERTY(QString, inputVariable, "Label") | 125 | BR_PROPERTY(QString, inputVariable, "Label") |
| 137 | BR_PROPERTY(QString, outputVariable, "") | 126 | BR_PROPERTY(QString, outputVariable, "") |
| 138 | BR_PROPERTY(bool, returnDFVal, false) | 127 | BR_PROPERTY(bool, returnDFVal, false) |
| 139 | - BR_PROPERTY(int, termCriteria, 1000) | ||
| 140 | - BR_PROPERTY(int, folds, 5) | ||
| 141 | - BR_PROPERTY(bool, balanceFolds, false) | 128 | + |
| 142 | 129 | ||
| 143 | SVM svm; | 130 | SVM svm; |
| 144 | QHash<QString, int> labelMap; | 131 | QHash<QString, int> labelMap; |
| @@ -159,8 +146,7 @@ private: | @@ -159,8 +146,7 @@ private: | ||
| 159 | QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); | 146 | QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); |
| 160 | lab = OpenCVUtils::toMat(dataLabels); | 147 | lab = OpenCVUtils::toMat(dataLabels); |
| 161 | } | 148 | } |
| 162 | - | ||
| 163 | - trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria); | 149 | + trainSVM(svm, data, lab, kernel, type, C, gamma); |
| 164 | } | 150 | } |
| 165 | 151 | ||
| 166 | void project(const Template &src, Template &dst) const | 152 | void project(const Template &src, Template &dst) const |