From 1aed6ade0f84f6d7b819289edffa8236295586bd Mon Sep 17 00:00:00 2001 From: Scott Klum Date: Fri, 26 Dec 2014 10:58:39 -0500 Subject: [PATCH] Added additional SVM parameters --- openbr/plugins/svm.cpp | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/openbr/plugins/svm.cpp b/openbr/plugins/svm.cpp index ad4594f..b80901e 100644 --- a/openbr/plugins/svm.cpp +++ b/openbr/plugins/svm.cpp @@ -59,7 +59,7 @@ static void loadSVM(SVM &svm, QDataStream &stream) svm.load(qPrintable(tempFile.fileName())); } -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma) +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria) { if (data.type() != CV_32FC1) qFatal("Expected single channel floating point training data."); @@ -69,9 +69,18 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, params.svm_type = type; params.p = 0.1; params.nu = 0.5; + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); + if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { try { - svm.train_auto(data, lab, Mat(), Mat(), params, 5); + svm.train_auto(data, lab, Mat(), Mat(), params, folds, + CvSVM::get_default_grid(CvSVM::C), + CvSVM::get_default_grid(CvSVM::GAMMA), + CvSVM::get_default_grid(CvSVM::P), + CvSVM::get_default_grid(CvSVM::NU), + CvSVM::get_default_grid(CvSVM::COEF), + CvSVM::get_default_grid(CvSVM::DEGREE), + balanceFolds); } catch (...) { qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification."); svm.train(data, lab, Mat(), Mat(), params); @@ -104,6 +113,9 @@ class SVMTransform : public Transform Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) public: enum Kernel { Linear = CvSVM::LINEAR, @@ -125,7 +137,9 @@ private: BR_PROPERTY(QString, inputVariable, "Label") BR_PROPERTY(QString, outputVariable, "") BR_PROPERTY(bool, returnDFVal, false) - + BR_PROPERTY(int, termCriteria, 1000) + BR_PROPERTY(int, folds, 5) + BR_PROPERTY(bool, balanceFolds, false) SVM svm; QHash labelMap; @@ -146,7 +160,8 @@ private: QList dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); lab = OpenCVUtils::toMat(dataLabels); } - trainSVM(svm, data, lab, kernel, type, C, gamma); + + trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria); } void project(const Template &src, Template &dst) const @@ -207,7 +222,9 @@ class SVMDistance : public Distance Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) - + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) public: enum Kernel { Linear = CvSVM::LINEAR, @@ -225,6 +242,9 @@ private: BR_PROPERTY(Kernel, kernel, Linear) BR_PROPERTY(Type, type, EPS_SVR) BR_PROPERTY(QString, inputVariable, "Label") + BR_PROPERTY(int, termCriteria, 1000) + BR_PROPERTY(int, folds, 5) + BR_PROPERTY(bool, balanceFolds, false) SVM svm; @@ -249,7 +269,7 @@ private: deltaData = deltaData.rowRange(0, index); deltaLab = deltaLab.rowRange(0, index); - trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1); + trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria); } float compare(const Mat &a, const Mat &b) const -- libgit2 0.21.4