diff --git a/openbr/plugins/classification/svm.cpp b/openbr/plugins/classification/svm.cpp index c361957..e6f42d1 100644 --- a/openbr/plugins/classification/svm.cpp +++ b/openbr/plugins/classification/svm.cpp @@ -26,12 +26,19 @@ using namespace cv; namespace br { -static void trainSVM(SVM &svm, Mat data, Mat lab, CvSVMParams params, float C, float gamma, int folds, bool balanceFolds) +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria) { if (data.type() != CV_32FC1) qFatal("Expected single channel floating point training data."); - if ((C == -1) || ((gamma == -1) && (params.kernel_type == CvSVM::RBF))) { + CvSVMParams params; + params.kernel_type = kernel; + params.svm_type = type; + params.p = 0.1; + params.nu = 0.5; + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); + + if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { try { svm.train_auto(data, lab, Mat(), Mat(), params, folds, CvSVM::get_default_grid(CvSVM::C), @@ -76,7 +83,6 @@ class SVMTransform : public Transform Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) - Q_PROPERTY(int degree READ get_degree WRITE set_degree RESET reset_degree STORED false) public: enum Kernel { Linear = CvSVM::LINEAR, @@ -101,7 +107,6 @@ private: BR_PROPERTY(int, termCriteria, 1000) BR_PROPERTY(int, folds, 5) BR_PROPERTY(bool, balanceFolds, false) - BR_PROPERTY(int, degree, 2) SVM svm; QHash labelMap; @@ -122,16 +127,8 @@ private: QList dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); lab = OpenCVUtils::toMat(dataLabels); } - - CvSVMParams params; - params.kernel_type = kernel; - params.svm_type = type; - params.p = 0.1; - params.nu = 0.5; - params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); - params.degree = degree; - - trainSVM(svm, data, lab, params, C, gamma, folds, balanceFolds); + + trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria); } void project(const Template &src, Template &dst) const @@ -195,7 +192,6 @@ class SVMDistance : public Distance Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) - Q_PROPERTY(int degree READ get_degree WRITE set_degree RESET reset_degree STORED false) public: enum Kernel { Linear = CvSVM::LINEAR, @@ -216,7 +212,6 @@ private: BR_PROPERTY(int, termCriteria, 1000) BR_PROPERTY(int, folds, 5) BR_PROPERTY(bool, balanceFolds, false) - BR_PROPERTY(int, degree, 2) SVM svm; @@ -240,16 +235,8 @@ private: } deltaData = deltaData.rowRange(0, index); deltaLab = deltaLab.rowRange(0, index); - - CvSVMParams params; - params.kernel_type = kernel; - params.svm_type = type; - params.p = 0.1; - params.nu = 0.5; - params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); - params.degree = degree; - - trainSVM(svm, deltaData, deltaLab, params, -1, -1, folds, balanceFolds); + + trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria); } float compare(const Mat &a, const Mat &b) const