Commit a3d504b5702877046da2bcb4237b5aa2280f1675
Merge branch 'master' of https://github.com/biometrics/openbr
Showing
2 changed files
with
175 additions
and
6 deletions
openbr/plugins/svm.cpp
| ... | ... | @@ -59,7 +59,7 @@ static void loadSVM(SVM &svm, QDataStream &stream) |
| 59 | 59 | svm.load(qPrintable(tempFile.fileName())); |
| 60 | 60 | } |
| 61 | 61 | |
| 62 | -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma) | |
| 62 | +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria) | |
| 63 | 63 | { |
| 64 | 64 | if (data.type() != CV_32FC1) |
| 65 | 65 | qFatal("Expected single channel floating point training data."); |
| ... | ... | @@ -69,9 +69,18 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, |
| 69 | 69 | params.svm_type = type; |
| 70 | 70 | params.p = 0.1; |
| 71 | 71 | params.nu = 0.5; |
| 72 | + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); | |
| 73 | + | |
| 72 | 74 | if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { |
| 73 | 75 | try { |
| 74 | - svm.train_auto(data, lab, Mat(), Mat(), params, 5); | |
| 76 | + svm.train_auto(data, lab, Mat(), Mat(), params, folds, | |
| 77 | + CvSVM::get_default_grid(CvSVM::C), | |
| 78 | + CvSVM::get_default_grid(CvSVM::GAMMA), | |
| 79 | + CvSVM::get_default_grid(CvSVM::P), | |
| 80 | + CvSVM::get_default_grid(CvSVM::NU), | |
| 81 | + CvSVM::get_default_grid(CvSVM::COEF), | |
| 82 | + CvSVM::get_default_grid(CvSVM::DEGREE), | |
| 83 | + balanceFolds); | |
| 75 | 84 | } catch (...) { |
| 76 | 85 | qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification."); |
| 77 | 86 | svm.train(data, lab, Mat(), Mat(), params); |
| ... | ... | @@ -104,6 +113,9 @@ class SVMTransform : public Transform |
| 104 | 113 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 105 | 114 | Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) |
| 106 | 115 | Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) |
| 116 | + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) | |
| 117 | + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | |
| 118 | + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) | |
| 107 | 119 | |
| 108 | 120 | public: |
| 109 | 121 | enum Kernel { Linear = CvSVM::LINEAR, |
| ... | ... | @@ -125,7 +137,9 @@ private: |
| 125 | 137 | BR_PROPERTY(QString, inputVariable, "Label") |
| 126 | 138 | BR_PROPERTY(QString, outputVariable, "") |
| 127 | 139 | BR_PROPERTY(bool, returnDFVal, false) |
| 128 | - | |
| 140 | + BR_PROPERTY(int, termCriteria, 1000) | |
| 141 | + BR_PROPERTY(int, folds, 5) | |
| 142 | + BR_PROPERTY(bool, balanceFolds, false) | |
| 129 | 143 | |
| 130 | 144 | SVM svm; |
| 131 | 145 | QHash<QString, int> labelMap; |
| ... | ... | @@ -146,7 +160,8 @@ private: |
| 146 | 160 | QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); |
| 147 | 161 | lab = OpenCVUtils::toMat(dataLabels); |
| 148 | 162 | } |
| 149 | - trainSVM(svm, data, lab, kernel, type, C, gamma); | |
| 163 | + | |
| 164 | + trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria); | |
| 150 | 165 | } |
| 151 | 166 | |
| 152 | 167 | void project(const Template &src, Template &dst) const |
| ... | ... | @@ -207,7 +222,9 @@ class SVMDistance : public Distance |
| 207 | 222 | Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) |
| 208 | 223 | Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) |
| 209 | 224 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 210 | - | |
| 225 | + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) | |
| 226 | + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) | |
| 227 | + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) | |
| 211 | 228 | |
| 212 | 229 | public: |
| 213 | 230 | enum Kernel { Linear = CvSVM::LINEAR, |
| ... | ... | @@ -225,6 +242,9 @@ private: |
| 225 | 242 | BR_PROPERTY(Kernel, kernel, Linear) |
| 226 | 243 | BR_PROPERTY(Type, type, EPS_SVR) |
| 227 | 244 | BR_PROPERTY(QString, inputVariable, "Label") |
| 245 | + BR_PROPERTY(int, termCriteria, 1000) | |
| 246 | + BR_PROPERTY(int, folds, 5) | |
| 247 | + BR_PROPERTY(bool, balanceFolds, false) | |
| 228 | 248 | |
| 229 | 249 | SVM svm; |
| 230 | 250 | |
| ... | ... | @@ -249,7 +269,7 @@ private: |
| 249 | 269 | deltaData = deltaData.rowRange(0, index); |
| 250 | 270 | deltaLab = deltaLab.rowRange(0, index); |
| 251 | 271 | |
| 252 | - trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1); | |
| 272 | + trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria); | |
| 253 | 273 | } |
| 254 | 274 | |
| 255 | 275 | float compare(const Mat &a, const Mat &b) const | ... | ... |
openbr/plugins/tree.cpp
0 → 100644
| 1 | +#include <opencv2/ml/ml.hpp> | |
| 2 | + | |
| 3 | +#include "openbr_internal.h" | |
| 4 | +#include "openbr/core/opencvutils.h" | |
| 5 | +#include <QString> | |
| 6 | +#include <QTemporaryFile> | |
| 7 | + | |
| 8 | +using namespace std; | |
| 9 | +using namespace cv; | |
| 10 | + | |
| 11 | +namespace br | |
| 12 | +{ | |
| 13 | + | |
| 14 | +static void storeForest(const CvRTrees &forest, QDataStream &stream) | |
| 15 | +{ | |
| 16 | + // Create local file | |
| 17 | + QTemporaryFile tempFile; | |
| 18 | + tempFile.open(); | |
| 19 | + tempFile.close(); | |
| 20 | + | |
| 21 | + // Save MLP to local file | |
| 22 | + forest.save(qPrintable(tempFile.fileName())); | |
| 23 | + | |
| 24 | + // Copy local file contents to stream | |
| 25 | + tempFile.open(); | |
| 26 | + QByteArray data = tempFile.readAll(); | |
| 27 | + tempFile.close(); | |
| 28 | + stream << data; | |
| 29 | +} | |
| 30 | + | |
| 31 | +static void loadForest(CvRTrees &forest, QDataStream &stream) | |
| 32 | +{ | |
| 33 | + // Copy local file contents from stream | |
| 34 | + QByteArray data; | |
| 35 | + stream >> data; | |
| 36 | + | |
| 37 | + // Create local file | |
| 38 | + QTemporaryFile tempFile(QDir::tempPath()+"/forest"); | |
| 39 | + tempFile.open(); | |
| 40 | + tempFile.write(data); | |
| 41 | + tempFile.close(); | |
| 42 | + | |
| 43 | + // Load MLP from local file | |
| 44 | + forest.load(qPrintable(tempFile.fileName())); | |
| 45 | +} | |
| 46 | + | |
| 47 | +/*! | |
| 48 | + * \ingroup transforms | |
| 49 | + * \brief Wraps OpenCV's random trees framework | |
| 50 | + * \author Scott Klum \cite sklum | |
| 51 | + * \brief http://docs.opencv.org/modules/ml/doc/random_trees.html | |
| 52 | + */ | |
| 53 | +class ForestTransform : public MetaTransform | |
| 54 | +{ | |
| 55 | + Q_OBJECT | |
| 56 | + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED true) | |
| 57 | + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED true) | |
| 58 | + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true) | |
| 59 | + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true) | |
| 60 | + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true) | |
| 61 | + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED true) | |
| 62 | + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED true) | |
| 63 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | |
| 64 | + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | |
| 65 | + BR_PROPERTY(bool, classification, true) | |
| 66 | + BR_PROPERTY(float, splitPercentage, .01) | |
| 67 | + BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max()) | |
| 68 | + BR_PROPERTY(int, maxTrees, 10) | |
| 69 | + BR_PROPERTY(float, forestAccuracy, .1) | |
| 70 | + BR_PROPERTY(bool, returnConfidence, true) | |
| 71 | + BR_PROPERTY(bool, overwriteMat, true) | |
| 72 | + BR_PROPERTY(QString, inputVariable, "Label") | |
| 73 | + BR_PROPERTY(QString, outputVariable, "") | |
| 74 | + | |
| 75 | + CvRTrees forest; | |
| 76 | + | |
| 77 | + void train(const TemplateList &data) | |
| 78 | + { | |
| 79 | + Mat samples = OpenCVUtils::toMat(data.data()); | |
| 80 | + Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); | |
| 81 | + | |
| 82 | + Mat types = Mat(samples.cols + 1, 1, CV_8U); | |
| 83 | + types.setTo(Scalar(CV_VAR_NUMERICAL)); | |
| 84 | + | |
| 85 | + if (classification) { | |
| 86 | + types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL; | |
| 87 | + } else { | |
| 88 | + types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL; | |
| 89 | + } | |
| 90 | + | |
| 91 | + int minSamplesForSplit = data.size()*splitPercentage; | |
| 92 | + forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), | |
| 93 | + CvRTParams(maxDepth, | |
| 94 | + minSamplesForSplit, | |
| 95 | + 0, | |
| 96 | + false, | |
| 97 | + 2, | |
| 98 | + 0, // priors | |
| 99 | + false, | |
| 100 | + 0, | |
| 101 | + maxTrees, | |
| 102 | + forestAccuracy, | |
| 103 | + CV_TERMCRIT_EPS)); | |
| 104 | + | |
| 105 | + qDebug() << "Number of trees:" << forest.get_tree_count(); | |
| 106 | + } | |
| 107 | + | |
| 108 | + void project(const Template &src, Template &dst) const | |
| 109 | + { | |
| 110 | + dst = src; | |
| 111 | + | |
| 112 | + float response; | |
| 113 | + if (classification && returnConfidence) { | |
| 114 | + // Fuzzy class label | |
| 115 | + response = forest.predict_prob(src.m().reshape(1,1)); | |
| 116 | + } else { | |
| 117 | + response = forest.predict(src.m().reshape(1,1)); | |
| 118 | + } | |
| 119 | + | |
| 120 | + if (overwriteMat) { | |
| 121 | + dst.m() = Mat(1, 1, CV_32F); | |
| 122 | + dst.m().at<float>(0, 0) = response; | |
| 123 | + } else { | |
| 124 | + dst.file.set(outputVariable, response); | |
| 125 | + } | |
| 126 | + } | |
| 127 | + | |
| 128 | + void load(QDataStream &stream) | |
| 129 | + { | |
| 130 | + loadForest(forest,stream); | |
| 131 | + } | |
| 132 | + | |
| 133 | + void store(QDataStream &stream) const | |
| 134 | + { | |
| 135 | + storeForest(forest,stream); | |
| 136 | + } | |
| 137 | + | |
| 138 | + void init() | |
| 139 | + { | |
| 140 | + if (outputVariable.isEmpty()) | |
| 141 | + outputVariable = inputVariable; | |
| 142 | + } | |
| 143 | +}; | |
| 144 | + | |
| 145 | +BR_REGISTER(Transform, ForestTransform) | |
| 146 | + | |
| 147 | +} // namespace br | |
| 148 | + | |
| 149 | +#include "tree.moc" | ... | ... |