From 9e9e1a9b627e2b045d06a98586887cbdd24f930d Mon Sep 17 00:00:00 2001 From: Scott Klum Date: Fri, 30 Jan 2015 12:51:23 -0500 Subject: [PATCH] More liblinear cleanup --- openbr/plugins/liblinear.cpp | 49 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 17 deletions(-) diff --git a/openbr/plugins/liblinear.cpp b/openbr/plugins/liblinear.cpp index 7094b9a..7de2485 100644 --- a/openbr/plugins/liblinear.cpp +++ b/openbr/plugins/liblinear.cpp @@ -13,7 +13,7 @@ using namespace cv; namespace br { -class LinearSVM : public Transform +class Linear : public Transform { Q_OBJECT Q_ENUMS(Solver) @@ -22,9 +22,7 @@ class LinearSVM : public Transform Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) - Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) - Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) - Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) public: enum Solver { L2R_LR, @@ -45,9 +43,7 @@ private: BR_PROPERTY(QString, inputVariable, "Label") BR_PROPERTY(QString, outputVariable, "") BR_PROPERTY(bool, returnDFVal, false) - BR_PROPERTY(int, termCriteria, 1000) - BR_PROPERTY(int, folds, 5) - BR_PROPERTY(bool, balanceFolds, false) + BR_PROPERTY(bool, overwriteMat, true) model *m; @@ -56,9 +52,6 @@ private: Mat samples = OpenCVUtils::toMat(data.data()); Mat labels = OpenCVUtils::toMat(File::get(data, inputVariable)); - // Number of features = n - // Number of instances = l - problem prob; prob.n = samples.cols; prob.l = samples.rows; @@ -84,6 +77,7 @@ private: } parameter param; + // TODO: Support grid search param.C = C; param.eps = FLT_EPSILON; @@ -109,19 +103,40 @@ private: Mat sample = src.m().reshape(1,1); feature_node *x_space = new feature_node[sample.cols+1]; - // Assign the address of the ith instance to be the address of the jth feature for (int j=0; j(0,j); } x_space[sample.cols].index = -1; - // TODO: Call appropriate function based on solver - double prob_estimates[1]; - float prediction = predict_values(m,x_space,prob_estimates); + float prediction; + double prob_estimates[m->nr_class]; + + if (solver == L2R_L2LOSS_SVR || + solver == L2R_L1LOSS_SVR_DUAL || + solver == L2R_L2LOSS_SVR_DUAL || + solver == L2R_L2LOSS_SVC_DUAL || + solver == L2R_L2LOSS_SVC || + solver == L2R_L1LOSS_SVC_DUAL || + solver == MCSVM_CS || + solver == L1R_L2LOSS_SVC) + { + prediction = predict_values(m,x_space,prob_estimates); + if (returnDFVal) prediction = prob_estimates[0]; + } else if (solver == L2R_LR || + solver == L2R_LR_DUAL || + solver == L1R_LR) + { + prediction = predict_probability(m,x_space,prob_estimates); + if (returnDFVal) prediction = prob_estimates[0]; + } - dst.m() = Mat(1, 1, CV_32F); - dst.m().at(0, 0) = prob_estimates[0]; + if (overwriteMat) { + dst.m() = Mat(1, 1, CV_32F); + dst.m().at(0, 0) = prediction; + } else { + dst.file.set(outputVariable,prediction); + } delete[] x_space; } @@ -141,7 +156,7 @@ private: } }; -BR_REGISTER(Transform, LinearSVM) +BR_REGISTER(Transform, Linear) } // namespace br -- libgit2 0.21.4