Commit a3d504b5702877046da2bcb4237b5aa2280f1675
Merge branch 'master' of https://github.com/biometrics/openbr
Showing
2 changed files
with
175 additions
and
6 deletions
openbr/plugins/svm.cpp
| @@ -59,7 +59,7 @@ static void loadSVM(SVM &svm, QDataStream &stream) | @@ -59,7 +59,7 @@ static void loadSVM(SVM &svm, QDataStream &stream) | ||
| 59 | svm.load(qPrintable(tempFile.fileName())); | 59 | svm.load(qPrintable(tempFile.fileName())); |
| 60 | } | 60 | } |
| 61 | 61 | ||
| 62 | -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma) | 62 | +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria) |
| 63 | { | 63 | { |
| 64 | if (data.type() != CV_32FC1) | 64 | if (data.type() != CV_32FC1) |
| 65 | qFatal("Expected single channel floating point training data."); | 65 | qFatal("Expected single channel floating point training data."); |
| @@ -69,9 +69,18 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, | @@ -69,9 +69,18 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, | ||
| 69 | params.svm_type = type; | 69 | params.svm_type = type; |
| 70 | params.p = 0.1; | 70 | params.p = 0.1; |
| 71 | params.nu = 0.5; | 71 | params.nu = 0.5; |
| 72 | + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); | ||
| 73 | + | ||
| 72 | if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { | 74 | if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { |
| 73 | try { | 75 | try { |
| 74 | - svm.train_auto(data, lab, Mat(), Mat(), params, 5); | 76 | + svm.train_auto(data, lab, Mat(), Mat(), params, folds, |
| 77 | + CvSVM::get_default_grid(CvSVM::C), | ||
| 78 | + CvSVM::get_default_grid(CvSVM::GAMMA), | ||
| 79 | + CvSVM::get_default_grid(CvSVM::P), | ||
| 80 | + CvSVM::get_default_grid(CvSVM::NU), | ||
| 81 | + CvSVM::get_default_grid(CvSVM::COEF), | ||
| 82 | + CvSVM::get_default_grid(CvSVM::DEGREE), | ||
| 83 | + balanceFolds); | ||
| 75 | } catch (...) { | 84 | } catch (...) { |
| 76 | qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification."); | 85 | qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification."); |
| 77 | svm.train(data, lab, Mat(), Mat(), params); | 86 | svm.train(data, lab, Mat(), Mat(), params); |
| @@ -104,6 +113,9 @@ class SVMTransform : public Transform | @@ -104,6 +113,9 @@ class SVMTransform : public Transform | ||
| 104 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | 113 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 105 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | 114 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) |
| 106 | Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) | 115 | Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) |
| 116 | + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) | ||
| 117 | + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | ||
| 118 | + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) | ||
| 107 | 119 | ||
| 108 | public: | 120 | public: |
| 109 | enum Kernel { Linear = CvSVM::LINEAR, | 121 | enum Kernel { Linear = CvSVM::LINEAR, |
| @@ -125,7 +137,9 @@ private: | @@ -125,7 +137,9 @@ private: | ||
| 125 | BR_PROPERTY(QString, inputVariable, "Label") | 137 | BR_PROPERTY(QString, inputVariable, "Label") |
| 126 | BR_PROPERTY(QString, outputVariable, "") | 138 | BR_PROPERTY(QString, outputVariable, "") |
| 127 | BR_PROPERTY(bool, returnDFVal, false) | 139 | BR_PROPERTY(bool, returnDFVal, false) |
| 128 | - | 140 | + BR_PROPERTY(int, termCriteria, 1000) |
| 141 | + BR_PROPERTY(int, folds, 5) | ||
| 142 | + BR_PROPERTY(bool, balanceFolds, false) | ||
| 129 | 143 | ||
| 130 | SVM svm; | 144 | SVM svm; |
| 131 | QHash<QString, int> labelMap; | 145 | QHash<QString, int> labelMap; |
| @@ -146,7 +160,8 @@ private: | @@ -146,7 +160,8 @@ private: | ||
| 146 | QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); | 160 | QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); |
| 147 | lab = OpenCVUtils::toMat(dataLabels); | 161 | lab = OpenCVUtils::toMat(dataLabels); |
| 148 | } | 162 | } |
| 149 | - trainSVM(svm, data, lab, kernel, type, C, gamma); | 163 | + |
| 164 | + trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria); | ||
| 150 | } | 165 | } |
| 151 | 166 | ||
| 152 | void project(const Template &src, Template &dst) const | 167 | void project(const Template &src, Template &dst) const |
| @@ -207,7 +222,9 @@ class SVMDistance : public Distance | @@ -207,7 +222,9 @@ class SVMDistance : public Distance | ||
| 207 | Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) | 222 | Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) |
| 208 | Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) | 223 | Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) |
| 209 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | 224 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 210 | - | 225 | + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) |
| 226 | + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | ||
| 227 | + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) | ||
| 211 | 228 | ||
| 212 | public: | 229 | public: |
| 213 | enum Kernel { Linear = CvSVM::LINEAR, | 230 | enum Kernel { Linear = CvSVM::LINEAR, |
| @@ -225,6 +242,9 @@ private: | @@ -225,6 +242,9 @@ private: | ||
| 225 | BR_PROPERTY(Kernel, kernel, Linear) | 242 | BR_PROPERTY(Kernel, kernel, Linear) |
| 226 | BR_PROPERTY(Type, type, EPS_SVR) | 243 | BR_PROPERTY(Type, type, EPS_SVR) |
| 227 | BR_PROPERTY(QString, inputVariable, "Label") | 244 | BR_PROPERTY(QString, inputVariable, "Label") |
| 245 | + BR_PROPERTY(int, termCriteria, 1000) | ||
| 246 | + BR_PROPERTY(int, folds, 5) | ||
| 247 | + BR_PROPERTY(bool, balanceFolds, false) | ||
| 228 | 248 | ||
| 229 | SVM svm; | 249 | SVM svm; |
| 230 | 250 | ||
| @@ -249,7 +269,7 @@ private: | @@ -249,7 +269,7 @@ private: | ||
| 249 | deltaData = deltaData.rowRange(0, index); | 269 | deltaData = deltaData.rowRange(0, index); |
| 250 | deltaLab = deltaLab.rowRange(0, index); | 270 | deltaLab = deltaLab.rowRange(0, index); |
| 251 | 271 | ||
| 252 | - trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1); | 272 | + trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria); |
| 253 | } | 273 | } |
| 254 | 274 | ||
| 255 | float compare(const Mat &a, const Mat &b) const | 275 | float compare(const Mat &a, const Mat &b) const |
openbr/plugins/tree.cpp
0 โ 100644
| 1 | +#include <opencv2/ml/ml.hpp> | ||
| 2 | + | ||
| 3 | +#include "openbr_internal.h" | ||
| 4 | +#include "openbr/core/opencvutils.h" | ||
| 5 | +#include <QString> | ||
| 6 | +#include <QTemporaryFile> | ||
| 7 | + | ||
| 8 | +using namespace std; | ||
| 9 | +using namespace cv; | ||
| 10 | + | ||
| 11 | +namespace br | ||
| 12 | +{ | ||
| 13 | + | ||
| 14 | +static void storeForest(const CvRTrees &forest, QDataStream &stream) | ||
| 15 | +{ | ||
| 16 | + // Create local file | ||
| 17 | + QTemporaryFile tempFile; | ||
| 18 | + tempFile.open(); | ||
| 19 | + tempFile.close(); | ||
| 20 | + | ||
| 21 | + // Save MLP to local file | ||
| 22 | + forest.save(qPrintable(tempFile.fileName())); | ||
| 23 | + | ||
| 24 | + // Copy local file contents to stream | ||
| 25 | + tempFile.open(); | ||
| 26 | + QByteArray data = tempFile.readAll(); | ||
| 27 | + tempFile.close(); | ||
| 28 | + stream << data; | ||
| 29 | +} | ||
| 30 | + | ||
| 31 | +static void loadForest(CvRTrees &forest, QDataStream &stream) | ||
| 32 | +{ | ||
| 33 | + // Copy local file contents from stream | ||
| 34 | + QByteArray data; | ||
| 35 | + stream >> data; | ||
| 36 | + | ||
| 37 | + // Create local file | ||
| 38 | + QTemporaryFile tempFile(QDir::tempPath()+"/forest"); | ||
| 39 | + tempFile.open(); | ||
| 40 | + tempFile.write(data); | ||
| 41 | + tempFile.close(); | ||
| 42 | + | ||
| 43 | + // Load MLP from local file | ||
| 44 | + forest.load(qPrintable(tempFile.fileName())); | ||
| 45 | +} | ||
| 46 | + | ||
| 47 | +/*! | ||
| 48 | + * \ingroup transforms | ||
| 49 | + * \brief Wraps OpenCV's random trees framework | ||
| 50 | + * \author Scott Klum \cite sklum | ||
| 51 | + * \brief http://docs.opencv.org/modules/ml/doc/random_trees.html | ||
| 52 | + */ | ||
| 53 | +class ForestTransform : public MetaTransform | ||
| 54 | +{ | ||
| 55 | + Q_OBJECT | ||
| 56 | + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED true) | ||
| 57 | + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED true) | ||
| 58 | + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true) | ||
| 59 | + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true) | ||
| 60 | + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true) | ||
| 61 | + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED true) | ||
| 62 | + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED true) | ||
| 63 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | ||
| 64 | + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | ||
| 65 | + BR_PROPERTY(bool, classification, true) | ||
| 66 | + BR_PROPERTY(float, splitPercentage, .01) | ||
| 67 | + BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max()) | ||
| 68 | + BR_PROPERTY(int, maxTrees, 10) | ||
| 69 | + BR_PROPERTY(float, forestAccuracy, .1) | ||
| 70 | + BR_PROPERTY(bool, returnConfidence, true) | ||
| 71 | + BR_PROPERTY(bool, overwriteMat, true) | ||
| 72 | + BR_PROPERTY(QString, inputVariable, "Label") | ||
| 73 | + BR_PROPERTY(QString, outputVariable, "") | ||
| 74 | + | ||
| 75 | + CvRTrees forest; | ||
| 76 | + | ||
| 77 | + void train(const TemplateList &data) | ||
| 78 | + { | ||
| 79 | + Mat samples = OpenCVUtils::toMat(data.data()); | ||
| 80 | + Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); | ||
| 81 | + | ||
| 82 | + Mat types = Mat(samples.cols + 1, 1, CV_8U); | ||
| 83 | + types.setTo(Scalar(CV_VAR_NUMERICAL)); | ||
| 84 | + | ||
| 85 | + if (classification) { | ||
| 86 | + types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL; | ||
| 87 | + } else { | ||
| 88 | + types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL; | ||
| 89 | + } | ||
| 90 | + | ||
| 91 | + int minSamplesForSplit = data.size()*splitPercentage; | ||
| 92 | + forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), | ||
| 93 | + CvRTParams(maxDepth, | ||
| 94 | + minSamplesForSplit, | ||
| 95 | + 0, | ||
| 96 | + false, | ||
| 97 | + 2, | ||
| 98 | + 0, // priors | ||
| 99 | + false, | ||
| 100 | + 0, | ||
| 101 | + maxTrees, | ||
| 102 | + forestAccuracy, | ||
| 103 | + CV_TERMCRIT_EPS)); | ||
| 104 | + | ||
| 105 | + qDebug() << "Number of trees:" << forest.get_tree_count(); | ||
| 106 | + } | ||
| 107 | + | ||
| 108 | + void project(const Template &src, Template &dst) const | ||
| 109 | + { | ||
| 110 | + dst = src; | ||
| 111 | + | ||
| 112 | + float response; | ||
| 113 | + if (classification && returnConfidence) { | ||
| 114 | + // Fuzzy class label | ||
| 115 | + response = forest.predict_prob(src.m().reshape(1,1)); | ||
| 116 | + } else { | ||
| 117 | + response = forest.predict(src.m().reshape(1,1)); | ||
| 118 | + } | ||
| 119 | + | ||
| 120 | + if (overwriteMat) { | ||
| 121 | + dst.m() = Mat(1, 1, CV_32F); | ||
| 122 | + dst.m().at<float>(0, 0) = response; | ||
| 123 | + } else { | ||
| 124 | + dst.file.set(outputVariable, response); | ||
| 125 | + } | ||
| 126 | + } | ||
| 127 | + | ||
| 128 | + void load(QDataStream &stream) | ||
| 129 | + { | ||
| 130 | + loadForest(forest,stream); | ||
| 131 | + } | ||
| 132 | + | ||
| 133 | + void store(QDataStream &stream) const | ||
| 134 | + { | ||
| 135 | + storeForest(forest,stream); | ||
| 136 | + } | ||
| 137 | + | ||
| 138 | + void init() | ||
| 139 | + { | ||
| 140 | + if (outputVariable.isEmpty()) | ||
| 141 | + outputVariable = inputVariable; | ||
| 142 | + } | ||
| 143 | +}; | ||
| 144 | + | ||
| 145 | +BR_REGISTER(Transform, ForestTransform) | ||
| 146 | + | ||
| 147 | +} // namespace br | ||
| 148 | + | ||
| 149 | +#include "tree.moc" |