Commit b33957530fa03b491f6b12f41e4c3266731eb05e
1 parent
abba4b7d
Allow user to get distance from margin in Dist metadata
Showing
1 changed file
with
12 additions
and
1 deletions
openbr/plugins/svm.cpp
| @@ -103,6 +103,7 @@ class SVMTransform : public Transform | @@ -103,6 +103,7 @@ class SVMTransform : public Transform | ||
| 103 | Q_PROPERTY(float gamma READ get_gamma WRITE set_gamma RESET reset_gamma STORED false) | 103 | Q_PROPERTY(float gamma READ get_gamma WRITE set_gamma RESET reset_gamma STORED false) |
| 104 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | 104 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 105 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | 105 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) |
| 106 | + Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) | ||
| 106 | 107 | ||
| 107 | public: | 108 | public: |
| 108 | enum Kernel { Linear = CvSVM::LINEAR, | 109 | enum Kernel { Linear = CvSVM::LINEAR, |
| @@ -123,6 +124,7 @@ private: | @@ -123,6 +124,7 @@ private: | ||
| 123 | BR_PROPERTY(float, gamma, -1) | 124 | BR_PROPERTY(float, gamma, -1) |
| 124 | BR_PROPERTY(QString, inputVariable, "") | 125 | BR_PROPERTY(QString, inputVariable, "") |
| 125 | BR_PROPERTY(QString, outputVariable, "") | 126 | BR_PROPERTY(QString, outputVariable, "") |
| 127 | + BR_PROPERTY(bool, returnDFVal, false) | ||
| 126 | 128 | ||
| 127 | 129 | ||
| 128 | SVM svm; | 130 | SVM svm; |
| @@ -149,8 +151,17 @@ private: | @@ -149,8 +151,17 @@ private: | ||
| 149 | 151 | ||
| 150 | void project(const Template &src, Template &dst) const | 152 | void project(const Template &src, Template &dst) const |
| 151 | { | 153 | { |
| 154 | + if (returnDFVal && reverseLookup.size() > 2) | ||
| 155 | + qFatal("Decision function for multiclass classification not implemented."); | ||
| 156 | + | ||
| 152 | dst = src; | 157 | dst = src; |
| 153 | - float prediction = svm.predict(src.m().reshape(1, 1)); | 158 | + float prediction = svm.predict(src.m().reshape(1, 1), returnDFVal); |
| 159 | + if (returnDFVal) { | ||
| 160 | + dst.file.set("Dist", prediction); | ||
| 161 | + // positive values ==> first class | ||
| 162 | + // negative values ==> second class | ||
| 163 | + prediction = prediction > 0 ? 0 : 1; | ||
| 164 | + } | ||
| 154 | if (type == EPS_SVR || type == NU_SVR) | 165 | if (type == EPS_SVR || type == NU_SVR) |
| 155 | dst.file.set(outputVariable, prediction); | 166 | dst.file.set(outputVariable, prediction); |
| 156 | else | 167 | else |