diff --git a/openbr/plugins/liblinear.cpp b/openbr/plugins/liblinear.cpp index 4ecf72c..7094b9a 100644 --- a/openbr/plugins/liblinear.cpp +++ b/openbr/plugins/liblinear.cpp @@ -16,12 +16,9 @@ namespace br class LinearSVM : public Transform { Q_OBJECT - Q_ENUMS(Kernel) - Q_ENUMS(Type) - Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) - Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) + Q_ENUMS(Solver) + Q_PROPERTY(Solver solver READ get_solver WRITE set_solver RESET reset_solver STORED false) Q_PROPERTY(float C READ get_C WRITE set_C RESET reset_C STORED false) - Q_PROPERTY(float gamma READ get_gamma WRITE set_gamma RESET reset_gamma STORED false) Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) @@ -30,22 +27,21 @@ class LinearSVM : public Transform Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) public: - enum Kernel { Linear = CvSVM::LINEAR, - Poly = CvSVM::POLY, - RBF = CvSVM::RBF, - Sigmoid = CvSVM::SIGMOID }; - - enum Type { C_SVC = CvSVM::C_SVC, - NU_SVC = CvSVM::NU_SVC, - ONE_CLASS = CvSVM::ONE_CLASS, - EPS_SVR = CvSVM::EPS_SVR, - NU_SVR = CvSVM::NU_SVR}; + enum Solver { L2R_LR, + L2R_L2LOSS_SVC_DUAL, + L2R_L2LOSS_SVC, + L2R_L1LOSS_SVC_DUAL, + MCSVM_CS, + L1R_L2LOSS_SVC, + L1R_LR, + L2R_LR_DUAL, + L2R_L2LOSS_SVR, + L2R_L2LOSS_SVR_DUAL, + L2R_L1LOSS_SVR_DUAL }; private: - BR_PROPERTY(Kernel, kernel, Linear) - BR_PROPERTY(Type, type, C_SVC) - BR_PROPERTY(float, C, -1) - BR_PROPERTY(float, gamma, -1) + BR_PROPERTY(Solver, solver, L2R_L2LOSS_SVC_DUAL) + BR_PROPERTY(float, C, 1) BR_PROPERTY(QString, inputVariable, "Label") BR_PROPERTY(QString, outputVariable, "") BR_PROPERTY(bool, returnDFVal, false) @@ -88,9 +84,12 @@ private: } parameter param; - param.C = 1; + // TODO: Support grid search + param.C = C; param.eps = FLT_EPSILON; - param.solver_type = L2R_L2LOSS_SVC_DUAL; + param.solver_type = solver; + + // TODO: Support weights param.nr_weight = 0; param.p = 1; param.weight_label = NULL; @@ -98,9 +97,9 @@ private: m = train_svm(&prob, ¶m); - delete x_space; - delete prob.x; - delete prob.y; + delete[] prob.y; + delete[] prob.x; + delete[] x_space; } void project(const Template &src, Template &dst) const @@ -108,7 +107,7 @@ private: dst = src; Mat sample = src.m().reshape(1,1); - feature_node *x_space = new feature_node[sample.cols]; + feature_node *x_space = new feature_node[sample.cols+1]; // Assign the address of the ith instance to be the address of the jth feature for (int j=0; j(0, 0) = prob_estimates[0]; + + delete[] x_space; } void store(QDataStream &stream) const