Commit 4691ca211f37134b7ba7a111a466e5f36ef1a4fb

Authored by Scott Klum
1 parent b0f41b6e

Liblinear cleanup

Showing 1 changed file with 71 additions and 27 deletions
openbr/plugins/liblinear.cpp
... ... @@ -13,6 +13,39 @@ using namespace cv;
13 13 namespace br
14 14 {
15 15  
  16 +static void storeModel(const model &m, QDataStream &stream)
  17 +{
  18 + // Create local file
  19 + QTemporaryFile tempFile;
  20 + tempFile.open();
  21 + tempFile.close();
  22 +
  23 + // Save MLP to local file
  24 + save_model(qPrintable(tempFile.fileName()),&m);
  25 +
  26 + // Copy local file contents to stream
  27 + tempFile.open();
  28 + QByteArray data = tempFile.readAll();
  29 + tempFile.close();
  30 + stream << data;
  31 +}
  32 +
  33 +static void loadModel(model &m, QDataStream &stream)
  34 +{
  35 + // Copy local file contents from stream
  36 + QByteArray data;
  37 + stream >> data;
  38 +
  39 + // Create local file
  40 + QTemporaryFile tempFile(QDir::tempPath()+"/model");
  41 + tempFile.open();
  42 + tempFile.write(data);
  43 + tempFile.close();
  44 +
  45 + // Load MLP from local file
  46 + m = *load_model(qPrintable(tempFile.fileName()));
  47 +}
  48 +
16 49 class Linear : public Transform
17 50 {
18 51 Q_OBJECT
... ... @@ -23,19 +56,20 @@ class Linear : public Transform
23 56 Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
24 57 Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false)
25 58 Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false)
  59 + Q_PROPERTY(bool weight READ get_weight WRITE set_weight RESET reset_weight STORED false)
26 60  
27 61 public:
28   - enum Solver { L2R_LR,
29   - L2R_L2LOSS_SVC_DUAL,
30   - L2R_L2LOSS_SVC,
31   - L2R_L1LOSS_SVC_DUAL,
32   - MCSVM_CS,
33   - L1R_L2LOSS_SVC,
34   - L1R_LR,
35   - L2R_LR_DUAL,
36   - L2R_L2LOSS_SVR,
37   - L2R_L2LOSS_SVR_DUAL,
38   - L2R_L1LOSS_SVR_DUAL };
  62 + enum Solver { L2R_LR = ::L2R_LR,
  63 + L2R_L2LOSS_SVC_DUAL = ::L2R_L2LOSS_SVC_DUAL,
  64 + L2R_L2LOSS_SVC = ::L2R_L2LOSS_SVC,
  65 + L2R_L1LOSS_SVC_DUAL = ::L2R_L1LOSS_SVC_DUAL,
  66 + MCSVM_CS = ::MCSVM_CS,
  67 + L1R_L2LOSS_SVC = ::L1R_L2LOSS_SVC,
  68 + L1R_LR = ::L1R_LR,
  69 + L2R_LR_DUAL = ::L2R_LR_DUAL,
  70 + L2R_L2LOSS_SVR = ::L2R_L2LOSS_SVR,
  71 + L2R_L2LOSS_SVR_DUAL = ::L2R_L2LOSS_SVR_DUAL,
  72 + L2R_L1LOSS_SVR_DUAL = ::L2R_L1LOSS_SVR_DUAL };
39 73  
40 74 private:
41 75 BR_PROPERTY(Solver, solver, L2R_L2LOSS_SVC_DUAL)
... ... @@ -44,8 +78,9 @@ private:
44 78 BR_PROPERTY(QString, outputVariable, "")
45 79 BR_PROPERTY(bool, returnDFVal, false)
46 80 BR_PROPERTY(bool, overwriteMat, true)
  81 + BR_PROPERTY(bool, weight, false)
47 82  
48   - model *m;
  83 + model m;
49 84  
50 85 void train(const TemplateList &data)
51 86 {
... ... @@ -80,17 +115,30 @@ private:
80 115  
81 116 // TODO: Support grid search
82 117 param.C = C;
  118 + param.p = 1;
83 119 param.eps = FLT_EPSILON;
84 120 param.solver_type = solver;
85 121  
86   - // TODO: Support weights
87   - param.nr_weight = 0;
88   - param.p = 1;
89   - param.weight_label = NULL;
90   - param.weight = NULL;
  122 + if (weight) {
  123 + param.nr_weight = 2;
  124 + param.weight_label = new int[2];
  125 + param.weight = new double[2];
  126 + param.weight_label[0] = 0;
  127 + param.weight_label[1] = 1;
  128 + int nonZero = countNonZero(labels);
  129 + param.weight[0] = 1;
  130 + param.weight[1] = (double)(prob.l-nonZero)/nonZero;
  131 + qDebug() << param.weight[0] << param.weight[1];
  132 + } else {
  133 + param.nr_weight = 0;
  134 + param.weight_label = NULL;
  135 + param.weight = NULL;
  136 + }
91 137  
92   - m = train_svm(&prob, &param);
  138 + m = *train_svm(&prob, &param);
93 139  
  140 + delete[] param.weight;
  141 + delete[] param.weight_label;
94 142 delete[] prob.y;
95 143 delete[] prob.x;
96 144 delete[] x_space;
... ... @@ -110,7 +158,7 @@ private:
110 158 x_space[sample.cols].index = -1;
111 159  
112 160 float prediction;
113   - double prob_estimates[m->nr_class];
  161 + double prob_estimates[m.nr_class];
114 162  
115 163 if (solver == L2R_L2LOSS_SVR ||
116 164 solver == L2R_L1LOSS_SVR_DUAL ||
... ... @@ -121,13 +169,13 @@ private:
121 169 solver == MCSVM_CS ||
122 170 solver == L1R_L2LOSS_SVC)
123 171 {
124   - prediction = predict_values(m,x_space,prob_estimates);
  172 + prediction = predict_values(&m,x_space,prob_estimates);
125 173 if (returnDFVal) prediction = prob_estimates[0];
126 174 } else if (solver == L2R_LR ||
127 175 solver == L2R_LR_DUAL ||
128 176 solver == L1R_LR)
129 177 {
130   - prediction = predict_probability(m,x_space,prob_estimates);
  178 + prediction = predict_probability(&m,x_space,prob_estimates);
131 179 if (returnDFVal) prediction = prob_estimates[0];
132 180 }
133 181  
... ... @@ -143,16 +191,12 @@ private:
143 191  
144 192 void store(QDataStream &stream) const
145 193 {
146   - QString filename = QString::number(qrand());
147   - stream << filename;
148   - save_model(filename.toStdString().c_str(),m);
  194 + storeModel(m,stream);
149 195 }
150 196  
151 197 void load(QDataStream &stream)
152 198 {
153   - QString filename;
154   - stream >> filename;
155   - m = load_model(filename.toStdString().c_str());
  199 + loadModel(m,stream);
156 200 }
157 201 };
158 202  
... ...