Commit 9e9e1a9b627e2b045d06a98586887cbdd24f930d
1 parent
914912b1
More liblinear cleanup
Showing
1 changed file
with
32 additions
and
17 deletions
openbr/plugins/liblinear.cpp
| ... | ... | @@ -13,7 +13,7 @@ using namespace cv; |
| 13 | 13 | namespace br |
| 14 | 14 | { |
| 15 | 15 | |
| 16 | -class LinearSVM : public Transform | |
| 16 | +class Linear : public Transform | |
| 17 | 17 | { |
| 18 | 18 | Q_OBJECT |
| 19 | 19 | Q_ENUMS(Solver) |
| ... | ... | @@ -22,9 +22,7 @@ class LinearSVM : public Transform |
| 22 | 22 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 23 | 23 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) |
| 24 | 24 | Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) |
| 25 | - Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) | |
| 26 | - Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | |
| 27 | - Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) | |
| 25 | + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) | |
| 28 | 26 | |
| 29 | 27 | public: |
| 30 | 28 | enum Solver { L2R_LR, |
| ... | ... | @@ -45,9 +43,7 @@ private: |
| 45 | 43 | BR_PROPERTY(QString, inputVariable, "Label") |
| 46 | 44 | BR_PROPERTY(QString, outputVariable, "") |
| 47 | 45 | BR_PROPERTY(bool, returnDFVal, false) |
| 48 | - BR_PROPERTY(int, termCriteria, 1000) | |
| 49 | - BR_PROPERTY(int, folds, 5) | |
| 50 | - BR_PROPERTY(bool, balanceFolds, false) | |
| 46 | + BR_PROPERTY(bool, overwriteMat, true) | |
| 51 | 47 | |
| 52 | 48 | model *m; |
| 53 | 49 | |
| ... | ... | @@ -56,9 +52,6 @@ private: |
| 56 | 52 | Mat samples = OpenCVUtils::toMat(data.data()); |
| 57 | 53 | Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); |
| 58 | 54 | |
| 59 | - // Number of features = n | |
| 60 | - // Number of instances = l | |
| 61 | - | |
| 62 | 55 | problem prob; |
| 63 | 56 | prob.n = samples.cols; |
| 64 | 57 | prob.l = samples.rows; |
| ... | ... | @@ -84,6 +77,7 @@ private: |
| 84 | 77 | } |
| 85 | 78 | |
| 86 | 79 | parameter param; |
| 80 | + | |
| 87 | 81 | // TODO: Support grid search |
| 88 | 82 | param.C = C; |
| 89 | 83 | param.eps = FLT_EPSILON; |
| ... | ... | @@ -109,19 +103,40 @@ private: |
| 109 | 103 | Mat sample = src.m().reshape(1,1); |
| 110 | 104 | feature_node *x_space = new feature_node[sample.cols+1]; |
| 111 | 105 | |
| 112 | - // Assign the address of the ith instance to be the address of the jth feature | |
| 113 | 106 | for (int j=0; j<sample.cols; j++) { |
| 114 | 107 | x_space[j].index = j+1; |
| 115 | 108 | x_space[j].value = sample.at<float>(0,j); |
| 116 | 109 | } |
| 117 | 110 | x_space[sample.cols].index = -1; |
| 118 | 111 | |
| 119 | - // TODO: Call appropriate function based on solver | |
| 120 | - double prob_estimates[1]; | |
| 121 | - float prediction = predict_values(m,x_space,prob_estimates); | |
| 112 | + float prediction; | |
| 113 | + double prob_estimates[m->nr_class]; | |
| 114 | + | |
| 115 | + if (solver == L2R_L2LOSS_SVR || | |
| 116 | + solver == L2R_L1LOSS_SVR_DUAL || | |
| 117 | + solver == L2R_L2LOSS_SVR_DUAL || | |
| 118 | + solver == L2R_L2LOSS_SVC_DUAL || | |
| 119 | + solver == L2R_L2LOSS_SVC || | |
| 120 | + solver == L2R_L1LOSS_SVC_DUAL || | |
| 121 | + solver == MCSVM_CS || | |
| 122 | + solver == L1R_L2LOSS_SVC) | |
| 123 | + { | |
| 124 | + prediction = predict_values(m,x_space,prob_estimates); | |
| 125 | + if (returnDFVal) prediction = prob_estimates[0]; | |
| 126 | + } else if (solver == L2R_LR || | |
| 127 | + solver == L2R_LR_DUAL || | |
| 128 | + solver == L1R_LR) | |
| 129 | + { | |
| 130 | + prediction = predict_probability(m,x_space,prob_estimates); | |
| 131 | + if (returnDFVal) prediction = prob_estimates[0]; | |
| 132 | + } | |
| 122 | 133 | |
| 123 | - dst.m() = Mat(1, 1, CV_32F); | |
| 124 | - dst.m().at<float>(0, 0) = prob_estimates[0]; | |
| 134 | + if (overwriteMat) { | |
| 135 | + dst.m() = Mat(1, 1, CV_32F); | |
| 136 | + dst.m().at<float>(0, 0) = prediction; | |
| 137 | + } else { | |
| 138 | + dst.file.set(outputVariable,prediction); | |
| 139 | + } | |
| 125 | 140 | |
| 126 | 141 | delete[] x_space; |
| 127 | 142 | } |
| ... | ... | @@ -141,7 +156,7 @@ private: |
| 141 | 156 | } |
| 142 | 157 | }; |
| 143 | 158 | |
| 144 | -BR_REGISTER(Transform, LinearSVM) | |
| 159 | +BR_REGISTER(Transform, Linear) | |
| 145 | 160 | |
| 146 | 161 | } // namespace br |
| 147 | 162 | ... | ... |