Commit 4691ca211f37134b7ba7a111a466e5f36ef1a4fb
1 parent
b0f41b6e
Liblinear cleanup
Showing
1 changed file
with
71 additions
and
27 deletions
openbr/plugins/liblinear.cpp
| ... | ... | @@ -13,6 +13,39 @@ using namespace cv; |
| 13 | 13 | namespace br |
| 14 | 14 | { |
| 15 | 15 | |
| 16 | +static void storeModel(const model &m, QDataStream &stream) | |
| 17 | +{ | |
| 18 | + // Create local file | |
| 19 | + QTemporaryFile tempFile; | |
| 20 | + tempFile.open(); | |
| 21 | + tempFile.close(); | |
| 22 | + | |
| 23 | + // Save MLP to local file | |
| 24 | + save_model(qPrintable(tempFile.fileName()),&m); | |
| 25 | + | |
| 26 | + // Copy local file contents to stream | |
| 27 | + tempFile.open(); | |
| 28 | + QByteArray data = tempFile.readAll(); | |
| 29 | + tempFile.close(); | |
| 30 | + stream << data; | |
| 31 | +} | |
| 32 | + | |
| 33 | +static void loadModel(model &m, QDataStream &stream) | |
| 34 | +{ | |
| 35 | + // Copy local file contents from stream | |
| 36 | + QByteArray data; | |
| 37 | + stream >> data; | |
| 38 | + | |
| 39 | + // Create local file | |
| 40 | + QTemporaryFile tempFile(QDir::tempPath()+"/model"); | |
| 41 | + tempFile.open(); | |
| 42 | + tempFile.write(data); | |
| 43 | + tempFile.close(); | |
| 44 | + | |
| 45 | + // Load MLP from local file | |
| 46 | + m = *load_model(qPrintable(tempFile.fileName())); | |
| 47 | +} | |
| 48 | + | |
| 16 | 49 | class Linear : public Transform |
| 17 | 50 | { |
| 18 | 51 | Q_OBJECT |
| ... | ... | @@ -23,19 +56,20 @@ class Linear : public Transform |
| 23 | 56 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) |
| 24 | 57 | Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) |
| 25 | 58 | Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) |
| 59 | + Q_PROPERTY(bool weight READ get_weight WRITE set_weight RESET reset_weight STORED false) | |
| 26 | 60 | |
| 27 | 61 | public: |
| 28 | - enum Solver { L2R_LR, | |
| 29 | - L2R_L2LOSS_SVC_DUAL, | |
| 30 | - L2R_L2LOSS_SVC, | |
| 31 | - L2R_L1LOSS_SVC_DUAL, | |
| 32 | - MCSVM_CS, | |
| 33 | - L1R_L2LOSS_SVC, | |
| 34 | - L1R_LR, | |
| 35 | - L2R_LR_DUAL, | |
| 36 | - L2R_L2LOSS_SVR, | |
| 37 | - L2R_L2LOSS_SVR_DUAL, | |
| 38 | - L2R_L1LOSS_SVR_DUAL }; | |
| 62 | + enum Solver { L2R_LR = ::L2R_LR, | |
| 63 | + L2R_L2LOSS_SVC_DUAL = ::L2R_L2LOSS_SVC_DUAL, | |
| 64 | + L2R_L2LOSS_SVC = ::L2R_L2LOSS_SVC, | |
| 65 | + L2R_L1LOSS_SVC_DUAL = ::L2R_L1LOSS_SVC_DUAL, | |
| 66 | + MCSVM_CS = ::MCSVM_CS, | |
| 67 | + L1R_L2LOSS_SVC = ::L1R_L2LOSS_SVC, | |
| 68 | + L1R_LR = ::L1R_LR, | |
| 69 | + L2R_LR_DUAL = ::L2R_LR_DUAL, | |
| 70 | + L2R_L2LOSS_SVR = ::L2R_L2LOSS_SVR, | |
| 71 | + L2R_L2LOSS_SVR_DUAL = ::L2R_L2LOSS_SVR_DUAL, | |
| 72 | + L2R_L1LOSS_SVR_DUAL = ::L2R_L1LOSS_SVR_DUAL }; | |
| 39 | 73 | |
| 40 | 74 | private: |
| 41 | 75 | BR_PROPERTY(Solver, solver, L2R_L2LOSS_SVC_DUAL) |
| ... | ... | @@ -44,8 +78,9 @@ private: |
| 44 | 78 | BR_PROPERTY(QString, outputVariable, "") |
| 45 | 79 | BR_PROPERTY(bool, returnDFVal, false) |
| 46 | 80 | BR_PROPERTY(bool, overwriteMat, true) |
| 81 | + BR_PROPERTY(bool, weight, false) | |
| 47 | 82 | |
| 48 | - model *m; | |
| 83 | + model m; | |
| 49 | 84 | |
| 50 | 85 | void train(const TemplateList &data) |
| 51 | 86 | { |
| ... | ... | @@ -80,17 +115,30 @@ private: |
| 80 | 115 | |
| 81 | 116 | // TODO: Support grid search |
| 82 | 117 | param.C = C; |
| 118 | + param.p = 1; | |
| 83 | 119 | param.eps = FLT_EPSILON; |
| 84 | 120 | param.solver_type = solver; |
| 85 | 121 | |
| 86 | - // TODO: Support weights | |
| 87 | - param.nr_weight = 0; | |
| 88 | - param.p = 1; | |
| 89 | - param.weight_label = NULL; | |
| 90 | - param.weight = NULL; | |
| 122 | + if (weight) { | |
| 123 | + param.nr_weight = 2; | |
| 124 | + param.weight_label = new int[2]; | |
| 125 | + param.weight = new double[2]; | |
| 126 | + param.weight_label[0] = 0; | |
| 127 | + param.weight_label[1] = 1; | |
| 128 | + int nonZero = countNonZero(labels); | |
| 129 | + param.weight[0] = 1; | |
| 130 | + param.weight[1] = (double)(prob.l-nonZero)/nonZero; | |
| 131 | + qDebug() << param.weight[0] << param.weight[1]; | |
| 132 | + } else { | |
| 133 | + param.nr_weight = 0; | |
| 134 | + param.weight_label = NULL; | |
| 135 | + param.weight = NULL; | |
| 136 | + } | |
| 91 | 137 | |
| 92 | - m = train_svm(&prob, ¶m); | |
| 138 | + m = *train_svm(&prob, ¶m); | |
| 93 | 139 | |
| 140 | + delete[] param.weight; | |
| 141 | + delete[] param.weight_label; | |
| 94 | 142 | delete[] prob.y; |
| 95 | 143 | delete[] prob.x; |
| 96 | 144 | delete[] x_space; |
| ... | ... | @@ -110,7 +158,7 @@ private: |
| 110 | 158 | x_space[sample.cols].index = -1; |
| 111 | 159 | |
| 112 | 160 | float prediction; |
| 113 | - double prob_estimates[m->nr_class]; | |
| 161 | + double prob_estimates[m.nr_class]; | |
| 114 | 162 | |
| 115 | 163 | if (solver == L2R_L2LOSS_SVR || |
| 116 | 164 | solver == L2R_L1LOSS_SVR_DUAL || |
| ... | ... | @@ -121,13 +169,13 @@ private: |
| 121 | 169 | solver == MCSVM_CS || |
| 122 | 170 | solver == L1R_L2LOSS_SVC) |
| 123 | 171 | { |
| 124 | - prediction = predict_values(m,x_space,prob_estimates); | |
| 172 | + prediction = predict_values(&m,x_space,prob_estimates); | |
| 125 | 173 | if (returnDFVal) prediction = prob_estimates[0]; |
| 126 | 174 | } else if (solver == L2R_LR || |
| 127 | 175 | solver == L2R_LR_DUAL || |
| 128 | 176 | solver == L1R_LR) |
| 129 | 177 | { |
| 130 | - prediction = predict_probability(m,x_space,prob_estimates); | |
| 178 | + prediction = predict_probability(&m,x_space,prob_estimates); | |
| 131 | 179 | if (returnDFVal) prediction = prob_estimates[0]; |
| 132 | 180 | } |
| 133 | 181 | |
| ... | ... | @@ -143,16 +191,12 @@ private: |
| 143 | 191 | |
| 144 | 192 | void store(QDataStream &stream) const |
| 145 | 193 | { |
| 146 | - QString filename = QString::number(qrand()); | |
| 147 | - stream << filename; | |
| 148 | - save_model(filename.toStdString().c_str(),m); | |
| 194 | + storeModel(m,stream); | |
| 149 | 195 | } |
| 150 | 196 | |
| 151 | 197 | void load(QDataStream &stream) |
| 152 | 198 | { |
| 153 | - QString filename; | |
| 154 | - stream >> filename; | |
| 155 | - m = load_model(filename.toStdString().c_str()); | |
| 199 | + loadModel(m,stream); | |
| 156 | 200 | } |
| 157 | 201 | }; |
| 158 | 202 | ... | ... |