Commit c16487358f4ccf57d2e47174aa741df20c5639cd

Authored by Scott Klum
1 parent e1439a2c

Parameterized some functionality

Showing 1 changed file with 30 additions and 5 deletions
openbr/plugins/tree.cpp
... ... @@ -58,20 +58,28 @@ class ForestTransform : public MetaTransform
58 58 Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED true)
59 59 Q_PROPERTY(int maxTrees READ get_maxTrees WRITE set_maxTrees RESET reset_maxTrees STORED true)
60 60 Q_PROPERTY(float forestAccuracy READ get_forestAccuracy WRITE set_forestAccuracy RESET reset_forestAccuracy STORED true)
  61 + Q_PROPERTY(bool returnConfidence READ get_returnConfidence WRITE set_returnConfidence RESET reset_returnConfidence STORED true)
  62 + Q_PROPERTY(bool overwriteMat READ get_overwriteMat WRITE set_overwriteMat RESET reset_overwriteMat STORED true)
  63 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  64 + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
61 65 BR_PROPERTY(bool, classification, true)
62 66 BR_PROPERTY(float, splitPercentage, .01)
63 67 BR_PROPERTY(int, maxDepth, std::numeric_limits<int>::max())
64 68 BR_PROPERTY(int, maxTrees, 10)
65 69 BR_PROPERTY(float, forestAccuracy, .1)
  70 + BR_PROPERTY(bool, returnConfidence, true)
  71 + BR_PROPERTY(bool, overwriteMat, true)
  72 + BR_PROPERTY(QString, inputVariable, "Label")
  73 + BR_PROPERTY(QString, outputVariable, "")
66 74  
67 75 CvRTrees forest;
68 76  
69 77 void train(const TemplateList &data)
70 78 {
71 79 Mat samples = OpenCVUtils::toMat(data.data());
72   - Mat labels = OpenCVUtils::toMat(File::get<float>(data, "Label"));
  80 + Mat labels = OpenCVUtils::toMat(File::get<float>(data, inputVariable));
73 81  
74   - Mat types = Mat(samples.cols + 1, 1, CV_8U );
  82 + Mat types = Mat(samples.cols + 1, 1, CV_8U);
75 83 types.setTo(Scalar(CV_VAR_NUMERICAL));
76 84  
77 85 if (classification) {
... ... @@ -101,9 +109,20 @@ class ForestTransform : public MetaTransform
101 109 {
102 110 dst = src;
103 111  
104   - float response = forest.predict_prob(src.m().reshape(1,1));
105   - dst.m() = Mat(1, 1, CV_32F);
106   - dst.m().at<float>(0, 0) = response;
  112 + float response;
  113 + if (classification && returnConfidence) {
  114 + // Fuzzy class label
  115 + response = forest.predict_prob(src.m().reshape(1,1));
  116 + } else {
  117 + response = forest.predict(src.m().reshape(1,1));
  118 + }
  119 +
  120 + if (overwriteMat) {
  121 + dst.m() = Mat(1, 1, CV_32F);
  122 + dst.m().at<float>(0, 0) = response;
  123 + } else {
  124 + dst.file.set(outputVariable, response);
  125 + }
107 126 }
108 127  
109 128 void load(QDataStream &stream)
... ... @@ -115,6 +134,12 @@ class ForestTransform : public MetaTransform
115 134 {
116 135 storeForest(forest,stream);
117 136 }
  137 +
  138 + void init()
  139 + {
  140 + if (outputVariable.isEmpty())
  141 + outputVariable = inputVariable;
  142 + }
118 143 };
119 144  
120 145 BR_REGISTER(Transform, ForestTransform)
... ...