Commit 9e9e1a9b627e2b045d06a98586887cbdd24f930d

Authored by Scott Klum
1 parent 914912b1

More liblinear cleanup

Showing 1 changed file with 32 additions and 17 deletions
openbr/plugins/liblinear.cpp
@@ -13,7 +13,7 @@ using namespace cv; @@ -13,7 +13,7 @@ using namespace cv;
13 namespace br 13 namespace br
14 { 14 {
15 15
16 -class LinearSVM : public Transform 16 +class Linear : public Transform
17 { 17 {
18 Q_OBJECT 18 Q_OBJECT
19 Q_ENUMS(Solver) 19 Q_ENUMS(Solver)
@@ -22,9 +22,7 @@ class LinearSVM : public Transform @@ -22,9 +22,7 @@ class LinearSVM : public Transform
22 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) 22 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
23 Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) 23 Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
24 Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) 24 Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false)
25 - Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false)  
26 - Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false)  
27 - Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) 25 + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false)
28 26
29 public: 27 public:
30 enum Solver { L2R_LR, 28 enum Solver { L2R_LR,
@@ -45,9 +43,7 @@ private: @@ -45,9 +43,7 @@ private:
45 BR_PROPERTY(QString, inputVariable, "Label") 43 BR_PROPERTY(QString, inputVariable, "Label")
46 BR_PROPERTY(QString, outputVariable, "") 44 BR_PROPERTY(QString, outputVariable, "")
47 BR_PROPERTY(bool, returnDFVal, false) 45 BR_PROPERTY(bool, returnDFVal, false)
48 - BR_PROPERTY(int, termCriteria, 1000)  
49 - BR_PROPERTY(int, folds, 5)  
50 - BR_PROPERTY(bool, balanceFolds, false) 46 + BR_PROPERTY(bool, overwriteMat, true)
51 47
52 model *m; 48 model *m;
53 49
@@ -56,9 +52,6 @@ private: @@ -56,9 +52,6 @@ private:
56 Mat samples = OpenCVUtils::toMat(data.data()); 52 Mat samples = OpenCVUtils::toMat(data.data());
57 Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); 53 Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable));
58 54
59 - // Number of features = n  
60 - // Number of instances = l  
61 -  
62 problem prob; 55 problem prob;
63 prob.n = samples.cols; 56 prob.n = samples.cols;
64 prob.l = samples.rows; 57 prob.l = samples.rows;
@@ -84,6 +77,7 @@ private: @@ -84,6 +77,7 @@ private:
84 } 77 }
85 78
86 parameter param; 79 parameter param;
  80 +
87 // TODO: Support grid search 81 // TODO: Support grid search
88 param.C = C; 82 param.C = C;
89 param.eps = FLT_EPSILON; 83 param.eps = FLT_EPSILON;
@@ -109,19 +103,40 @@ private: @@ -109,19 +103,40 @@ private:
109 Mat sample = src.m().reshape(1,1); 103 Mat sample = src.m().reshape(1,1);
110 feature_node *x_space = new feature_node[sample.cols+1]; 104 feature_node *x_space = new feature_node[sample.cols+1];
111 105
112 - // Assign the address of the ith instance to be the address of the jth feature  
113 for (int j=0; j<sample.cols; j++) { 106 for (int j=0; j<sample.cols; j++) {
114 x_space[j].index = j+1; 107 x_space[j].index = j+1;
115 x_space[j].value = sample.at<float>(0,j); 108 x_space[j].value = sample.at<float>(0,j);
116 } 109 }
117 x_space[sample.cols].index = -1; 110 x_space[sample.cols].index = -1;
118 111
119 - // TODO: Call appropriate function based on solver  
120 - double prob_estimates[1];  
121 - float prediction = predict_values(m,x_space,prob_estimates); 112 + float prediction;
  113 + double prob_estimates[m->nr_class];
  114 +
  115 + if (solver == L2R_L2LOSS_SVR ||
  116 + solver == L2R_L1LOSS_SVR_DUAL ||
  117 + solver == L2R_L2LOSS_SVR_DUAL ||
  118 + solver == L2R_L2LOSS_SVC_DUAL ||
  119 + solver == L2R_L2LOSS_SVC ||
  120 + solver == L2R_L1LOSS_SVC_DUAL ||
  121 + solver == MCSVM_CS ||
  122 + solver == L1R_L2LOSS_SVC)
  123 + {
  124 + prediction = predict_values(m,x_space,prob_estimates);
  125 + if (returnDFVal) prediction = prob_estimates[0];
  126 + } else if (solver == L2R_LR ||
  127 + solver == L2R_LR_DUAL ||
  128 + solver == L1R_LR)
  129 + {
  130 + prediction = predict_probability(m,x_space,prob_estimates);
  131 + if (returnDFVal) prediction = prob_estimates[0];
  132 + }
122 133
123 - dst.m() = Mat(1, 1, CV_32F);  
124 - dst.m().at<float>(0, 0) = prob_estimates[0]; 134 + if (overwriteMat) {
  135 + dst.m() = Mat(1, 1, CV_32F);
  136 + dst.m().at<float>(0, 0) = prediction;
  137 + } else {
  138 + dst.file.set(outputVariable,prediction);
  139 + }
125 140
126 delete[] x_space; 141 delete[] x_space;
127 } 142 }
@@ -141,7 +156,7 @@ private: @@ -141,7 +156,7 @@ private:
141 } 156 }
142 }; 157 };
143 158
144 -BR_REGISTER(Transform, LinearSVM) 159 +BR_REGISTER(Transform, Linear)
145 160
146 } // namespace br 161 } // namespace br
147 162