Commit b33957530fa03b491f6b12f41e4c3266731eb05e

Authored by Austin Blanton
1 parent abba4b7d

Allow user to get distance from margin in Dist metadata

Showing 1 changed file with 12 additions and 1 deletions
openbr/plugins/svm.cpp
... ... @@ -103,6 +103,7 @@ class SVMTransform : public Transform
103 103 Q_PROPERTY(float gamma READ get_gamma WRITE set_gamma RESET reset_gamma STORED false)
104 104 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
105 105 Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
  106 + Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false)
106 107  
107 108 public:
108 109 enum Kernel { Linear = CvSVM::LINEAR,
... ... @@ -123,6 +124,7 @@ private:
123 124 BR_PROPERTY(float, gamma, -1)
124 125 BR_PROPERTY(QString, inputVariable, "")
125 126 BR_PROPERTY(QString, outputVariable, "")
  127 + BR_PROPERTY(bool, returnDFVal, false)
126 128  
127 129  
128 130 SVM svm;
... ... @@ -149,8 +151,17 @@ private:
149 151  
150 152 void project(const Template &src, Template &dst) const
151 153 {
  154 + if (returnDFVal && reverseLookup.size() > 2)
  155 + qFatal("Decision function for multiclass classification not implemented.");
  156 +
152 157 dst = src;
153   - float prediction = svm.predict(src.m().reshape(1, 1));
  158 + float prediction = svm.predict(src.m().reshape(1, 1), returnDFVal);
  159 + if (returnDFVal) {
  160 + dst.file.set("Dist", prediction);
  161 + // positive values ==> first class
  162 + // negative values ==> second class
  163 + prediction = prediction > 0 ? 0 : 1;
  164 + }
154 165 if (type == EPS_SVR || type == NU_SVR)
155 166 dst.file.set(outputVariable, prediction);
156 167 else
... ...