Commit 4691ca211f37134b7ba7a111a466e5f36ef1a4fb

Authored by Scott Klum
1 parent b0f41b6e

Liblinear cleanup

Showing 1 changed file with 71 additions and 27 deletions
openbr/plugins/liblinear.cpp
@@ -13,6 +13,39 @@ using namespace cv; @@ -13,6 +13,39 @@ using namespace cv;
13 namespace br 13 namespace br
14 { 14 {
15 15
  16 +static void storeModel(const model &m, QDataStream &stream)
  17 +{
  18 + // Create local file
  19 + QTemporaryFile tempFile;
  20 + tempFile.open();
  21 + tempFile.close();
  22 +
  23 + // Save MLP to local file
  24 + save_model(qPrintable(tempFile.fileName()),&m);
  25 +
  26 + // Copy local file contents to stream
  27 + tempFile.open();
  28 + QByteArray data = tempFile.readAll();
  29 + tempFile.close();
  30 + stream << data;
  31 +}
  32 +
  33 +static void loadModel(model &m, QDataStream &stream)
  34 +{
  35 + // Copy local file contents from stream
  36 + QByteArray data;
  37 + stream >> data;
  38 +
  39 + // Create local file
  40 + QTemporaryFile tempFile(QDir::tempPath()+"/model");
  41 + tempFile.open();
  42 + tempFile.write(data);
  43 + tempFile.close();
  44 +
  45 + // Load MLP from local file
  46 + m = *load_model(qPrintable(tempFile.fileName()));
  47 +}
  48 +
16 class Linear : public Transform 49 class Linear : public Transform
17 { 50 {
18 Q_OBJECT 51 Q_OBJECT
@@ -23,19 +56,20 @@ class Linear : public Transform @@ -23,19 +56,20 @@ class Linear : public Transform
23 Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) 56 Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
24 Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) 57 Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false)
25 Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) 58 Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false)
  59 + Q_PROPERTY(bool weight READ get_weight WRITE set_weight RESET reset_weight STORED false)
26 60
27 public: 61 public:
28 - enum Solver { L2R_LR,  
29 - L2R_L2LOSS_SVC_DUAL,  
30 - L2R_L2LOSS_SVC,  
31 - L2R_L1LOSS_SVC_DUAL,  
32 - MCSVM_CS,  
33 - L1R_L2LOSS_SVC,  
34 - L1R_LR,  
35 - L2R_LR_DUAL,  
36 - L2R_L2LOSS_SVR,  
37 - L2R_L2LOSS_SVR_DUAL,  
38 - L2R_L1LOSS_SVR_DUAL }; 62 + enum Solver { L2R_LR = ::L2R_LR,
  63 + L2R_L2LOSS_SVC_DUAL = ::L2R_L2LOSS_SVC_DUAL,
  64 + L2R_L2LOSS_SVC = ::L2R_L2LOSS_SVC,
  65 + L2R_L1LOSS_SVC_DUAL = ::L2R_L1LOSS_SVC_DUAL,
  66 + MCSVM_CS = ::MCSVM_CS,
  67 + L1R_L2LOSS_SVC = ::L1R_L2LOSS_SVC,
  68 + L1R_LR = ::L1R_LR,
  69 + L2R_LR_DUAL = ::L2R_LR_DUAL,
  70 + L2R_L2LOSS_SVR = ::L2R_L2LOSS_SVR,
  71 + L2R_L2LOSS_SVR_DUAL = ::L2R_L2LOSS_SVR_DUAL,
  72 + L2R_L1LOSS_SVR_DUAL = ::L2R_L1LOSS_SVR_DUAL };
39 73
40 private: 74 private:
41 BR_PROPERTY(Solver, solver, L2R_L2LOSS_SVC_DUAL) 75 BR_PROPERTY(Solver, solver, L2R_L2LOSS_SVC_DUAL)
@@ -44,8 +78,9 @@ private: @@ -44,8 +78,9 @@ private:
44 BR_PROPERTY(QString, outputVariable, "") 78 BR_PROPERTY(QString, outputVariable, "")
45 BR_PROPERTY(bool, returnDFVal, false) 79 BR_PROPERTY(bool, returnDFVal, false)
46 BR_PROPERTY(bool, overwriteMat, true) 80 BR_PROPERTY(bool, overwriteMat, true)
  81 + BR_PROPERTY(bool, weight, false)
47 82
48 - model *m; 83 + model m;
49 84
50 void train(const TemplateList &data) 85 void train(const TemplateList &data)
51 { 86 {
@@ -80,17 +115,30 @@ private: @@ -80,17 +115,30 @@ private:
80 115
81 // TODO: Support grid search 116 // TODO: Support grid search
82 param.C = C; 117 param.C = C;
  118 + param.p = 1;
83 param.eps = FLT_EPSILON; 119 param.eps = FLT_EPSILON;
84 param.solver_type = solver; 120 param.solver_type = solver;
85 121
86 - // TODO: Support weights  
87 - param.nr_weight = 0;  
88 - param.p = 1;  
89 - param.weight_label = NULL;  
90 - param.weight = NULL; 122 + if (weight) {
  123 + param.nr_weight = 2;
  124 + param.weight_label = new int[2];
  125 + param.weight = new double[2];
  126 + param.weight_label[0] = 0;
  127 + param.weight_label[1] = 1;
  128 + int nonZero = countNonZero(labels);
  129 + param.weight[0] = 1;
  130 + param.weight[1] = (double)(prob.l-nonZero)/nonZero;
  131 + qDebug() << param.weight[0] << param.weight[1];
  132 + } else {
  133 + param.nr_weight = 0;
  134 + param.weight_label = NULL;
  135 + param.weight = NULL;
  136 + }
91 137
92 - m = train_svm(&prob, &param); 138 + m = *train_svm(&prob, &param);
93 139
  140 + delete[] param.weight;
  141 + delete[] param.weight_label;
94 delete[] prob.y; 142 delete[] prob.y;
95 delete[] prob.x; 143 delete[] prob.x;
96 delete[] x_space; 144 delete[] x_space;
@@ -110,7 +158,7 @@ private: @@ -110,7 +158,7 @@ private:
110 x_space[sample.cols].index = -1; 158 x_space[sample.cols].index = -1;
111 159
112 float prediction; 160 float prediction;
113 - double prob_estimates[m->nr_class]; 161 + double prob_estimates[m.nr_class];
114 162
115 if (solver == L2R_L2LOSS_SVR || 163 if (solver == L2R_L2LOSS_SVR ||
116 solver == L2R_L1LOSS_SVR_DUAL || 164 solver == L2R_L1LOSS_SVR_DUAL ||
@@ -121,13 +169,13 @@ private: @@ -121,13 +169,13 @@ private:
121 solver == MCSVM_CS || 169 solver == MCSVM_CS ||
122 solver == L1R_L2LOSS_SVC) 170 solver == L1R_L2LOSS_SVC)
123 { 171 {
124 - prediction = predict_values(m,x_space,prob_estimates); 172 + prediction = predict_values(&m,x_space,prob_estimates);
125 if (returnDFVal) prediction = prob_estimates[0]; 173 if (returnDFVal) prediction = prob_estimates[0];
126 } else if (solver == L2R_LR || 174 } else if (solver == L2R_LR ||
127 solver == L2R_LR_DUAL || 175 solver == L2R_LR_DUAL ||
128 solver == L1R_LR) 176 solver == L1R_LR)
129 { 177 {
130 - prediction = predict_probability(m,x_space,prob_estimates); 178 + prediction = predict_probability(&m,x_space,prob_estimates);
131 if (returnDFVal) prediction = prob_estimates[0]; 179 if (returnDFVal) prediction = prob_estimates[0];
132 } 180 }
133 181
@@ -143,16 +191,12 @@ private: @@ -143,16 +191,12 @@ private:
143 191
144 void store(QDataStream &stream) const 192 void store(QDataStream &stream) const
145 { 193 {
146 - QString filename = QString::number(qrand());  
147 - stream << filename;  
148 - save_model(filename.toStdString().c_str(),m); 194 + storeModel(m,stream);
149 } 195 }
150 196
151 void load(QDataStream &stream) 197 void load(QDataStream &stream)
152 { 198 {
153 - QString filename;  
154 - stream >> filename;  
155 - m = load_model(filename.toStdString().c_str()); 199 + loadModel(m,stream);
156 } 200 }
157 }; 201 };
158 202