Commit a3d504b5702877046da2bcb4237b5aa2280f1675

Authored by Josh Klontz
2 parents c08b5850 b0b6e5cd

Merge branch 'master' of https://github.com/biometrics/openbr

openbr/plugins/svm.cpp
... ... @@ -59,7 +59,7 @@ static void loadSVM(SVM &svm, QDataStream &stream)
59 59 svm.load(qPrintable(tempFile.fileName()));
60 60 }
61 61  
62   -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma)
  62 +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria)
63 63 {
64 64 if (data.type() != CV_32FC1)
65 65 qFatal("Expected single channel floating point training data.");
... ... @@ -69,9 +69,18 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C,
69 69 params.svm_type = type;
70 70 params.p = 0.1;
71 71 params.nu = 0.5;
  72 + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON);
  73 +
72 74 if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) {
73 75 try {
74   - svm.train_auto(data, lab, Mat(), Mat(), params, 5);
  76 + svm.train_auto(data, lab, Mat(), Mat(), params, folds,
  77 + CvSVM::get_default_grid(CvSVM::C),
  78 + CvSVM::get_default_grid(CvSVM::GAMMA),
  79 + CvSVM::get_default_grid(CvSVM::P),
  80 + CvSVM::get_default_grid(CvSVM::NU),
  81 + CvSVM::get_default_grid(CvSVM::COEF),
  82 + CvSVM::get_default_grid(CvSVM::DEGREE),
  83 + balanceFolds);
75 84 } catch (...) {
76 85 qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification.");
77 86 svm.train(data, lab, Mat(), Mat(), params);
... ... @@ -104,6 +113,9 @@ class SVMTransform : public Transform
104 113 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
105 114 Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
106 115 Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false)
  116 + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false)
  117 + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false)
  118 + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false)
107 119  
108 120 public:
109 121 enum Kernel { Linear = CvSVM::LINEAR,
... ... @@ -125,7 +137,9 @@ private:
125 137 BR_PROPERTY(QString, inputVariable, "Label")
126 138 BR_PROPERTY(QString, outputVariable, "")
127 139 BR_PROPERTY(bool, returnDFVal, false)
128   -
  140 + BR_PROPERTY(int, termCriteria, 1000)
  141 + BR_PROPERTY(int, folds, 5)
  142 + BR_PROPERTY(bool, balanceFolds, false)
129 143  
130 144 SVM svm;
131 145 QHash<QString, int> labelMap;
... ... @@ -146,7 +160,8 @@ private:
146 160 QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup);
147 161 lab = OpenCVUtils::toMat(dataLabels);
148 162 }
149   - trainSVM(svm, data, lab, kernel, type, C, gamma);
  163 +
  164 + trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria);
150 165 }
151 166  
152 167 void project(const Template &src, Template &dst) const
... ... @@ -207,7 +222,9 @@ class SVMDistance : public Distance
207 222 Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false)
208 223 Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false)
209 224 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
210   -
  225 + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false)
  226 + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false)
  227 + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false)
211 228  
212 229 public:
213 230 enum Kernel { Linear = CvSVM::LINEAR,
... ... @@ -225,6 +242,9 @@ private:
225 242 BR_PROPERTY(Kernel, kernel, Linear)
226 243 BR_PROPERTY(Type, type, EPS_SVR)
227 244 BR_PROPERTY(QString, inputVariable, "Label")
  245 + BR_PROPERTY(int, termCriteria, 1000)
  246 + BR_PROPERTY(int, folds, 5)
  247 + BR_PROPERTY(bool, balanceFolds, false)
228 248  
229 249 SVM svm;
230 250  
... ... @@ -249,7 +269,7 @@ private:
249 269 deltaData = deltaData.rowRange(0, index);
250 270 deltaLab = deltaLab.rowRange(0, index);
251 271  
252   - trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1);
  272 + trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria);
253 273 }
254 274  
255 275 float compare(const Mat &a, const Mat &b) const
... ...
openbr/plugins/tree.cpp 0 → 100644
  1 +#include <opencv2/ml/ml.hpp>
  2 +
  3 +#include "openbr_internal.h"
  4 +#include "openbr/core/opencvutils.h"
  5 +#include <QString>
  6 +#include <QTemporaryFile>
  7 +
  8 +using namespace std;
  9 +using namespace cv;
  10 +
  11 +namespace br
  12 +{
  13 +
  14 +static void storeForest(const CvRTrees &forest, QDataStream &stream)
  15 +{
  16 + // Create local file
  17 + QTemporaryFile tempFile;
  18 + tempFile.open();
  19 + tempFile.close();
  20 +
  21 + // Save MLP to local file
  22 + forest.save(qPrintable(tempFile.fileName()));
  23 +
  24 + // Copy local file contents to stream
  25 + tempFile.open();
  26 + QByteArray data = tempFile.readAll();
  27 + tempFile.close();
  28 + stream << data;
  29 +}
  30 +
  31 +static void loadForest(CvRTrees &forest, QDataStream &stream)
  32 +{
  33 + // Copy local file contents from stream
  34 + QByteArray data;
  35 + stream >> data;
  36 +
  37 + // Create local file
  38 + QTemporaryFile tempFile(QDir::tempPath()+"/forest");
  39 + tempFile.open();
  40 + tempFile.write(data);
  41 + tempFile.close();
  42 +
  43 + // Load MLP from local file
  44 + forest.load(qPrintable(tempFile.fileName()));
  45 +}
  46 +
  47 +/*!
  48 + * \ingroup transforms
  49 + * \brief Wraps OpenCV's random trees framework
  50 + * \author Scott Klum \cite sklum
  51 + * \brief http://docs.opencv.org/modules/ml/doc/random_trees.html
  52 + */
  53 +class ForestTransform : public MetaTransform
  54 +{
  55 + Q_OBJECT
  56 + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED true)
  57 + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED true)
  58 + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true)
  59 + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true)
  60 + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true)
  61 + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED true)
  62 + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED true)
  63 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  64 + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
  65 + BR_PROPERTY(bool, classification, true)
  66 + BR_PROPERTY(float, splitPercentage, .01)
  67 + BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max())
  68 + BR_PROPERTY(int, maxTrees, 10)
  69 + BR_PROPERTY(float, forestAccuracy, .1)
  70 + BR_PROPERTY(bool, returnConfidence, true)
  71 + BR_PROPERTY(bool, overwriteMat, true)
  72 + BR_PROPERTY(QString, inputVariable, "Label")
  73 + BR_PROPERTY(QString, outputVariable, "")
  74 +
  75 + CvRTrees forest;
  76 +
  77 + void train(const TemplateList &data)
  78 + {
  79 + Mat samples = OpenCVUtils::toMat(data.data());
  80 + Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable));
  81 +
  82 + Mat types = Mat(samples.cols + 1, 1, CV_8U);
  83 + types.setTo(Scalar(CV_VAR_NUMERICAL));
  84 +
  85 + if (classification) {
  86 + types.at<char>(samples.cols, 0) = CV_VAR_CATEGORICAL;
  87 + } else {
  88 + types.at<char>(samples.cols, 0) = CV_VAR_NUMERICAL;
  89 + }
  90 +
  91 + int minSamplesForSplit = data.size()*splitPercentage;
  92 + forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(),
  93 + CvRTParams(maxDepth,
  94 + minSamplesForSplit,
  95 + 0,
  96 + false,
  97 + 2,
  98 + 0, // priors
  99 + false,
  100 + 0,
  101 + maxTrees,
  102 + forestAccuracy,
  103 + CV_TERMCRIT_EPS));
  104 +
  105 + qDebug() << "Number of trees:" << forest.get_tree_count();
  106 + }
  107 +
  108 + void project(const Template &src, Template &dst) const
  109 + {
  110 + dst = src;
  111 +
  112 + float response;
  113 + if (classification && returnConfidence) {
  114 + // Fuzzy class label
  115 + response = forest.predict_prob(src.m().reshape(1,1));
  116 + } else {
  117 + response = forest.predict(src.m().reshape(1,1));
  118 + }
  119 +
  120 + if (overwriteMat) {
  121 + dst.m() = Mat(1, 1, CV_32F);
  122 + dst.m().at<float>(0, 0) = response;
  123 + } else {
  124 + dst.file.set(outputVariable, response);
  125 + }
  126 + }
  127 +
  128 + void load(QDataStream &stream)
  129 + {
  130 + loadForest(forest,stream);
  131 + }
  132 +
  133 + void store(QDataStream &stream) const
  134 + {
  135 + storeForest(forest,stream);
  136 + }
  137 +
  138 + void init()
  139 + {
  140 + if (outputVariable.isEmpty())
  141 + outputVariable = inputVariable;
  142 + }
  143 +};
  144 +
  145 +BR_REGISTER(Transform, ForestTransform)
  146 +
  147 +} // namespace br
  148 +
  149 +#include "tree.moc"
... ...