diff --git a/openbr/plugins/svm.cpp b/openbr/plugins/svm.cpp index f37eb88..ca01cdd 100644 --- a/openbr/plugins/svm.cpp +++ b/openbr/plugins/svm.cpp @@ -26,10 +26,8 @@ using namespace cv; namespace br { -static void storeSVM(float a, float b, const SVM &svm, QDataStream &stream) +static void storeSVM(const SVM &svm, QDataStream &stream) { - stream << a << b; - // Create local file QTemporaryFile tempFile; tempFile.open(); @@ -45,10 +43,8 @@ static void storeSVM(float a, float b, const SVM &svm, QDataStream &stream) stream << data; } -static void loadSVM(float &a, float &b, SVM &svm, QDataStream &stream) +static void loadSVM(SVM &svm, QDataStream &stream) { - stream >> a >> b; - // Copy local file contents from stream QByteArray data; stream >> data; @@ -63,22 +59,8 @@ static void loadSVM(float &a, float &b, SVM &svm, QDataStream &stream) svm.load(qPrintable(tempFile.fileName())); } -static void trainSVM(float &a, float &b, SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma) +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma) { - if ((type == CvSVM::EPS_SVR) || (type == CvSVM::NU_SVR)) { - // Scale labels to [-1,1] - double min, max; - minMaxLoc(lab, &min, &max); - if (max > min) { - a = 2.0/(max-min); - b = -(min*a+1); - lab = (lab * a) + b; - } - } else { - a = 1; - b = 0; - } - if (data.type() != CV_32FC1) qFatal("Expected single channel floating point training data."); @@ -139,29 +121,28 @@ private: BR_PROPERTY(float, gamma, -1) SVM svm; - float a, b; void train(const TemplateList &_data) { Mat data = OpenCVUtils::toMat(_data.data()); Mat lab = OpenCVUtils::toMat(_data.labels()); - trainSVM(a, b, svm, data, lab, kernel, type, C, gamma); + trainSVM(svm, data, lab, kernel, type, C, gamma); } void project(const Template &src, Template &dst) const { dst = src; - dst.file.set("Label", ((svm.predict(src.m().reshape(1, 1)) - b)/a)); + dst.file.set("Label", svm.predict(src.m().reshape(1, 1))); } void store(QDataStream &stream) const { - storeSVM(a, b, svm, stream); + storeSVM(svm, stream); } void load(QDataStream &stream) { - loadSVM(a, b, svm, stream); + loadSVM(svm, stream); } }; @@ -197,7 +178,6 @@ private: BR_PROPERTY(Type, type, EPS_SVR) SVM svm; - float a, b; void train(const TemplateList &src) { @@ -220,24 +200,24 @@ private: deltaData = deltaData.rowRange(0, index); deltaLab = deltaLab.rowRange(0, index); - trainSVM(a, b, svm, deltaData, deltaLab, kernel, type, -1, -1); + trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1); } float compare(const Template &ta, const Template &tb) const { Mat delta; absdiff(ta, tb, delta); - return (svm.predict(delta.reshape(1, 1)) - b)/a; + return svm.predict(delta.reshape(1, 1)); } void store(QDataStream &stream) const { - storeSVM(a, b, svm, stream); + storeSVM(svm, stream); } void load(QDataStream &stream) { - loadSVM(a, b, svm, stream); + loadSVM(svm, stream); } }; diff --git a/share/openbr/models b/share/openbr/models index 1394fbe..cddd8fc 160000 --- a/share/openbr/models +++ b/share/openbr/models @@ -1 +1 @@ -Subproject commit 1394fbec42897c1a53373443c32a6f62a010294f +Subproject commit cddd8fc2231191ab5d62dc29573f73e97d5065c5