Commit 9e9e1a9b627e2b045d06a98586887cbdd24f930d

Authored by Scott Klum
1 parent 914912b1

More liblinear cleanup

Showing 1 changed file with 32 additions and 17 deletions
openbr/plugins/liblinear.cpp
... ... @@ -13,7 +13,7 @@ using namespace cv;
13 13 namespace br
14 14 {
15 15  
16   -class LinearSVM : public Transform
  16 +class Linear : public Transform
17 17 {
18 18 Q_OBJECT
19 19 Q_ENUMS(Solver)
... ... @@ -22,9 +22,7 @@ class LinearSVM : public Transform
22 22 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
23 23 Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
24 24 Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false)
25   - Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false)
26   - Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false)
27   - Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false)
  25 + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false)
28 26  
29 27 public:
30 28 enum Solver { L2R_LR,
... ... @@ -45,9 +43,7 @@ private:
45 43 BR_PROPERTY(QString, inputVariable, "Label")
46 44 BR_PROPERTY(QString, outputVariable, "")
47 45 BR_PROPERTY(bool, returnDFVal, false)
48   - BR_PROPERTY(int, termCriteria, 1000)
49   - BR_PROPERTY(int, folds, 5)
50   - BR_PROPERTY(bool, balanceFolds, false)
  46 + BR_PROPERTY(bool, overwriteMat, true)
51 47  
52 48 model *m;
53 49  
... ... @@ -56,9 +52,6 @@ private:
56 52 Mat samples = OpenCVUtils::toMat(data.data());
57 53 Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable));
58 54  
59   - // Number of features = n
60   - // Number of instances = l
61   -
62 55 problem prob;
63 56 prob.n = samples.cols;
64 57 prob.l = samples.rows;
... ... @@ -84,6 +77,7 @@ private:
84 77 }
85 78  
86 79 parameter param;
  80 +
87 81 // TODO: Support grid search
88 82 param.C = C;
89 83 param.eps = FLT_EPSILON;
... ... @@ -109,19 +103,40 @@ private:
109 103 Mat sample = src.m().reshape(1,1);
110 104 feature_node *x_space = new feature_node[sample.cols+1];
111 105  
112   - // Assign the address of the ith instance to be the address of the jth feature
113 106 for (int j=0; j<sample.cols; j++) {
114 107 x_space[j].index = j+1;
115 108 x_space[j].value = sample.at<float>(0,j);
116 109 }
117 110 x_space[sample.cols].index = -1;
118 111  
119   - // TODO: Call appropriate function based on solver
120   - double prob_estimates[1];
121   - float prediction = predict_values(m,x_space,prob_estimates);
  112 + float prediction;
  113 + double prob_estimates[m->nr_class];
  114 +
  115 + if (solver == L2R_L2LOSS_SVR ||
  116 + solver == L2R_L1LOSS_SVR_DUAL ||
  117 + solver == L2R_L2LOSS_SVR_DUAL ||
  118 + solver == L2R_L2LOSS_SVC_DUAL ||
  119 + solver == L2R_L2LOSS_SVC ||
  120 + solver == L2R_L1LOSS_SVC_DUAL ||
  121 + solver == MCSVM_CS ||
  122 + solver == L1R_L2LOSS_SVC)
  123 + {
  124 + prediction = predict_values(m,x_space,prob_estimates);
  125 + if (returnDFVal) prediction = prob_estimates[0];
  126 + } else if (solver == L2R_LR ||
  127 + solver == L2R_LR_DUAL ||
  128 + solver == L1R_LR)
  129 + {
  130 + prediction = predict_probability(m,x_space,prob_estimates);
  131 + if (returnDFVal) prediction = prob_estimates[0];
  132 + }
122 133  
123   - dst.m() = Mat(1, 1, CV_32F);
124   - dst.m().at<float>(0, 0) = prob_estimates[0];
  134 + if (overwriteMat) {
  135 + dst.m() = Mat(1, 1, CV_32F);
  136 + dst.m().at<float>(0, 0) = prediction;
  137 + } else {
  138 + dst.file.set(outputVariable,prediction);
  139 + }
125 140  
126 141 delete[] x_space;
127 142 }
... ... @@ -141,7 +156,7 @@ private:
141 156 }
142 157 };
143 158  
144   -BR_REGISTER(Transform, LinearSVM)
  159 +BR_REGISTER(Transform, Linear)
145 160  
146 161 } // namespace br
147 162  
... ...