Commit c16487358f4ccf57d2e47174aa741df20c5639cd

Authored by Scott Klum
1 parent e1439a2c

Parameterized some functionality

Showing 1 changed file with 30 additions and 5 deletions
openbr/plugins/tree.cpp
@@ -58,20 +58,28 @@ class ForestTransform : public MetaTransform @@ -58,20 +58,28 @@ class ForestTransform : public MetaTransform
58 Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true) 58 Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true)
59 Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true) 59 Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true)
60 Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true) 60 Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true)
  61 + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED true)
  62 + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED true)
  63 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  64 + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
61 BR_PROPERTY(bool, classification, true) 65 BR_PROPERTY(bool, classification, true)
62 BR_PROPERTY(float, splitPercentage, .01) 66 BR_PROPERTY(float, splitPercentage, .01)
63 BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max()) 67 BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max())
64 BR_PROPERTY(int, maxTrees, 10) 68 BR_PROPERTY(int, maxTrees, 10)
65 BR_PROPERTY(float, forestAccuracy, .1) 69 BR_PROPERTY(float, forestAccuracy, .1)
  70 + BR_PROPERTY(bool, returnConfidence, true)
  71 + BR_PROPERTY(bool, overwriteMat, true)
  72 + BR_PROPERTY(QString, inputVariable, "Label")
  73 + BR_PROPERTY(QString, outputVariable, "")
66 74
67 CvRTrees forest; 75 CvRTrees forest;
68 76
69 void train(const TemplateList &data) 77 void train(const TemplateList &data)
70 { 78 {
71 Mat samples = OpenCVUtils::toMat(data.data()); 79 Mat samples = OpenCVUtils::toMat(data.data());
72 - Mat labels = OpenCVUtils::toMat(File::get<float>(data, "Label")); 80 + Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable));
73 81
74 - Mat types = Mat(samples.cols + 1, 1, CV_8U ); 82 + Mat types = Mat(samples.cols + 1, 1, CV_8U);
75 types.setTo(Scalar(CV_VAR_NUMERICAL)); 83 types.setTo(Scalar(CV_VAR_NUMERICAL));
76 84
77 if (classification) { 85 if (classification) {
@@ -101,9 +109,20 @@ class ForestTransform : public MetaTransform @@ -101,9 +109,20 @@ class ForestTransform : public MetaTransform
101 { 109 {
102 dst = src; 110 dst = src;
103 111
104 - float response = forest.predict_prob(src.m().reshape(1,1));  
105 - dst.m() = Mat(1, 1, CV_32F);  
106 - dst.m().at<float>(0, 0) = response; 112 + float response;
  113 + if (classification && returnConfidence) {
  114 + // Fuzzy class label
  115 + response = forest.predict_prob(src.m().reshape(1,1));
  116 + } else {
  117 + response = forest.predict(src.m().reshape(1,1));
  118 + }
  119 +
  120 + if (overwriteMat) {
  121 + dst.m() = Mat(1, 1, CV_32F);
  122 + dst.m().at<float>(0, 0) = response;
  123 + } else {
  124 + dst.file.set(outputVariable, response);
  125 + }
107 } 126 }
108 127
109 void load(QDataStream &stream) 128 void load(QDataStream &stream)
@@ -115,6 +134,12 @@ class ForestTransform : public MetaTransform @@ -115,6 +134,12 @@ class ForestTransform : public MetaTransform
115 { 134 {
116 storeForest(forest,stream); 135 storeForest(forest,stream);
117 } 136 }
  137 +
  138 + void init()
  139 + {
  140 + if (outputVariable.isEmpty())
  141 + outputVariable = inputVariable;
  142 + }
118 }; 143 };
119 144
120 BR_REGISTER(Transform, ForestTransform) 145 BR_REGISTER(Transform, ForestTransform)