diff --git a/openbr/plugins/classification/svm.cpp b/openbr/plugins/classification/svm.cpp index e6f42d1..c361957 100644 --- a/openbr/plugins/classification/svm.cpp +++ b/openbr/plugins/classification/svm.cpp @@ -26,19 +26,12 @@ using namespace cv; namespace br { -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria) +static void trainSVM(SVM &svm, Mat data, Mat lab, CvSVMParams params, float C, float gamma, int folds, bool balanceFolds) { if (data.type() != CV_32FC1) qFatal("Expected single channel floating point training data."); - CvSVMParams params; - params.kernel_type = kernel; - params.svm_type = type; - params.p = 0.1; - params.nu = 0.5; - params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); - - if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { + if ((C == -1) || ((gamma == -1) && (params.kernel_type == CvSVM::RBF))) { try { svm.train_auto(data, lab, Mat(), Mat(), params, folds, CvSVM::get_default_grid(CvSVM::C), @@ -83,6 +76,7 @@ class SVMTransform : public Transform Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) + Q_PROPERTY(int degree READ get_degree WRITE set_degree RESET reset_degree STORED false) public: enum Kernel { Linear = CvSVM::LINEAR, @@ -107,6 +101,7 @@ private: BR_PROPERTY(int, termCriteria, 1000) BR_PROPERTY(int, folds, 5) BR_PROPERTY(bool, balanceFolds, false) + BR_PROPERTY(int, degree, 2) SVM svm; QHash labelMap; @@ -127,8 +122,16 @@ private: QList dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); lab = OpenCVUtils::toMat(dataLabels); } - - trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria); + + CvSVMParams params; + params.kernel_type = kernel; + params.svm_type = type; + params.p = 0.1; + params.nu = 0.5; + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); + params.degree = degree; + + trainSVM(svm, data, lab, params, C, gamma, folds, balanceFolds); } void project(const Template &src, Template &dst) const @@ -192,6 +195,7 @@ class SVMDistance : public Distance Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) + Q_PROPERTY(int degree READ get_degree WRITE set_degree RESET reset_degree STORED false) public: enum Kernel { Linear = CvSVM::LINEAR, @@ -212,6 +216,7 @@ private: BR_PROPERTY(int, termCriteria, 1000) BR_PROPERTY(int, folds, 5) BR_PROPERTY(bool, balanceFolds, false) + BR_PROPERTY(int, degree, 2) SVM svm; @@ -235,8 +240,16 @@ private: } deltaData = deltaData.rowRange(0, index); deltaLab = deltaLab.rowRange(0, index); - - trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria); + + CvSVMParams params; + params.kernel_type = kernel; + params.svm_type = type; + params.p = 0.1; + params.nu = 0.5; + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); + params.degree = degree; + + trainSVM(svm, deltaData, deltaLab, params, -1, -1, folds, balanceFolds); } float compare(const Mat &a, const Mat &b) const