From 4691ca211f37134b7ba7a111a466e5f36ef1a4fb Mon Sep 17 00:00:00 2001 From: Scott Klum Date: Mon, 9 Feb 2015 13:15:14 -0500 Subject: [PATCH] Liblinear cleanup --- openbr/plugins/liblinear.cpp | 98 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------------------- 1 file changed, 71 insertions(+), 27 deletions(-) diff --git a/openbr/plugins/liblinear.cpp b/openbr/plugins/liblinear.cpp index 7de2485..c421908 100644 --- a/openbr/plugins/liblinear.cpp +++ b/openbr/plugins/liblinear.cpp @@ -13,6 +13,39 @@ using namespace cv; namespace br { +static void storeModel(const model &m, QDataStream &stream) +{ + // Create local file + QTemporaryFile tempFile; + tempFile.open(); + tempFile.close(); + + // Save MLP to local file + save_model(qPrintable(tempFile.fileName()),&m); + + // Copy local file contents to stream + tempFile.open(); + QByteArray data = tempFile.readAll(); + tempFile.close(); + stream << data; +} + +static void loadModel(model &m, QDataStream &stream) +{ + // Copy local file contents from stream + QByteArray data; + stream >> data; + + // Create local file + QTemporaryFile tempFile(QDir::tempPath()+"/model"); + tempFile.open(); + tempFile.write(data); + tempFile.close(); + + // Load MLP from local file + m = *load_model(qPrintable(tempFile.fileName())); +} + class Linear : public Transform { Q_OBJECT @@ -23,19 +56,20 @@ class Linear : public Transform Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) + Q_PROPERTY(bool weight READ get_weight WRITE set_weight RESET reset_weight STORED false) public: - enum Solver { L2R_LR, - L2R_L2LOSS_SVC_DUAL, - L2R_L2LOSS_SVC, - L2R_L1LOSS_SVC_DUAL, - MCSVM_CS, - L1R_L2LOSS_SVC, - L1R_LR, - L2R_LR_DUAL, - L2R_L2LOSS_SVR, - L2R_L2LOSS_SVR_DUAL, - L2R_L1LOSS_SVR_DUAL }; + enum Solver { L2R_LR = ::L2R_LR, + L2R_L2LOSS_SVC_DUAL = ::L2R_L2LOSS_SVC_DUAL, + L2R_L2LOSS_SVC = ::L2R_L2LOSS_SVC, + L2R_L1LOSS_SVC_DUAL = ::L2R_L1LOSS_SVC_DUAL, + MCSVM_CS = ::MCSVM_CS, + L1R_L2LOSS_SVC = ::L1R_L2LOSS_SVC, + L1R_LR = ::L1R_LR, + L2R_LR_DUAL = ::L2R_LR_DUAL, + L2R_L2LOSS_SVR = ::L2R_L2LOSS_SVR, + L2R_L2LOSS_SVR_DUAL = ::L2R_L2LOSS_SVR_DUAL, + L2R_L1LOSS_SVR_DUAL = ::L2R_L1LOSS_SVR_DUAL }; private: BR_PROPERTY(Solver, solver, L2R_L2LOSS_SVC_DUAL) @@ -44,8 +78,9 @@ private: BR_PROPERTY(QString, outputVariable, "") BR_PROPERTY(bool, returnDFVal, false) BR_PROPERTY(bool, overwriteMat, true) + BR_PROPERTY(bool, weight, false) - model *m; + model m; void train(const TemplateList &data) { @@ -80,17 +115,30 @@ private: // TODO: Support grid search param.C = C; + param.p = 1; param.eps = FLT_EPSILON; param.solver_type = solver; - // TODO: Support weights - param.nr_weight = 0; - param.p = 1; - param.weight_label = NULL; - param.weight = NULL; + if (weight) { + param.nr_weight = 2; + param.weight_label = new int[2]; + param.weight = new double[2]; + param.weight_label[0] = 0; + param.weight_label[1] = 1; + int nonZero = countNonZero(labels); + param.weight[0] = 1; + param.weight[1] = (double)(prob.l-nonZero)/nonZero; + qDebug() << param.weight[0] << param.weight[1]; + } else { + param.nr_weight = 0; + param.weight_label = NULL; + param.weight = NULL; + } - m = train_svm(&prob, ¶m); + m = *train_svm(&prob, ¶m); + delete[] param.weight; + delete[] param.weight_label; delete[] prob.y; delete[] prob.x; delete[] x_space; @@ -110,7 +158,7 @@ private: x_space[sample.cols].index = -1; float prediction; - double prob_estimates[m->nr_class]; + double prob_estimates[m.nr_class]; if (solver == L2R_L2LOSS_SVR || solver == L2R_L1LOSS_SVR_DUAL || @@ -121,13 +169,13 @@ private: solver == MCSVM_CS || solver == L1R_L2LOSS_SVC) { - prediction = predict_values(m,x_space,prob_estimates); + prediction = predict_values(&m,x_space,prob_estimates); if (returnDFVal) prediction = prob_estimates[0]; } else if (solver == L2R_LR || solver == L2R_LR_DUAL || solver == L1R_LR) { - prediction = predict_probability(m,x_space,prob_estimates); + prediction = predict_probability(&m,x_space,prob_estimates); if (returnDFVal) prediction = prob_estimates[0]; } @@ -143,16 +191,12 @@ private: void store(QDataStream &stream) const { - QString filename = QString::number(qrand()); - stream << filename; - save_model(filename.toStdString().c_str(),m); + storeModel(m,stream); } void load(QDataStream &stream) { - QString filename; - stream >> filename; - m = load_model(filename.toStdString().c_str()); + loadModel(m,stream); } }; -- libgit2 0.21.4