Commit 9e32d3b1e65021d0d55ae55ff15ab5f44da5875e
Merge remote-tracking branch 'origin/master' into opencv_model_storage
Conflicts: openbr/plugins/tree.cpp
Showing
4 changed files
with
391 additions
and
22 deletions
openbr/plugins/liblinear.cmake
0 โ 100644
openbr/plugins/liblinear.cpp
0 โ 100644
| 1 | +#include <QTemporaryFile> | ||
| 2 | +#include <opencv2/core/core.hpp> | ||
| 3 | +#include <opencv2/ml/ml.hpp> | ||
| 4 | + | ||
| 5 | +#include "openbr_internal.h" | ||
| 6 | +#include "openbr/core/opencvutils.h" | ||
| 7 | + | ||
| 8 | +#include <linear.h> | ||
| 9 | + | ||
| 10 | +using namespace cv; | ||
| 11 | + | ||
| 12 | +namespace br | ||
| 13 | +{ | ||
| 14 | + | ||
| 15 | +static void storeModel(const model &m, QDataStream &stream) | ||
| 16 | +{ | ||
| 17 | + // Create local file | ||
| 18 | + QTemporaryFile tempFile; | ||
| 19 | + tempFile.open(); | ||
| 20 | + tempFile.close(); | ||
| 21 | + | ||
| 22 | + // Save MLP to local file | ||
| 23 | + save_model(qPrintable(tempFile.fileName()),&m); | ||
| 24 | + | ||
| 25 | + // Copy local file contents to stream | ||
| 26 | + tempFile.open(); | ||
| 27 | + QByteArray data = tempFile.readAll(); | ||
| 28 | + tempFile.close(); | ||
| 29 | + stream << data; | ||
| 30 | +} | ||
| 31 | + | ||
| 32 | +static void loadModel(model &m, QDataStream &stream) | ||
| 33 | +{ | ||
| 34 | + // Copy local file contents from stream | ||
| 35 | + QByteArray data; | ||
| 36 | + stream >> data; | ||
| 37 | + | ||
| 38 | + // Create local file | ||
| 39 | + QTemporaryFile tempFile(QDir::tempPath()+"/model"); | ||
| 40 | + tempFile.open(); | ||
| 41 | + tempFile.write(data); | ||
| 42 | + tempFile.close(); | ||
| 43 | + | ||
| 44 | + // Load MLP from local file | ||
| 45 | + m = *load_model(qPrintable(tempFile.fileName())); | ||
| 46 | +} | ||
| 47 | + | ||
| 48 | +class Linear : public Transform | ||
| 49 | +{ | ||
| 50 | + Q_OBJECT | ||
| 51 | + Q_ENUMS(Solver) | ||
| 52 | + Q_PROPERTY(Solver solver READ get_solver WRITE set_solver RESET reset_solver STORED false) | ||
| 53 | + Q_PROPERTY(float C READ get_C WRITE set_C RESET reset_C STORED false) | ||
| 54 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | ||
| 55 | + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | ||
| 56 | + Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) | ||
| 57 | + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) | ||
| 58 | + Q_PROPERTY(bool weight READ get_weight WRITE set_weight RESET reset_weight STORED false) | ||
| 59 | + | ||
| 60 | +public: | ||
| 61 | + enum Solver { L2R_LR = ::L2R_LR, | ||
| 62 | + L2R_L2LOSS_SVC_DUAL = ::L2R_L2LOSS_SVC_DUAL, | ||
| 63 | + L2R_L2LOSS_SVC = ::L2R_L2LOSS_SVC, | ||
| 64 | + L2R_L1LOSS_SVC_DUAL = ::L2R_L1LOSS_SVC_DUAL, | ||
| 65 | + MCSVM_CS = ::MCSVM_CS, | ||
| 66 | + L1R_L2LOSS_SVC = ::L1R_L2LOSS_SVC, | ||
| 67 | + L1R_LR = ::L1R_LR, | ||
| 68 | + L2R_LR_DUAL = ::L2R_LR_DUAL, | ||
| 69 | + L2R_L2LOSS_SVR = ::L2R_L2LOSS_SVR, | ||
| 70 | + L2R_L2LOSS_SVR_DUAL = ::L2R_L2LOSS_SVR_DUAL, | ||
| 71 | + L2R_L1LOSS_SVR_DUAL = ::L2R_L1LOSS_SVR_DUAL }; | ||
| 72 | + | ||
| 73 | +private: | ||
| 74 | + BR_PROPERTY(Solver, solver, L2R_L2LOSS_SVC_DUAL) | ||
| 75 | + BR_PROPERTY(float, C, 1) | ||
| 76 | + BR_PROPERTY(QString, inputVariable, "Label") | ||
| 77 | + BR_PROPERTY(QString, outputVariable, "") | ||
| 78 | + BR_PROPERTY(bool, returnDFVal, false) | ||
| 79 | + BR_PROPERTY(bool, overwriteMat, true) | ||
| 80 | + BR_PROPERTY(bool, weight, false) | ||
| 81 | + | ||
| 82 | + model m; | ||
| 83 | + | ||
| 84 | + void train(const TemplateList &data) | ||
| 85 | + { | ||
| 86 | + Mat samples = OpenCVUtils::toMat(data.data()); | ||
| 87 | + Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); | ||
| 88 | + | ||
| 89 | + problem prob; | ||
| 90 | + prob.n = samples.cols; | ||
| 91 | + prob.l = samples.rows; | ||
| 92 | + prob.bias = -1; | ||
| 93 | + prob.y = new double[prob.l]; | ||
| 94 | + | ||
| 95 | + for (int i=0; i<prob.l; i++) | ||
| 96 | + prob.y[i] = labels.at<float>(i,0); | ||
| 97 | + | ||
| 98 | + // Allocate enough memory for l feature_nodes pointers | ||
| 99 | + prob.x = new feature_node*[prob.l]; | ||
| 100 | + feature_node *x_space = new feature_node[(prob.n+1)*prob.l]; | ||
| 101 | + | ||
| 102 | + int k = 0; | ||
| 103 | + for (int i=0; i<prob.l; i++) { | ||
| 104 | + prob.x[i] = &x_space[k]; | ||
| 105 | + for (int j=0; j<prob.n; j++) { | ||
| 106 | + x_space[k].index = j+1; | ||
| 107 | + x_space[k].value = samples.at<float>(i,j); | ||
| 108 | + k++; | ||
| 109 | + } | ||
| 110 | + x_space[k++].index = -1; | ||
| 111 | + } | ||
| 112 | + | ||
| 113 | + parameter param; | ||
| 114 | + | ||
| 115 | + // TODO: Support grid search | ||
| 116 | + param.C = C; | ||
| 117 | + param.p = 1; | ||
| 118 | + param.eps = FLT_EPSILON; | ||
| 119 | + param.solver_type = solver; | ||
| 120 | + | ||
| 121 | + if (weight) { | ||
| 122 | + param.nr_weight = 2; | ||
| 123 | + param.weight_label = new int[2]; | ||
| 124 | + param.weight = new double[2]; | ||
| 125 | + param.weight_label[0] = 0; | ||
| 126 | + param.weight_label[1] = 1; | ||
| 127 | + int nonZero = countNonZero(labels); | ||
| 128 | + param.weight[0] = 1; | ||
| 129 | + param.weight[1] = (double)(prob.l-nonZero)/nonZero; | ||
| 130 | + qDebug() << param.weight[0] << param.weight[1]; | ||
| 131 | + } else { | ||
| 132 | + param.nr_weight = 0; | ||
| 133 | + param.weight_label = NULL; | ||
| 134 | + param.weight = NULL; | ||
| 135 | + } | ||
| 136 | + | ||
| 137 | + m = *train_svm(&prob, ¶m); | ||
| 138 | + | ||
| 139 | + delete[] param.weight; | ||
| 140 | + delete[] param.weight_label; | ||
| 141 | + delete[] prob.y; | ||
| 142 | + delete[] prob.x; | ||
| 143 | + delete[] x_space; | ||
| 144 | + } | ||
| 145 | + | ||
| 146 | + void project(const Template &src, Template &dst) const | ||
| 147 | + { | ||
| 148 | + dst = src; | ||
| 149 | + | ||
| 150 | + Mat sample = src.m().reshape(1,1); | ||
| 151 | + feature_node *x_space = new feature_node[sample.cols+1]; | ||
| 152 | + | ||
| 153 | + for (int j=0; j<sample.cols; j++) { | ||
| 154 | + x_space[j].index = j+1; | ||
| 155 | + x_space[j].value = sample.at<float>(0,j); | ||
| 156 | + } | ||
| 157 | + x_space[sample.cols].index = -1; | ||
| 158 | + | ||
| 159 | + float prediction; | ||
| 160 | + double prob_estimates[m.nr_class]; | ||
| 161 | + | ||
| 162 | + if (solver == L2R_L2LOSS_SVR || | ||
| 163 | + solver == L2R_L1LOSS_SVR_DUAL || | ||
| 164 | + solver == L2R_L2LOSS_SVR_DUAL || | ||
| 165 | + solver == L2R_L2LOSS_SVC_DUAL || | ||
| 166 | + solver == L2R_L2LOSS_SVC || | ||
| 167 | + solver == L2R_L1LOSS_SVC_DUAL || | ||
| 168 | + solver == MCSVM_CS || | ||
| 169 | + solver == L1R_L2LOSS_SVC) | ||
| 170 | + { | ||
| 171 | + prediction = predict_values(&m,x_space,prob_estimates); | ||
| 172 | + if (returnDFVal) prediction = prob_estimates[0]; | ||
| 173 | + } else if (solver == L2R_LR || | ||
| 174 | + solver == L2R_LR_DUAL || | ||
| 175 | + solver == L1R_LR) | ||
| 176 | + { | ||
| 177 | + prediction = predict_probability(&m,x_space,prob_estimates); | ||
| 178 | + if (returnDFVal) prediction = prob_estimates[0]; | ||
| 179 | + } | ||
| 180 | + | ||
| 181 | + if (overwriteMat) { | ||
| 182 | + dst.m() = Mat(1, 1, CV_32F); | ||
| 183 | + dst.m().at<float>(0, 0) = prediction; | ||
| 184 | + } else { | ||
| 185 | + dst.file.set(outputVariable,prediction); | ||
| 186 | + } | ||
| 187 | + | ||
| 188 | + delete[] x_space; | ||
| 189 | + } | ||
| 190 | + | ||
| 191 | + void store(QDataStream &stream) const | ||
| 192 | + { | ||
| 193 | + storeModel(m,stream); | ||
| 194 | + } | ||
| 195 | + | ||
| 196 | + void load(QDataStream &stream) | ||
| 197 | + { | ||
| 198 | + loadModel(m,stream); | ||
| 199 | + } | ||
| 200 | +}; | ||
| 201 | + | ||
| 202 | +BR_REGISTER(Transform, Linear) | ||
| 203 | + | ||
| 204 | +} // namespace br | ||
| 205 | + | ||
| 206 | +#include "liblinear.moc" |
openbr/plugins/tree.cpp
| @@ -16,6 +16,50 @@ namespace br | @@ -16,6 +16,50 @@ namespace br | ||
| 16 | class ForestTransform : public Transform | 16 | class ForestTransform : public Transform |
| 17 | { | 17 | { |
| 18 | Q_OBJECT | 18 | Q_OBJECT |
| 19 | + | ||
| 20 | + void train(const TemplateList &data) | ||
| 21 | + { | ||
| 22 | + trainForest(data); | ||
| 23 | + } | ||
| 24 | + | ||
| 25 | + void project(const Template &src, Template &dst) const | ||
| 26 | + { | ||
| 27 | + dst = src; | ||
| 28 | + | ||
| 29 | + float response; | ||
| 30 | + if (classification && returnConfidence) { | ||
| 31 | + // Fuzzy class label | ||
| 32 | + response = forest.predict_prob(src.m().reshape(1,1)); | ||
| 33 | + } else { | ||
| 34 | + response = forest.predict(src.m().reshape(1,1)); | ||
| 35 | + } | ||
| 36 | + | ||
| 37 | + if (overwriteMat) { | ||
| 38 | + dst.m() = Mat(1, 1, CV_32F); | ||
| 39 | + dst.m().at<float>(0, 0) = response; | ||
| 40 | + } else { | ||
| 41 | + dst.file.set(outputVariable, response); | ||
| 42 | + } | ||
| 43 | + } | ||
| 44 | + | ||
| 45 | + void load(QDataStream &stream) | ||
| 46 | + { | ||
| 47 | + OpenCVUtils::loadModel(forest,stream); | ||
| 48 | + } | ||
| 49 | + | ||
| 50 | + void store(QDataStream &stream) const | ||
| 51 | + { | ||
| 52 | + OpenCVUtils::storeModel(forest,stream); | ||
| 53 | + } | ||
| 54 | + | ||
| 55 | + void init() | ||
| 56 | + { | ||
| 57 | + if (outputVariable.isEmpty()) | ||
| 58 | + outputVariable = inputVariable; | ||
| 59 | + } | ||
| 60 | + | ||
| 61 | +protected: | ||
| 62 | + Q_ENUMS(TerminationCriteria) | ||
| 19 | Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED false) | 63 | Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED false) |
| 20 | Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED false) | 64 | Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED false) |
| 21 | Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false) | 65 | Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false) |
| @@ -25,6 +69,15 @@ class ForestTransform : public Transform | @@ -25,6 +69,15 @@ class ForestTransform : public Transform | ||
| 25 | Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) | 69 | Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) |
| 26 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | 70 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 27 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | 71 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) |
| 72 | + Q_PROPERTY(bool weight READ get_weight WRITE set_weight RESET reset_weight STORED false) | ||
| 73 | + Q_PROPERTY(TerminationCriteria termCrit READ get_termCrit WRITE set_termCrit RESET reset_termCrit STORED false) | ||
| 74 | + | ||
| 75 | +public: | ||
| 76 | + enum TerminationCriteria { Iter = CV_TERMCRIT_ITER, | ||
| 77 | + EPS = CV_TERMCRIT_EPS, | ||
| 78 | + Both = CV_TERMCRIT_EPS | CV_TERMCRIT_ITER}; | ||
| 79 | + | ||
| 80 | +protected: | ||
| 28 | BR_PROPERTY(bool, classification, true) | 81 | BR_PROPERTY(bool, classification, true) |
| 29 | BR_PROPERTY(float, splitPercentage, .01) | 82 | BR_PROPERTY(float, splitPercentage, .01) |
| 30 | BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max()) | 83 | BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max()) |
| @@ -34,10 +87,12 @@ class ForestTransform : public Transform | @@ -34,10 +87,12 @@ class ForestTransform : public Transform | ||
| 34 | BR_PROPERTY(bool, overwriteMat, true) | 87 | BR_PROPERTY(bool, overwriteMat, true) |
| 35 | BR_PROPERTY(QString, inputVariable, "Label") | 88 | BR_PROPERTY(QString, inputVariable, "Label") |
| 36 | BR_PROPERTY(QString, outputVariable, "") | 89 | BR_PROPERTY(QString, outputVariable, "") |
| 90 | + BR_PROPERTY(bool, weight, false) | ||
| 91 | + BR_PROPERTY(TerminationCriteria, termCrit, Iter) | ||
| 37 | 92 | ||
| 38 | CvRTrees forest; | 93 | CvRTrees forest; |
| 39 | 94 | ||
| 40 | - void train(const TemplateList &data) | 95 | + void trainForest(const TemplateList &data) |
| 41 | { | 96 | { |
| 42 | Mat samples = OpenCVUtils::toMat(data.data()); | 97 | Mat samples = OpenCVUtils::toMat(data.data()); |
| 43 | Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); | 98 | Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); |
| @@ -51,6 +106,14 @@ class ForestTransform : public Transform | @@ -51,6 +106,14 @@ class ForestTransform : public Transform | ||
| 51 | types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL; | 106 | types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL; |
| 52 | } | 107 | } |
| 53 | 108 | ||
| 109 | + bool usePrior = classification && weight; | ||
| 110 | + float priors[2]; | ||
| 111 | + if (usePrior) { | ||
| 112 | + int nonZero = countNonZero(labels); | ||
| 113 | + priors[0] = 1; | ||
| 114 | + priors[1] = (float)(samples.rows-nonZero)/nonZero; | ||
| 115 | + } | ||
| 116 | + | ||
| 54 | int minSamplesForSplit = data.size()*splitPercentage; | 117 | int minSamplesForSplit = data.size()*splitPercentage; |
| 55 | forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), | 118 | forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), |
| 56 | CvRTParams(maxDepth, | 119 | CvRTParams(maxDepth, |
| @@ -58,54 +121,134 @@ class ForestTransform : public Transform | @@ -58,54 +121,134 @@ class ForestTransform : public Transform | ||
| 58 | 0, | 121 | 0, |
| 59 | false, | 122 | false, |
| 60 | 2, | 123 | 2, |
| 61 | - 0, // priors | 124 | + usePrior ? priors : 0, |
| 62 | false, | 125 | false, |
| 63 | 0, | 126 | 0, |
| 64 | maxTrees, | 127 | maxTrees, |
| 65 | forestAccuracy, | 128 | forestAccuracy, |
| 66 | - CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)); | 129 | + termCrit)); |
| 130 | + | ||
| 131 | + if (Globals->verbose) { | ||
| 132 | + qDebug() << "Number of trees:" << forest.get_tree_count(); | ||
| 133 | + | ||
| 134 | + if (classification) { | ||
| 135 | + QTime timer; | ||
| 136 | + timer.start(); | ||
| 137 | + int correctClassification = 0; | ||
| 138 | + float regressionError = 0; | ||
| 139 | + for (int i=0; i<samples.rows; i++) { | ||
| 140 | + float prediction = forest.predict_prob(samples.row(i)); | ||
| 141 | + int label = forest.predict(samples.row(i)); | ||
| 142 | + if (label == labels.at<float>(i,0)) { | ||
| 143 | + correctClassification++; | ||
| 144 | + } | ||
| 145 | + regressionError += fabs(prediction-labels.at<float>(i,0)); | ||
| 146 | + } | ||
| 147 | + | ||
| 148 | + qDebug("Time to classify %d samples: %d ms\n \ | ||
| 149 | + Classification Accuracy: %f\n \ | ||
| 150 | + MAE: %f\n \ | ||
| 151 | + Sample dimensionality: %d", | ||
| 152 | + samples.rows,timer.elapsed(),(float)correctClassification/samples.rows,regressionError/samples.rows,samples.cols); | ||
| 153 | + } | ||
| 154 | + } | ||
| 155 | + } | ||
| 156 | +}; | ||
| 157 | + | ||
| 158 | +BR_REGISTER(Transform, ForestTransform) | ||
| 159 | + | ||
| 160 | +/*! | ||
| 161 | + * \ingroup transforms | ||
| 162 | + * \brief Wraps OpenCV's random trees framework to induce features | ||
| 163 | + * \author Scott Klum \cite sklum | ||
| 164 | + * \brief https://lirias.kuleuven.be/bitstream/123456789/316661/1/icdm11-camready.pdf | ||
| 165 | + */ | ||
| 166 | +class ForestInductionTransform : public ForestTransform | ||
| 167 | +{ | ||
| 168 | + Q_OBJECT | ||
| 169 | + Q_PROPERTY(bool useRegressionValue READ get_useRegressionValue WRITE set_useRegressionValue RESET reset_useRegressionValue STORED false) | ||
| 170 | + BR_PROPERTY(bool, useRegressionValue, false) | ||
| 171 | + | ||
| 172 | + int totalSize; | ||
| 173 | + QList< QList<const CvDTreeNode*> > nodes; | ||
| 67 | 174 | ||
| 68 | - qDebug() << "Number of trees:" << forest.get_tree_count(); | 175 | + void fillNodes() |
| 176 | + { | ||
| 177 | + for (int i=0; i<forest.get_tree_count(); i++) { | ||
| 178 | + nodes.append(QList<const CvDTreeNode*>()); | ||
| 179 | + const CvDTreeNode* node = forest.get_tree(i)->get_root(); | ||
| 180 | + | ||
| 181 | + // traverse the tree and save all the nodes in depth-first order | ||
| 182 | + for(;;) | ||
| 183 | + { | ||
| 184 | + CvDTreeNode* parent; | ||
| 185 | + for(;;) | ||
| 186 | + { | ||
| 187 | + if( !node->left ) | ||
| 188 | + break; | ||
| 189 | + node = node->left; | ||
| 190 | + } | ||
| 191 | + | ||
| 192 | + nodes.last().append(node); | ||
| 193 | + | ||
| 194 | + for( parent = node->parent; parent && parent->right == node; | ||
| 195 | + node = parent, parent = parent->parent ) | ||
| 196 | + ; | ||
| 197 | + | ||
| 198 | + if( !parent ) | ||
| 199 | + break; | ||
| 200 | + | ||
| 201 | + node = parent->right; | ||
| 202 | + } | ||
| 203 | + | ||
| 204 | + totalSize += nodes.last().size(); | ||
| 205 | + } | ||
| 206 | + } | ||
| 207 | + | ||
| 208 | + void train(const TemplateList &data) | ||
| 209 | + { | ||
| 210 | + trainForest(data); | ||
| 211 | + if (!useRegressionValue) fillNodes(); | ||
| 69 | } | 212 | } |
| 70 | 213 | ||
| 71 | void project(const Template &src, Template &dst) const | 214 | void project(const Template &src, Template &dst) const |
| 72 | { | 215 | { |
| 73 | dst = src; | 216 | dst = src; |
| 74 | 217 | ||
| 75 | - float response; | ||
| 76 | - if (classification && returnConfidence) { | ||
| 77 | - // Fuzzy class label | ||
| 78 | - response = forest.predict_prob(src.m().reshape(1,1)); | ||
| 79 | - } else { | ||
| 80 | - response = forest.predict(src.m().reshape(1,1)); | ||
| 81 | - } | 218 | + Mat responses; |
| 82 | 219 | ||
| 83 | - if (overwriteMat) { | ||
| 84 | - dst.m() = Mat(1, 1, CV_32F); | ||
| 85 | - dst.m().at<float>(0, 0) = response; | 220 | + if (useRegressionValue) { |
| 221 | + responses = Mat::zeros(forest.get_tree_count(),1,CV_32F); | ||
| 222 | + for (int i=0; i<forest.get_tree_count(); i++) { | ||
| 223 | + responses.at<float>(i,0) = forest.get_tree(i)->predict(src.m().reshape(1,1))->value; | ||
| 224 | + } | ||
| 86 | } else { | 225 | } else { |
| 87 | - dst.file.set(outputVariable, response); | 226 | + responses = Mat::zeros(totalSize,1,CV_32F); |
| 227 | + int offset = 0; | ||
| 228 | + for (int i=0; i<nodes.size(); i++) { | ||
| 229 | + int index = nodes[i].indexOf(forest.get_tree(i)->predict(src.m().reshape(1,1))); | ||
| 230 | + responses.at<float>(offset+index,0) = 1; | ||
| 231 | + offset += nodes[i].size(); | ||
| 232 | + } | ||
| 88 | } | 233 | } |
| 234 | + | ||
| 235 | + dst.m() = responses; | ||
| 89 | } | 236 | } |
| 90 | 237 | ||
| 91 | void load(QDataStream &stream) | 238 | void load(QDataStream &stream) |
| 92 | { | 239 | { |
| 93 | OpenCVUtils::loadModel(forest,stream); | 240 | OpenCVUtils::loadModel(forest,stream); |
| 241 | + if (!useRegressionValue) fillNodes(); | ||
| 242 | + | ||
| 94 | } | 243 | } |
| 95 | 244 | ||
| 96 | void store(QDataStream &stream) const | 245 | void store(QDataStream &stream) const |
| 97 | { | 246 | { |
| 98 | OpenCVUtils::storeModel(forest,stream); | 247 | OpenCVUtils::storeModel(forest,stream); |
| 99 | } | 248 | } |
| 100 | - | ||
| 101 | - void init() | ||
| 102 | - { | ||
| 103 | - if (outputVariable.isEmpty()) | ||
| 104 | - outputVariable = inputVariable; | ||
| 105 | - } | ||
| 106 | }; | 249 | }; |
| 107 | 250 | ||
| 108 | -BR_REGISTER(Transform, ForestTransform) | 251 | +BR_REGISTER(Transform, ForestInductionTransform) |
| 109 | 252 | ||
| 110 | /*! | 253 | /*! |
| 111 | * \ingroup transforms | 254 | * \ingroup transforms |
share/openbr/cmake/FindLibLinear.cmake
0 โ 100644
| 1 | +find_path(LibLinear_DIR linear.h ${CMAKE_SOURCE_DIR}/3rdparty/*) | ||
| 2 | + | ||
| 3 | +message(${LibLinear_DIR}) | ||
| 4 | +mark_as_advanced(LibLinear_DIR) | ||
| 5 | +include_directories(${LibLinear_DIR}) | ||
| 6 | +include_directories(${LibLinear_DIR}/blas) | ||
| 7 | + | ||
| 8 | +set(LibLinear_SRC ${LibLinear_DIR}/linear.cpp | ||
| 9 | + ${LibLinear_DIR}/tron.cpp | ||
| 10 | + ${LibLinear_DIR}/blas/daxpy.c | ||
| 11 | + ${LibLinear_DIR}/blas/ddot.c | ||
| 12 | + ${LibLinear_DIR}/blas/dnrm2.c | ||
| 13 | + ${LibLinear_DIR}/blas/dscal.c) |