diff --git a/openbr/plugins/tree.cpp b/openbr/plugins/tree.cpp index b75fa0f..9cba347 100644 --- a/openbr/plugins/tree.cpp +++ b/openbr/plugins/tree.cpp @@ -11,7 +11,7 @@ using namespace cv; namespace br { -static void storeForest(const CvRTrees &forest, QDataStream &stream) +static void storeModel(const CvStatModel &model, QDataStream &stream) { // Create local file QTemporaryFile tempFile; @@ -19,7 +19,7 @@ static void storeForest(const CvRTrees &forest, QDataStream &stream) tempFile.close(); // Save MLP to local file - forest.save(qPrintable(tempFile.fileName())); + model.save(qPrintable(tempFile.fileName())); // Copy local file contents to stream tempFile.open(); @@ -28,20 +28,20 @@ static void storeForest(const CvRTrees &forest, QDataStream &stream) stream << data; } -static void loadForest(CvRTrees &forest, QDataStream &stream) +static void loadModel(CvStatModel &model, QDataStream &stream) { // Copy local file contents from stream QByteArray data; stream >> data; // Create local file - QTemporaryFile tempFile(QDir::tempPath()+"/forest"); + QTemporaryFile tempFile(QDir::tempPath()+"/model"); tempFile.open(); tempFile.write(data); tempFile.close(); // Load MLP from local file - forest.load(qPrintable(tempFile.fileName())); + model.load(qPrintable(tempFile.fileName())); } /*! @@ -50,16 +50,16 @@ static void loadForest(CvRTrees &forest, QDataStream &stream) * \author Scott Klum \cite sklum * \brief http://docs.opencv.org/modules/ml/doc/random_trees.html */ -class ForestTransform : public MetaTransform +class ForestTransform : public Transform { Q_OBJECT - Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED true) - Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED true) - Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true) - Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true) - Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true) - Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED true) - Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED true) + Q_PROPERTY(bool classification READ get_classification WRITE set_classification RESET reset_classification STORED false) + Q_PROPERTY(float splitPercentage READ get_splitPercentage WRITE set_splitPercentage RESET reset_splitPercentage STORED false) + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false) + Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED false) + Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED false) + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false) + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) BR_PROPERTY(bool, classification, true) @@ -100,7 +100,7 @@ class ForestTransform : public MetaTransform 0, maxTrees, forestAccuracy, - CV_TERMCRIT_EPS)); + CV_TERMCRIT_ITER | CV_TERMCRIT_EPS)); qDebug() << "Number of trees:" << forest.get_tree_count(); } @@ -127,12 +127,12 @@ class ForestTransform : public MetaTransform void load(QDataStream &stream) { - loadForest(forest,stream); + loadModel(forest,stream); } void store(QDataStream &stream) const { - storeForest(forest,stream); + storeModel(forest,stream); } void init() @@ -144,6 +144,113 @@ class ForestTransform : public MetaTransform BR_REGISTER(Transform, ForestTransform) +/*! + * \ingroup transforms + * \brief Wraps OpenCV's Ada Boost framework + * \author Scott Klum \cite sklum + * \brief http://docs.opencv.org/modules/ml/doc/boosting.html + */ +class AdaBoostTransform : public Transform +{ + Q_OBJECT + Q_ENUMS(Type) + Q_ENUMS(SplitCriteria) + + Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) + Q_PROPERTY(SplitCriteria splitCriteria READ get_splitCriteria WRITE set_splitCriteria RESET reset_splitCriteria STORED false) + Q_PROPERTY(int weakCount READ get_weakCount WRITE set_weakCount RESET reset_weakCount STORED false) + Q_PROPERTY(float trimRate READ get_trimRate WRITE set_trimRate RESET reset_trimRate STORED false) + Q_PROPERTY(int folds READ get_folds WRITE set_folds RESET reset_folds STORED false) + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false) + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED false) + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED false) + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) + +public: + enum Type { Discrete = CvBoost::DISCRETE, + Real = CvBoost::REAL, + Logit = CvBoost::LOGIT, + Gentle = CvBoost::GENTLE}; + + enum SplitCriteria { Default = CvBoost::DEFAULT, + Gini = CvBoost::GINI, + Misclass = CvBoost::MISCLASS, + Sqerr = CvBoost::SQERR}; + +private: + BR_PROPERTY(Type, type, Real) + BR_PROPERTY(SplitCriteria, splitCriteria, Default) + BR_PROPERTY(int, weakCount, 100) + BR_PROPERTY(float, trimRate, .95) + BR_PROPERTY(int, folds, 0) + BR_PROPERTY(int, maxDepth, 1) + BR_PROPERTY(bool, returnConfidence, true) + BR_PROPERTY(bool, overwriteMat, true) + BR_PROPERTY(QString, inputVariable, "Label") + BR_PROPERTY(QString, outputVariable, "") + + CvBoost boost; + + void train(const TemplateList &data) + { + Mat samples = OpenCVUtils::toMat(data.data()); + Mat labels = OpenCVUtils::toMat(File::get(data, inputVariable)); + + Mat types = Mat(samples.cols + 1, 1, CV_8U); + types.setTo(Scalar(CV_VAR_NUMERICAL)); + types.at(samples.cols, 0) = CV_VAR_CATEGORICAL; + + CvBoostParams params; + params.boost_type = type; + params.split_criteria = splitCriteria; + params.weak_count = weakCount; + params.weight_trim_rate = trimRate; + params.cv_folds = folds; + params.max_depth = maxDepth; + + boost.train( samples, CV_ROW_SAMPLE, labels, Mat(), Mat(), types, Mat(), + params); + } + + void project(const Template &src, Template &dst) const + { + dst = src; + float response; + if (returnConfidence) { + response = boost.predict(src.m().reshape(1,1),Mat(),Range::all(),false,true)/weakCount; + } else { + response = boost.predict(src.m().reshape(1,1)); + } + + if (overwriteMat) { + dst.m() = Mat(1, 1, CV_32F); + dst.m().at(0, 0) = response; + } else { + dst.file.set(outputVariable, response); + } + } + + void load(QDataStream &stream) + { + loadModel(boost,stream); + } + + void store(QDataStream &stream) const + { + storeModel(boost,stream); + } + + + void init() + { + if (outputVariable.isEmpty()) + outputVariable = inputVariable; + } +}; + +BR_REGISTER(Transform, AdaBoostTransform) + } // namespace br #include "tree.moc"