Commit c16487358f4ccf57d2e47174aa741df20c5639cd
1 parent
e1439a2c
Parameterized some functionality
Showing
1 changed file
with
30 additions
and
5 deletions
openbr/plugins/tree.cpp
| ... | ... | @@ -58,20 +58,28 @@ class ForestTransform : public MetaTransform |
| 58 | 58 | Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true) |
| 59 | 59 | Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true) |
| 60 | 60 | Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true) |
| 61 | + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED true) | |
| 62 | + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED true) | |
| 63 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | |
| 64 | + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | |
| 61 | 65 | BR_PROPERTY(bool, classification, true) |
| 62 | 66 | BR_PROPERTY(float, splitPercentage, .01) |
| 63 | 67 | BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max()) |
| 64 | 68 | BR_PROPERTY(int, maxTrees, 10) |
| 65 | 69 | BR_PROPERTY(float, forestAccuracy, .1) |
| 70 | + BR_PROPERTY(bool, returnConfidence, true) | |
| 71 | + BR_PROPERTY(bool, overwriteMat, true) | |
| 72 | + BR_PROPERTY(QString, inputVariable, "Label") | |
| 73 | + BR_PROPERTY(QString, outputVariable, "") | |
| 66 | 74 | |
| 67 | 75 | CvRTrees forest; |
| 68 | 76 | |
| 69 | 77 | void train(const TemplateList &data) |
| 70 | 78 | { |
| 71 | 79 | Mat samples = OpenCVUtils::toMat(data.data()); |
| 72 | - Mat labels = OpenCVUtils::toMat(File::get<float>(data, "Label")); | |
| 80 | + Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable)); | |
| 73 | 81 | |
| 74 | - Mat types = Mat(samples.cols + 1, 1, CV_8U ); | |
| 82 | + Mat types = Mat(samples.cols + 1, 1, CV_8U); | |
| 75 | 83 | types.setTo(Scalar(CV_VAR_NUMERICAL)); |
| 76 | 84 | |
| 77 | 85 | if (classification) { |
| ... | ... | @@ -101,9 +109,20 @@ class ForestTransform : public MetaTransform |
| 101 | 109 | { |
| 102 | 110 | dst = src; |
| 103 | 111 | |
| 104 | - float response = forest.predict_prob(src.m().reshape(1,1)); | |
| 105 | - dst.m() = Mat(1, 1, CV_32F); | |
| 106 | - dst.m().at<float>(0, 0) = response; | |
| 112 | + float response; | |
| 113 | + if (classification && returnConfidence) { | |
| 114 | + // Fuzzy class label | |
| 115 | + response = forest.predict_prob(src.m().reshape(1,1)); | |
| 116 | + } else { | |
| 117 | + response = forest.predict(src.m().reshape(1,1)); | |
| 118 | + } | |
| 119 | + | |
| 120 | + if (overwriteMat) { | |
| 121 | + dst.m() = Mat(1, 1, CV_32F); | |
| 122 | + dst.m().at<float>(0, 0) = response; | |
| 123 | + } else { | |
| 124 | + dst.file.set(outputVariable, response); | |
| 125 | + } | |
| 107 | 126 | } |
| 108 | 127 | |
| 109 | 128 | void load(QDataStream &stream) |
| ... | ... | @@ -115,6 +134,12 @@ class ForestTransform : public MetaTransform |
| 115 | 134 | { |
| 116 | 135 | storeForest(forest,stream); |
| 117 | 136 | } |
| 137 | + | |
| 138 | + void init() | |
| 139 | + { | |
| 140 | + if (outputVariable.isEmpty()) | |
| 141 | + outputVariable = inputVariable; | |
| 142 | + } | |
| 118 | 143 | }; |
| 119 | 144 | |
| 120 | 145 | BR_REGISTER(Transform, ForestTransform) | ... | ... |