Commit f61b9ea59a17e96fde9684e9ea920357ca863437
Merge pull request #327 from biometrics/liblinear
Liblinear
Showing
4 changed files
with
391 additions
and
23 deletions
openbr/plugins/liblinear.cmake
0 โ 100644
openbr/plugins/liblinear.cpp
0 โ 100644
| 1 | +#include <QTemporaryFile> | ||
| 2 | +#include <opencv2/core/core.hpp> | ||
| 3 | +#include <opencv2/ml/ml.hpp> | ||
| 4 | + | ||
| 5 | +#include "openbr_internal.h" | ||
| 6 | +#include "openbr/core/opencvutils.h" | ||
| 7 | + | ||
| 8 | +#include <linear.h> | ||
| 9 | + | ||
| 10 | +using namespace cv; | ||
| 11 | + | ||
| 12 | +namespace br | ||
| 13 | +{ | ||
| 14 | + | ||
| 15 | +static void storeModel(const model &m, QDataStream &stream) | ||
| 16 | +{ | ||
| 17 | + // Create local file | ||
| 18 | + QTemporaryFile tempFile; | ||
| 19 | + tempFile.open(); | ||
| 20 | + tempFile.close(); | ||
| 21 | + | ||
| 22 | + // Save MLP to local file | ||
| 23 | + save_model(qPrintable(tempFile.fileName()),&m); | ||
| 24 | + | ||
| 25 | + // Copy local file contents to stream | ||
| 26 | + tempFile.open(); | ||
| 27 | + QByteArray data = tempFile.readAll(); | ||
| 28 | + tempFile.close(); | ||
| 29 | + stream << data; | ||
| 30 | +} | ||
| 31 | + | ||
| 32 | +static void loadModel(model &m, QDataStream &stream) | ||
| 33 | +{ | ||
| 34 | + // Copy local file contents from stream | ||
| 35 | + QByteArray data; | ||
| 36 | + stream >> data; | ||
| 37 | + | ||
| 38 | + // Create local file | ||
| 39 | + QTemporaryFile tempFile(QDir::tempPath()+"/model"); | ||
| 40 | + tempFile.open(); | ||
| 41 | + tempFile.write(data); | ||
| 42 | + tempFile.close(); | ||
| 43 | + | ||
| 44 | + // Load MLP from local file | ||
| 45 | + m = *load_model(qPrintable(tempFile.fileName())); | ||
| 46 | +} | ||
| 47 | + | ||
| 48 | +class Linear : public Transform | ||
| 49 | +{ | ||
| 50 | + Q_OBJECT | ||
| 51 | + Q_ENUMS(Solver) | ||
| 52 | + Q_PROPERTY(Solver solver READ get_solver WRITE set_solver RESET reset_solver STORED false) | ||
| 53 | + Q_PROPERTY(float C READ get_C WRITE set_C RESET reset_C STORED false) | ||
| 54 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | ||
| 55 | + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | ||
| 56 | + Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) | ||
| 57 | + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) | ||
| 58 | + Q_PROPERTY(bool weight READ get_weight WRITE set_weight RESET reset_weight STORED false) | ||
| 59 | + | ||
| 60 | +public: | ||
| 61 | + enum Solver { L2R_LR = ::L2R_LR, | ||
| 62 | + L2R_L2LOSS_SVC_DUAL = ::L2R_L2LOSS_SVC_DUAL, | ||
| 63 | + L2R_L2LOSS_SVC = ::L2R_L2LOSS_SVC, | ||
| 64 | + L2R_L1LOSS_SVC_DUAL = ::L2R_L1LOSS_SVC_DUAL, | ||
| 65 | + MCSVM_CS = ::MCSVM_CS, | ||
| 66 | + L1R_L2LOSS_SVC = ::L1R_L2LOSS_SVC, | ||
| 67 | + L1R_LR = ::L1R_LR, | ||
| 68 | + L2R_LR_DUAL = ::L2R_LR_DUAL, | ||
| 69 | + L2R_L2LOSS_SVR = ::L2R_L2LOSS_SVR, | ||
| 70 | + L2R_L2LOSS_SVR_DUAL = ::L2R_L2LOSS_SVR_DUAL, | ||
| 71 | + L2R_L1LOSS_SVR_DUAL = ::L2R_L1LOSS_SVR_DUAL }; | ||
| 72 | + | ||
| 73 | +private: | ||
| 74 | + BR_PROPERTY(Solver, solver, L2R_L2LOSS_SVC_DUAL) | ||
| 75 | + BR_PROPERTY(float, C, 1) | ||
| 76 | + BR_PROPERTY(QString, inputVariable, "Label") | ||
| 77 | + BR_PROPERTY(QString, outputVariable, "") | ||
| 78 | + BR_PROPERTY(bool, returnDFVal, false) | ||
| 79 | + BR_PROPERTY(bool, overwriteMat, true) | ||
| 80 | + BR_PROPERTY(bool, weight, false) | ||
| 81 | + | ||
| 82 | + model m; | ||
| 83 | + | ||
| 84 | + void train(const TemplateList &data) | ||
| 85 | + { | ||
| 86 | + Mat samples = OpenCVUtils::toMat(data.data()); | ||
| 87 | + Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); | ||
| 88 | + | ||
| 89 | + problem prob; | ||
| 90 | + prob.n = samples.cols; | ||
| 91 | + prob.l = samples.rows; | ||
| 92 | + prob.bias = -1; | ||
| 93 | + prob.y = new double[prob.l]; | ||
| 94 | + | ||
| 95 | + for (int i=0; i<prob.l; i++) | ||
| 96 | + prob.y[i] = labels.at<float>(i,0); | ||
| 97 | + | ||
| 98 | + // Allocate enough memory for l feature_nodes pointers | ||
| 99 | + prob.x = new feature_node*[prob.l]; | ||
| 100 | + feature_node *x_space = new feature_node[(prob.n+1)*prob.l]; | ||
| 101 | + | ||
| 102 | + int k = 0; | ||
| 103 | + for (int i=0; i<prob.l; i++) { | ||
| 104 | + prob.x[i] = &x_space[k]; | ||
| 105 | + for (int j=0; j<prob.n; j++) { | ||
| 106 | + x_space[k].index = j+1; | ||
| 107 | + x_space[k].value = samples.at<float>(i,j); | ||
| 108 | + k++; | ||
| 109 | + } | ||
| 110 | + x_space[k++].index = -1; | ||
| 111 | + } | ||
| 112 | + | ||
| 113 | + parameter param; | ||
| 114 | + | ||
| 115 | + // TODO: Support grid search | ||
| 116 | + param.C = C; | ||
| 117 | + param.p = 1; | ||
| 118 | + param.eps = FLT_EPSILON; | ||
| 119 | + param.solver_type = solver; | ||
| 120 | + | ||
| 121 | + if (weight) { | ||
| 122 | + param.nr_weight = 2; | ||
| 123 | + param.weight_label = new int[2]; | ||
| 124 | + param.weight = new double[2]; | ||
| 125 | + param.weight_label[0] = 0; | ||
| 126 | + param.weight_label[1] = 1; | ||
| 127 | + int nonZero = countNonZero(labels); | ||
| 128 | + param.weight[0] = 1; | ||
| 129 | + param.weight[1] = (double)(prob.l-nonZero)/nonZero; | ||
| 130 | + qDebug() << param.weight[0] << param.weight[1]; | ||
| 131 | + } else { | ||
| 132 | + param.nr_weight = 0; | ||
| 133 | + param.weight_label = NULL; | ||
| 134 | + param.weight = NULL; | ||
| 135 | + } | ||
| 136 | + | ||
| 137 | + m = *train_svm(&prob, ¶m); | ||
| 138 | + | ||
| 139 | + delete[] param.weight; | ||
| 140 | + delete[] param.weight_label; | ||
| 141 | + delete[] prob.y; | ||
| 142 | + delete[] prob.x; | ||
| 143 | + delete[] x_space; | ||
| 144 | + } | ||
| 145 | + | ||
| 146 | + void project(const Template &src, Template &dst) const | ||
| 147 | + { | ||
| 148 | + dst = src; | ||
| 149 | + | ||
| 150 | + Mat sample = src.m().reshape(1,1); | ||
| 151 | + feature_node *x_space = new feature_node[sample.cols+1]; | ||
| 152 | + | ||
| 153 | + for (int j=0; j<sample.cols; j++) { | ||
| 154 | + x_space[j].index = j+1; | ||
| 155 | + x_space[j].value = sample.at<float>(0,j); | ||
| 156 | + } | ||
| 157 | + x_space[sample.cols].index = -1; | ||
| 158 | + | ||
| 159 | + float prediction; | ||
| 160 | + double prob_estimates[m.nr_class]; | ||
| 161 | + | ||
| 162 | + if (solver == L2R_L2LOSS_SVR || | ||
| 163 | + solver == L2R_L1LOSS_SVR_DUAL || | ||
| 164 | + solver == L2R_L2LOSS_SVR_DUAL || | ||
| 165 | + solver == L2R_L2LOSS_SVC_DUAL || | ||
| 166 | + solver == L2R_L2LOSS_SVC || | ||
| 167 | + solver == L2R_L1LOSS_SVC_DUAL || | ||
| 168 | + solver == MCSVM_CS || | ||
| 169 | + solver == L1R_L2LOSS_SVC) | ||
| 170 | + { | ||
| 171 | + prediction = predict_values(&m,x_space,prob_estimates); | ||
| 172 | + if (returnDFVal) prediction = prob_estimates[0]; | ||
| 173 | + } else if (solver == L2R_LR || | ||
| 174 | + solver == L2R_LR_DUAL || | ||
| 175 | + solver == L1R_LR) | ||
| 176 | + { | ||
| 177 | + prediction = predict_probability(&m,x_space,prob_estimates); | ||
| 178 | + if (returnDFVal) prediction = prob_estimates[0]; | ||
| 179 | + } | ||
| 180 | + | ||
| 181 | + if (overwriteMat) { | ||
| 182 | + dst.m() = Mat(1, 1, CV_32F); | ||
| 183 | + dst.m().at<float>(0, 0) = prediction; | ||
| 184 | + } else { | ||
| 185 | + dst.file.set(outputVariable,prediction); | ||
| 186 | + } | ||
| 187 | + | ||
| 188 | + delete[] x_space; | ||
| 189 | + } | ||
| 190 | + | ||
| 191 | + void store(QDataStream &stream) const | ||
| 192 | + { | ||
| 193 | + storeModel(m,stream); | ||
| 194 | + } | ||
| 195 | + | ||
| 196 | + void load(QDataStream &stream) | ||
| 197 | + { | ||
| 198 | + loadModel(m,stream); | ||
| 199 | + } | ||
| 200 | +}; | ||
| 201 | + | ||
| 202 | +BR_REGISTER(Transform, Linear) | ||
| 203 | + | ||
| 204 | +} // namespace br | ||
| 205 | + | ||
| 206 | +#include "liblinear.moc" |
openbr/plugins/tree.cpp
| @@ -35,7 +35,7 @@ static void loadModel(CvStatModel &model, QDataStream &stream) | @@ -35,7 +35,7 @@ static void loadModel(CvStatModel &model, QDataStream &stream) | ||
| 35 | stream >> data; | 35 | stream >> data; |
| 36 | 36 | ||
| 37 | // Create local file | 37 | // Create local file |
| 38 | - QTemporaryFile tempFile(QDir::tempPath()+"/model"); | 38 | + QTemporaryFile tempFile(QDir::tempPath()+"/"+QString::number(rand())); |
| 39 | tempFile.open(); | 39 | tempFile.open(); |
| 40 | tempFile.write(data); | 40 | tempFile.write(data); |
| 41 | tempFile.close(); | 41 | tempFile.close(); |
| @@ -53,6 +53,50 @@ static void loadModel(CvStatModel &model, QDataStream &stream) | @@ -53,6 +53,50 @@ static void loadModel(CvStatModel &model, QDataStream &stream) | ||
| 53 | class ForestTransform : public Transform | 53 | class ForestTransform : public Transform |
| 54 | { | 54 | { |
| 55 | Q_OBJECT | 55 | Q_OBJECT |
| 56 | + | ||
| 57 | + void train(const TemplateList &data) | ||
| 58 | + { | ||
| 59 | + trainForest(data); | ||
| 60 | + } | ||
| 61 | + | ||
| 62 | + void project(const Template &src, Template &dst) const | ||
| 63 | + { | ||
| 64 | + dst = src; | ||
| 65 | + | ||
| 66 | + float response; | ||
| 67 | + if (classification && returnConfidence) { | ||
| 68 | + // Fuzzy class label | ||
| 69 | + response = forest.predict_prob(src.m().reshape(1,1)); | ||
| 70 | + } else { | ||
| 71 | + response = forest.predict(src.m().reshape(1,1)); | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + if (overwriteMat) { | ||
| 75 | + dst.m() = Mat(1, 1, CV_32F); | ||
| 76 | + dst.m().at<float>(0, 0) = response; | ||
| 77 | + } else { | ||
| 78 | + dst.file.set(outputVariable, response); | ||
| 79 | + } | ||
| 80 | + } | ||
| 81 | + | ||
| 82 | + void load(QDataStream &stream) | ||
| 83 | + { | ||
| 84 | + loadModel(forest,stream); | ||
| 85 | + } | ||
| 86 | + | ||
| 87 | + void store(QDataStream &stream) const | ||
| 88 | + { | ||
| 89 | + storeModel(forest,stream); | ||
| 90 | + } | ||
| 91 | + | ||
| 92 | + void init() | ||
| 93 | + { | ||
| 94 | + if (outputVariable.isEmpty()) | ||
| 95 | + outputVariable = inputVariable; | ||
| 96 | + } | ||
| 97 | + | ||
| 98 | +protected: | ||
| 99 | + Q_ENUMS(TerminationCriteria) | ||
| 56 | Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED false) | 100 | Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED false) |
| 57 | Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED false) | 101 | Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED false) |
| 58 | Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false) | 102 | Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false) |
| @@ -62,6 +106,15 @@ class ForestTransform : public Transform | @@ -62,6 +106,15 @@ class ForestTransform : public Transform | ||
| 62 | Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) | 106 | Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) |
| 63 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | 107 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 64 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | 108 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) |
| 109 | + Q_PROPERTY(bool weight READ get_weight WRITE set_weight RESET reset_weight STORED false) | ||
| 110 | + Q_PROPERTY(TerminationCriteria termCrit READ get_termCrit WRITE set_termCrit RESET reset_termCrit STORED false) | ||
| 111 | + | ||
| 112 | +public: | ||
| 113 | + enum TerminationCriteria { Iter = CV_TERMCRIT_ITER, | ||
| 114 | + EPS = CV_TERMCRIT_EPS, | ||
| 115 | + Both = CV_TERMCRIT_EPS | CV_TERMCRIT_ITER}; | ||
| 116 | + | ||
| 117 | +protected: | ||
| 65 | BR_PROPERTY(bool, classification, true) | 118 | BR_PROPERTY(bool, classification, true) |
| 66 | BR_PROPERTY(float, splitPercentage, .01) | 119 | BR_PROPERTY(float, splitPercentage, .01) |
| 67 | BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max()) | 120 | BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max()) |
| @@ -71,10 +124,12 @@ class ForestTransform : public Transform | @@ -71,10 +124,12 @@ class ForestTransform : public Transform | ||
| 71 | BR_PROPERTY(bool, overwriteMat, true) | 124 | BR_PROPERTY(bool, overwriteMat, true) |
| 72 | BR_PROPERTY(QString, inputVariable, "Label") | 125 | BR_PROPERTY(QString, inputVariable, "Label") |
| 73 | BR_PROPERTY(QString, outputVariable, "") | 126 | BR_PROPERTY(QString, outputVariable, "") |
| 127 | + BR_PROPERTY(bool, weight, false) | ||
| 128 | + BR_PROPERTY(TerminationCriteria, termCrit, Iter) | ||
| 74 | 129 | ||
| 75 | CvRTrees forest; | 130 | CvRTrees forest; |
| 76 | 131 | ||
| 77 | - void train(const TemplateList &data) | 132 | + void trainForest(const TemplateList &data) |
| 78 | { | 133 | { |
| 79 | Mat samples = OpenCVUtils::toMat(data.data()); | 134 | Mat samples = OpenCVUtils::toMat(data.data()); |
| 80 | Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); | 135 | Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); |
| @@ -88,6 +143,14 @@ class ForestTransform : public Transform | @@ -88,6 +143,14 @@ class ForestTransform : public Transform | ||
| 88 | types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL; | 143 | types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL; |
| 89 | } | 144 | } |
| 90 | 145 | ||
| 146 | + bool usePrior = classification && weight; | ||
| 147 | + float priors[2]; | ||
| 148 | + if (usePrior) { | ||
| 149 | + int nonZero = countNonZero(labels); | ||
| 150 | + priors[0] = 1; | ||
| 151 | + priors[1] = (float)(samples.rows-nonZero)/nonZero; | ||
| 152 | + } | ||
| 153 | + | ||
| 91 | int minSamplesForSplit = data.size()*splitPercentage; | 154 | int minSamplesForSplit = data.size()*splitPercentage; |
| 92 | forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), | 155 | forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), |
| 93 | CvRTParams(maxDepth, | 156 | CvRTParams(maxDepth, |
| @@ -95,54 +158,133 @@ class ForestTransform : public Transform | @@ -95,54 +158,133 @@ class ForestTransform : public Transform | ||
| 95 | 0, | 158 | 0, |
| 96 | false, | 159 | false, |
| 97 | 2, | 160 | 2, |
| 98 | - 0, // priors | 161 | + usePrior ? priors : 0, |
| 99 | false, | 162 | false, |
| 100 | 0, | 163 | 0, |
| 101 | maxTrees, | 164 | maxTrees, |
| 102 | forestAccuracy, | 165 | forestAccuracy, |
| 103 | - CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)); | 166 | + termCrit)); |
| 167 | + | ||
| 168 | + if (Globals->verbose) { | ||
| 169 | + qDebug() << "Number of trees:" << forest.get_tree_count(); | ||
| 170 | + | ||
| 171 | + if (classification) { | ||
| 172 | + QTime timer; | ||
| 173 | + timer.start(); | ||
| 174 | + int correctClassification = 0; | ||
| 175 | + float regressionError = 0; | ||
| 176 | + for (int i=0; i<samples.rows; i++) { | ||
| 177 | + float prediction = forest.predict_prob(samples.row(i)); | ||
| 178 | + int label = forest.predict(samples.row(i)); | ||
| 179 | + if (label == labels.at<float>(i,0)) { | ||
| 180 | + correctClassification++; | ||
| 181 | + } | ||
| 182 | + regressionError += fabs(prediction-labels.at<float>(i,0)); | ||
| 183 | + } | ||
| 184 | + | ||
| 185 | + qDebug("Time to classify %d samples: %d ms\n \ | ||
| 186 | + Classification Accuracy: %f\n \ | ||
| 187 | + MAE: %f\n \ | ||
| 188 | + Sample dimensionality: %d", | ||
| 189 | + samples.rows,timer.elapsed(),(float)correctClassification/samples.rows,regressionError/samples.rows,samples.cols); | ||
| 190 | + } | ||
| 191 | + } | ||
| 192 | + } | ||
| 193 | +}; | ||
| 194 | + | ||
| 195 | +BR_REGISTER(Transform, ForestTransform) | ||
| 196 | + | ||
| 197 | +/*! | ||
| 198 | + * \ingroup transforms | ||
| 199 | + * \brief Wraps OpenCV's random trees framework to induce features | ||
| 200 | + * \author Scott Klum \cite sklum | ||
| 201 | + * \brief https://lirias.kuleuven.be/bitstream/123456789/316661/1/icdm11-camready.pdf | ||
| 202 | + */ | ||
| 203 | +class ForestInductionTransform : public ForestTransform | ||
| 204 | +{ | ||
| 205 | + Q_OBJECT | ||
| 206 | + Q_PROPERTY(bool useRegressionValue READ get_useRegressionValue WRITE set_useRegressionValue RESET reset_useRegressionValue STORED false) | ||
| 207 | + BR_PROPERTY(bool, useRegressionValue, false) | ||
| 208 | + | ||
| 209 | + int totalSize; | ||
| 210 | + QList< QList<const CvDTreeNode*> > nodes; | ||
| 211 | + | ||
| 212 | + void fillNodes() | ||
| 213 | + { | ||
| 214 | + for (int i=0; i<forest.get_tree_count(); i++) { | ||
| 215 | + nodes.append(QList<const CvDTreeNode*>()); | ||
| 216 | + const CvDTreeNode* node = forest.get_tree(i)->get_root(); | ||
| 217 | + | ||
| 218 | + // traverse the tree and save all the nodes in depth-first order | ||
| 219 | + for(;;) | ||
| 220 | + { | ||
| 221 | + CvDTreeNode* parent; | ||
| 222 | + for(;;) | ||
| 223 | + { | ||
| 224 | + if( !node->left ) | ||
| 225 | + break; | ||
| 226 | + node = node->left; | ||
| 227 | + } | ||
| 228 | + | ||
| 229 | + nodes.last().append(node); | ||
| 230 | + | ||
| 231 | + for( parent = node->parent; parent && parent->right == node; | ||
| 232 | + node = parent, parent = parent->parent ) | ||
| 233 | + ; | ||
| 234 | + | ||
| 235 | + if( !parent ) | ||
| 236 | + break; | ||
| 237 | + | ||
| 238 | + node = parent->right; | ||
| 239 | + } | ||
| 240 | + | ||
| 241 | + totalSize += nodes.last().size(); | ||
| 242 | + } | ||
| 243 | + } | ||
| 104 | 244 | ||
| 105 | - qDebug() << "Number of trees:" << forest.get_tree_count(); | 245 | + void train(const TemplateList &data) |
| 246 | + { | ||
| 247 | + trainForest(data); | ||
| 248 | + if (!useRegressionValue) fillNodes(); | ||
| 106 | } | 249 | } |
| 107 | 250 | ||
| 108 | void project(const Template &src, Template &dst) const | 251 | void project(const Template &src, Template &dst) const |
| 109 | { | 252 | { |
| 110 | dst = src; | 253 | dst = src; |
| 111 | 254 | ||
| 112 | - float response; | ||
| 113 | - if (classification && returnConfidence) { | ||
| 114 | - // Fuzzy class label | ||
| 115 | - response = forest.predict_prob(src.m().reshape(1,1)); | ||
| 116 | - } else { | ||
| 117 | - response = forest.predict(src.m().reshape(1,1)); | ||
| 118 | - } | 255 | + Mat responses; |
| 119 | 256 | ||
| 120 | - if (overwriteMat) { | ||
| 121 | - dst.m() = Mat(1, 1, CV_32F); | ||
| 122 | - dst.m().at<float>(0, 0) = response; | 257 | + if (useRegressionValue) { |
| 258 | + responses = Mat::zeros(forest.get_tree_count(),1,CV_32F); | ||
| 259 | + for (int i=0; i<forest.get_tree_count(); i++) { | ||
| 260 | + responses.at<float>(i,0) = forest.get_tree(i)->predict(src.m().reshape(1,1))->value; | ||
| 261 | + } | ||
| 123 | } else { | 262 | } else { |
| 124 | - dst.file.set(outputVariable, response); | 263 | + responses = Mat::zeros(totalSize,1,CV_32F); |
| 264 | + int offset = 0; | ||
| 265 | + for (int i=0; i<nodes.size(); i++) { | ||
| 266 | + int index = nodes[i].indexOf(forest.get_tree(i)->predict(src.m().reshape(1,1))); | ||
| 267 | + responses.at<float>(offset+index,0) = 1; | ||
| 268 | + offset += nodes[i].size(); | ||
| 269 | + } | ||
| 125 | } | 270 | } |
| 271 | + | ||
| 272 | + dst.m() = responses; | ||
| 126 | } | 273 | } |
| 127 | 274 | ||
| 128 | void load(QDataStream &stream) | 275 | void load(QDataStream &stream) |
| 129 | { | 276 | { |
| 130 | loadModel(forest,stream); | 277 | loadModel(forest,stream); |
| 278 | + if (!useRegressionValue) fillNodes(); | ||
| 131 | } | 279 | } |
| 132 | 280 | ||
| 133 | void store(QDataStream &stream) const | 281 | void store(QDataStream &stream) const |
| 134 | { | 282 | { |
| 135 | storeModel(forest,stream); | 283 | storeModel(forest,stream); |
| 136 | } | 284 | } |
| 137 | - | ||
| 138 | - void init() | ||
| 139 | - { | ||
| 140 | - if (outputVariable.isEmpty()) | ||
| 141 | - outputVariable = inputVariable; | ||
| 142 | - } | ||
| 143 | }; | 285 | }; |
| 144 | 286 | ||
| 145 | -BR_REGISTER(Transform, ForestTransform) | 287 | +BR_REGISTER(Transform, ForestInductionTransform) |
| 146 | 288 | ||
| 147 | /*! | 289 | /*! |
| 148 | * \ingroup transforms | 290 | * \ingroup transforms |
share/openbr/cmake/FindLibLinear.cmake
0 โ 100644
| 1 | +find_path(LibLinear_DIR linear.h ${CMAKE_SOURCE_DIR}/3rdparty/*) | ||
| 2 | + | ||
| 3 | +message(${LibLinear_DIR}) | ||
| 4 | +mark_as_advanced(LibLinear_DIR) | ||
| 5 | +include_directories(${LibLinear_DIR}) | ||
| 6 | +include_directories(${LibLinear_DIR}/blas) | ||
| 7 | + | ||
| 8 | +set(LibLinear_SRC ${LibLinear_DIR}/linear.cpp | ||
| 9 | + ${LibLinear_DIR}/tron.cpp | ||
| 10 | + ${LibLinear_DIR}/blas/daxpy.c | ||
| 11 | + ${LibLinear_DIR}/blas/ddot.c | ||
| 12 | + ${LibLinear_DIR}/blas/dnrm2.c | ||
| 13 | + ${LibLinear_DIR}/blas/dscal.c) |