diff --git a/openbr/plugins/svm.cpp b/openbr/plugins/svm.cpp index ad4594f..b80901e 100644 --- a/openbr/plugins/svm.cpp +++ b/openbr/plugins/svm.cpp @@ -59,7 +59,7 @@ static void loadSVM(SVM &svm, QDataStream &stream) svm.load(qPrintable(tempFile.fileName())); } -static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma) +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma, int folds, bool balanceFolds, int termCriteria) { if (data.type() != CV_32FC1) qFatal("Expected single channel floating point training data."); @@ -69,9 +69,18 @@ static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, params.svm_type = type; params.p = 0.1; params.nu = 0.5; + params.term_crit = cvTermCriteria(CV_TERMCRIT_ITER+CV_TERMCRIT_EPS, termCriteria, FLT_EPSILON); + if ((C == -1) || ((gamma == -1) && (kernel == CvSVM::RBF))) { try { - svm.train_auto(data, lab, Mat(), Mat(), params, 5); + svm.train_auto(data, lab, Mat(), Mat(), params, folds, + CvSVM::get_default_grid(CvSVM::C), + CvSVM::get_default_grid(CvSVM::GAMMA), + CvSVM::get_default_grid(CvSVM::P), + CvSVM::get_default_grid(CvSVM::NU), + CvSVM::get_default_grid(CvSVM::COEF), + CvSVM::get_default_grid(CvSVM::DEGREE), + balanceFolds); } catch (...) { qWarning("Some classes do not contain sufficient examples or are not discriminative enough for accurate SVM classification."); svm.train(data, lab, Mat(), Mat(), params); @@ -104,6 +113,9 @@ class SVMTransform : public Transform Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) Q_PROPERTY(bool returnDFVal READ get_returnDFVal WRITE set_returnDFVal RESET reset_returnDFVal STORED false) + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) public: enum Kernel { Linear = CvSVM::LINEAR, @@ -125,7 +137,9 @@ private: BR_PROPERTY(QString, inputVariable, "Label") BR_PROPERTY(QString, outputVariable, "") BR_PROPERTY(bool, returnDFVal, false) - + BR_PROPERTY(int, termCriteria, 1000) + BR_PROPERTY(int, folds, 5) + BR_PROPERTY(bool, balanceFolds, false) SVM svm; QHash labelMap; @@ -146,7 +160,8 @@ private: QList dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); lab = OpenCVUtils::toMat(dataLabels); } - trainSVM(svm, data, lab, kernel, type, C, gamma); + + trainSVM(svm, data, lab, kernel, type, C, gamma, folds, balanceFolds, termCriteria); } void project(const Template &src, Template &dst) const @@ -207,7 +222,9 @@ class SVMDistance : public Distance Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) - + Q_PROPERTY(int termCriteria READ get_termCriteria WRITE set_termCriteria RESET reset_termCriteria STORED false) + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) + Q_PROPERTY(bool balanceFolds READ get_balanceFolds WRITE set_balanceFolds RESET reset_balanceFolds STORED false) public: enum Kernel { Linear = CvSVM::LINEAR, @@ -225,6 +242,9 @@ private: BR_PROPERTY(Kernel, kernel, Linear) BR_PROPERTY(Type, type, EPS_SVR) BR_PROPERTY(QString, inputVariable, "Label") + BR_PROPERTY(int, termCriteria, 1000) + BR_PROPERTY(int, folds, 5) + BR_PROPERTY(bool, balanceFolds, false) SVM svm; @@ -249,7 +269,7 @@ private: deltaData = deltaData.rowRange(0, index); deltaLab = deltaLab.rowRange(0, index); - trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1); + trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1, folds, balanceFolds, termCriteria); } float compare(const Mat &a, const Mat &b) const diff --git a/openbr/plugins/tree.cpp b/openbr/plugins/tree.cpp new file mode 100644 index 0000000..b75fa0f --- /dev/null +++ b/openbr/plugins/tree.cpp @@ -0,0 +1,149 @@ +#include + +#include "openbr_internal.h" +#include "openbr/core/opencvutils.h" +#include +#include + +using namespace std; +using namespace cv; + +namespace br +{ + +static void storeForest(const CvRTrees &forest, QDataStream &stream) +{ + // Create local file + QTemporaryFile tempFile; + tempFile.open(); + tempFile.close(); + + // Save MLP to local file + forest.save(qPrintable(tempFile.fileName())); + + // Copy local file contents to stream + tempFile.open(); + QByteArray data = tempFile.readAll(); + tempFile.close(); + stream << data; +} + +static void loadForest(CvRTrees &forest, QDataStream &stream) +{ + // Copy local file contents from stream + QByteArray data; + stream >> data; + + // Create local file + QTemporaryFile tempFile(QDir::tempPath()+"/forest"); + tempFile.open(); + tempFile.write(data); + tempFile.close(); + + // Load MLP from local file + forest.load(qPrintable(tempFile.fileName())); +} + +/*! + * \ingroup transforms + * \brief Wraps OpenCV's random trees framework + * \author Scott Klum \cite sklum + * \brief http://docs.opencv.org/modules/ml/doc/random_trees.html + */ +class ForestTransform : public MetaTransform +{ + Q_OBJECT + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED true) + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED true) + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true) + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true) + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true) + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED true) + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED true) + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) + BR_PROPERTY(bool, classification, true) + BR_PROPERTY(float, splitPercentage, .01) + BR_PROPERTY(int, maxDepth, std::numeric_limits::max()) + BR_PROPERTY(int, maxTrees, 10) + BR_PROPERTY(float, forestAccuracy, .1) + BR_PROPERTY(bool, returnConfidence, true) + BR_PROPERTY(bool, overwriteMat, true) + BR_PROPERTY(QString, inputVariable, "Label") + BR_PROPERTY(QString, outputVariable, "") + + CvRTrees forest; + + void train(const TemplateList &data) + { + Mat samples = OpenCVUtils::toMat(data.data()); + Mat labels = OpenCVUtils::toMat(File::get(data, inputVariable)); + + Mat types = Mat(samples.cols + 1, 1, CV_8U); + types.setTo(Scalar(CV_VAR_NUMERICAL)); + + if (classification) { + types.at(samples.cols, 0) = CV_VAR_CATEGORICAL; + } else { + types.at(samples.cols, 0) = CV_VAR_NUMERICAL; + } + + int minSamplesForSplit = data.size()*splitPercentage; + forest.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), + CvRTParams(maxDepth, + minSamplesForSplit, + 0, + false, + 2, + 0, // priors + false, + 0, + maxTrees, + forestAccuracy, + CV_TERMCRIT_EPS)); + + qDebug() << "Number of trees:" << forest.get_tree_count(); + } + + void project(const Template &src, Template &dst) const + { + dst = src; + + float response; + if (classification && returnConfidence) { + // Fuzzy class label + response = forest.predict_prob(src.m().reshape(1,1)); + } else { + response = forest.predict(src.m().reshape(1,1)); + } + + if (overwriteMat) { + dst.m() = Mat(1, 1, CV_32F); + dst.m().at(0, 0) = response; + } else { + dst.file.set(outputVariable, response); + } + } + + void load(QDataStream &stream) + { + loadForest(forest,stream); + } + + void store(QDataStream &stream) const + { + storeForest(forest,stream); + } + + void init() + { + if (outputVariable.isEmpty()) + outputVariable = inputVariable; + } +}; + +BR_REGISTER(Transform, ForestTransform) + +} // namespace br + +#include "tree.moc"