Commit 4691ca211f37134b7ba7a111a466e5f36ef1a4fb
1 parent
b0f41b6e
Liblinear cleanup
Showing
1 changed file
with
71 additions
and
27 deletions
openbr/plugins/liblinear.cpp
| @@ -13,6 +13,39 @@ using namespace cv; | @@ -13,6 +13,39 @@ using namespace cv; | ||
| 13 | namespace br | 13 | namespace br |
| 14 | { | 14 | { |
| 15 | 15 | ||
| 16 | +static void storeModel(const model &m, QDataStream &stream) | ||
| 17 | +{ | ||
| 18 | + // Create local file | ||
| 19 | + QTemporaryFile tempFile; | ||
| 20 | + tempFile.open(); | ||
| 21 | + tempFile.close(); | ||
| 22 | + | ||
| 23 | + // Save MLP to local file | ||
| 24 | + save_model(qPrintable(tempFile.fileName()),&m); | ||
| 25 | + | ||
| 26 | + // Copy local file contents to stream | ||
| 27 | + tempFile.open(); | ||
| 28 | + QByteArray data = tempFile.readAll(); | ||
| 29 | + tempFile.close(); | ||
| 30 | + stream << data; | ||
| 31 | +} | ||
| 32 | + | ||
| 33 | +static void loadModel(model &m, QDataStream &stream) | ||
| 34 | +{ | ||
| 35 | + // Copy local file contents from stream | ||
| 36 | + QByteArray data; | ||
| 37 | + stream >> data; | ||
| 38 | + | ||
| 39 | + // Create local file | ||
| 40 | + QTemporaryFile tempFile(QDir::tempPath()+"/model"); | ||
| 41 | + tempFile.open(); | ||
| 42 | + tempFile.write(data); | ||
| 43 | + tempFile.close(); | ||
| 44 | + | ||
| 45 | + // Load MLP from local file | ||
| 46 | + m = *load_model(qPrintable(tempFile.fileName())); | ||
| 47 | +} | ||
| 48 | + | ||
| 16 | class Linear : public Transform | 49 | class Linear : public Transform |
| 17 | { | 50 | { |
| 18 | Q_OBJECT | 51 | Q_OBJECT |
| @@ -23,19 +56,20 @@ class Linear : public Transform | @@ -23,19 +56,20 @@ class Linear : public Transform | ||
| 23 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | 56 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) |
| 24 | Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) | 57 | Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) |
| 25 | Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) | 58 | Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) |
| 59 | + Q_PROPERTY(bool weight READ get_weight WRITE set_weight RESET reset_weight STORED false) | ||
| 26 | 60 | ||
| 27 | public: | 61 | public: |
| 28 | - enum Solver { L2R_LR, | ||
| 29 | - L2R_L2LOSS_SVC_DUAL, | ||
| 30 | - L2R_L2LOSS_SVC, | ||
| 31 | - L2R_L1LOSS_SVC_DUAL, | ||
| 32 | - MCSVM_CS, | ||
| 33 | - L1R_L2LOSS_SVC, | ||
| 34 | - L1R_LR, | ||
| 35 | - L2R_LR_DUAL, | ||
| 36 | - L2R_L2LOSS_SVR, | ||
| 37 | - L2R_L2LOSS_SVR_DUAL, | ||
| 38 | - L2R_L1LOSS_SVR_DUAL }; | 62 | + enum Solver { L2R_LR = ::L2R_LR, |
| 63 | + L2R_L2LOSS_SVC_DUAL = ::L2R_L2LOSS_SVC_DUAL, | ||
| 64 | + L2R_L2LOSS_SVC = ::L2R_L2LOSS_SVC, | ||
| 65 | + L2R_L1LOSS_SVC_DUAL = ::L2R_L1LOSS_SVC_DUAL, | ||
| 66 | + MCSVM_CS = ::MCSVM_CS, | ||
| 67 | + L1R_L2LOSS_SVC = ::L1R_L2LOSS_SVC, | ||
| 68 | + L1R_LR = ::L1R_LR, | ||
| 69 | + L2R_LR_DUAL = ::L2R_LR_DUAL, | ||
| 70 | + L2R_L2LOSS_SVR = ::L2R_L2LOSS_SVR, | ||
| 71 | + L2R_L2LOSS_SVR_DUAL = ::L2R_L2LOSS_SVR_DUAL, | ||
| 72 | + L2R_L1LOSS_SVR_DUAL = ::L2R_L1LOSS_SVR_DUAL }; | ||
| 39 | 73 | ||
| 40 | private: | 74 | private: |
| 41 | BR_PROPERTY(Solver, solver, L2R_L2LOSS_SVC_DUAL) | 75 | BR_PROPERTY(Solver, solver, L2R_L2LOSS_SVC_DUAL) |
| @@ -44,8 +78,9 @@ private: | @@ -44,8 +78,9 @@ private: | ||
| 44 | BR_PROPERTY(QString, outputVariable, "") | 78 | BR_PROPERTY(QString, outputVariable, "") |
| 45 | BR_PROPERTY(bool, returnDFVal, false) | 79 | BR_PROPERTY(bool, returnDFVal, false) |
| 46 | BR_PROPERTY(bool, overwriteMat, true) | 80 | BR_PROPERTY(bool, overwriteMat, true) |
| 81 | + BR_PROPERTY(bool, weight, false) | ||
| 47 | 82 | ||
| 48 | - model *m; | 83 | + model m; |
| 49 | 84 | ||
| 50 | void train(const TemplateList &data) | 85 | void train(const TemplateList &data) |
| 51 | { | 86 | { |
| @@ -80,17 +115,30 @@ private: | @@ -80,17 +115,30 @@ private: | ||
| 80 | 115 | ||
| 81 | // TODO: Support grid search | 116 | // TODO: Support grid search |
| 82 | param.C = C; | 117 | param.C = C; |
| 118 | + param.p = 1; | ||
| 83 | param.eps = FLT_EPSILON; | 119 | param.eps = FLT_EPSILON; |
| 84 | param.solver_type = solver; | 120 | param.solver_type = solver; |
| 85 | 121 | ||
| 86 | - // TODO: Support weights | ||
| 87 | - param.nr_weight = 0; | ||
| 88 | - param.p = 1; | ||
| 89 | - param.weight_label = NULL; | ||
| 90 | - param.weight = NULL; | 122 | + if (weight) { |
| 123 | + param.nr_weight = 2; | ||
| 124 | + param.weight_label = new int[2]; | ||
| 125 | + param.weight = new double[2]; | ||
| 126 | + param.weight_label[0] = 0; | ||
| 127 | + param.weight_label[1] = 1; | ||
| 128 | + int nonZero = countNonZero(labels); | ||
| 129 | + param.weight[0] = 1; | ||
| 130 | + param.weight[1] = (double)(prob.l-nonZero)/nonZero; | ||
| 131 | + qDebug() << param.weight[0] << param.weight[1]; | ||
| 132 | + } else { | ||
| 133 | + param.nr_weight = 0; | ||
| 134 | + param.weight_label = NULL; | ||
| 135 | + param.weight = NULL; | ||
| 136 | + } | ||
| 91 | 137 | ||
| 92 | - m = train_svm(&prob, ¶m); | 138 | + m = *train_svm(&prob, ¶m); |
| 93 | 139 | ||
| 140 | + delete[] param.weight; | ||
| 141 | + delete[] param.weight_label; | ||
| 94 | delete[] prob.y; | 142 | delete[] prob.y; |
| 95 | delete[] prob.x; | 143 | delete[] prob.x; |
| 96 | delete[] x_space; | 144 | delete[] x_space; |
| @@ -110,7 +158,7 @@ private: | @@ -110,7 +158,7 @@ private: | ||
| 110 | x_space[sample.cols].index = -1; | 158 | x_space[sample.cols].index = -1; |
| 111 | 159 | ||
| 112 | float prediction; | 160 | float prediction; |
| 113 | - double prob_estimates[m->nr_class]; | 161 | + double prob_estimates[m.nr_class]; |
| 114 | 162 | ||
| 115 | if (solver == L2R_L2LOSS_SVR || | 163 | if (solver == L2R_L2LOSS_SVR || |
| 116 | solver == L2R_L1LOSS_SVR_DUAL || | 164 | solver == L2R_L1LOSS_SVR_DUAL || |
| @@ -121,13 +169,13 @@ private: | @@ -121,13 +169,13 @@ private: | ||
| 121 | solver == MCSVM_CS || | 169 | solver == MCSVM_CS || |
| 122 | solver == L1R_L2LOSS_SVC) | 170 | solver == L1R_L2LOSS_SVC) |
| 123 | { | 171 | { |
| 124 | - prediction = predict_values(m,x_space,prob_estimates); | 172 | + prediction = predict_values(&m,x_space,prob_estimates); |
| 125 | if (returnDFVal) prediction = prob_estimates[0]; | 173 | if (returnDFVal) prediction = prob_estimates[0]; |
| 126 | } else if (solver == L2R_LR || | 174 | } else if (solver == L2R_LR || |
| 127 | solver == L2R_LR_DUAL || | 175 | solver == L2R_LR_DUAL || |
| 128 | solver == L1R_LR) | 176 | solver == L1R_LR) |
| 129 | { | 177 | { |
| 130 | - prediction = predict_probability(m,x_space,prob_estimates); | 178 | + prediction = predict_probability(&m,x_space,prob_estimates); |
| 131 | if (returnDFVal) prediction = prob_estimates[0]; | 179 | if (returnDFVal) prediction = prob_estimates[0]; |
| 132 | } | 180 | } |
| 133 | 181 | ||
| @@ -143,16 +191,12 @@ private: | @@ -143,16 +191,12 @@ private: | ||
| 143 | 191 | ||
| 144 | void store(QDataStream &stream) const | 192 | void store(QDataStream &stream) const |
| 145 | { | 193 | { |
| 146 | - QString filename = QString::number(qrand()); | ||
| 147 | - stream << filename; | ||
| 148 | - save_model(filename.toStdString().c_str(),m); | 194 | + storeModel(m,stream); |
| 149 | } | 195 | } |
| 150 | 196 | ||
| 151 | void load(QDataStream &stream) | 197 | void load(QDataStream &stream) |
| 152 | { | 198 | { |
| 153 | - QString filename; | ||
| 154 | - stream >> filename; | ||
| 155 | - m = load_model(filename.toStdString().c_str()); | 199 | + loadModel(m,stream); |
| 156 | } | 200 | } |
| 157 | }; | 201 | }; |
| 158 | 202 |