diff --git a/openbr/plugins/svm.cpp b/openbr/plugins/svm.cpp index d4e2aad..a30fe9b 100644 --- a/openbr/plugins/svm.cpp +++ b/openbr/plugins/svm.cpp @@ -103,6 +103,7 @@ class SVMTransform : public Transform Q_PROPERTY(float gamma READ get_gamma WRITE set_gamma RESET reset_gamma STORED false) Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) + Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) public: enum Kernel { Linear = CvSVM::LINEAR, @@ -123,6 +124,7 @@ private: BR_PROPERTY(float, gamma, -1) BR_PROPERTY(QString, inputVariable, "") BR_PROPERTY(QString, outputVariable, "") + BR_PROPERTY(bool, returnDFVal, false) SVM svm; @@ -149,8 +151,17 @@ private: void project(const Template &src, Template &dst) const { + if (returnDFVal && reverseLookup.size() > 2) + qFatal("Decision function for multiclass classification not implemented."); + dst = src; - float prediction = svm.predict(src.m().reshape(1, 1)); + float prediction = svm.predict(src.m().reshape(1, 1), returnDFVal); + if (returnDFVal) { + dst.file.set("Dist", prediction); + // positive values ==> first class + // negative values ==> second class + prediction = prediction > 0 ? 0 : 1; + } if (type == EPS_SVR || type == NU_SVR) dst.file.set(outputVariable, prediction); else