Commit 22de0ca461d6cb90cfda0781d5d797e87504d486

Authored by Charles Otto
1 parent aebd450c

Preliminar changes towards refering to class labels etc. more specifically

Change default label name from Subjet to Label (since label is a more general
term).
Use different default variable names for classification (label), regression
(regressor/regressand), and clustering (ClusterID)
Update some (far from all) transforms to accept arguments specifying their
input/output variables.
Update eval classification to optionally take target variable names as
arguments
app/br/br.cpp
... ... @@ -130,7 +130,7 @@ public:
130 130 br_convert(parv[0], parv[1], parv[2]);
131 131 } else if (!strcmp(fun, "evalClassification")) {
132 132 check(parc == 2, "Incorrect parameter count for 'evalClassification'.");
133   - br_eval_classification(parv[0], parv[1]);
  133 + br_eval_classification(parv[0], parv[1], parc >= 3 ? parv[2] : NULL, parc >= 4 ? parv[3] : NULL);
134 134 } else if (!strcmp(fun, "evalRegression")) {
135 135 check(parc == 2, "Incorrect parameter count for 'evalRegression'.");
136 136 br_eval_regression(parv[0], parv[1]);
... ...
app/examples/age_estimation.cpp
... ... @@ -29,7 +29,8 @@
29 29  
30 30 static void printTemplate(const br::Template &t)
31 31 {
32   - printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Subject")));
  32 + // may use age directly -cao
  33 + printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Regressand")));
33 34 }
34 35  
35 36 int main(int argc, char *argv[])
... ...
app/examples/gender_estimation.cpp
... ... @@ -29,7 +29,8 @@
29 29  
30 30 static void printTemplate(const br::Template &t)
31 31 {
32   - printf("%s gender: %s\n", qPrintable(t.file.fileName()), qPrintable(t.file.get<QString>("Subject")));
  32 + // may use gender directly -cao
  33 + printf("%s gender: %s\n", qPrintable(t.file.fileName()), qPrintable(t.file.get<QString>("Label")));
33 34 }
34 35  
35 36 int main(int argc, char *argv[])
... ...
openbr/core/bee.cpp
... ... @@ -93,10 +93,10 @@ void BEE::writeSigset(const QString &amp;sigset, const br::FileList &amp;files, bool ign
93 93 QStringList metadata;
94 94 if (!ignoreMetadata)
95 95 foreach (const QString &key, file.localKeys()) {
96   - if ((key == "Index") || (key == "Subject")) continue;
  96 + if ((key == "Index") || (key == "Label")) continue;
97 97 metadata.append(key+"=\""+QtUtils::toString(file.value(key))+"\"");
98 98 }
99   - lines.append("\t<biometric-signature name=\"" + file.get<QString>("Subject") +"\">");
  99 + lines.append("\t<biometric-signature name=\"" + file.get<QString>("Label") +"\">");
100 100 lines.append("\t\t<presentation file-name=\"" + file.name + "\" " + metadata.join(" ") + "/>");
101 101 lines.append("\t</biometric-signature>");
102 102 }
... ... @@ -260,10 +260,10 @@ void BEE::makeMask(const QString &amp;targetInput, const QString &amp;queryInput, const
260 260  
261 261 cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition)
262 262 {
263   - // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet
264   - // -cao
265   - QList<QString> targetLabels = targets.get<QString>("Subject", "-1");
266   - QList<QString> queryLabels = queries.get<QString>("Subject", "-1");
  263 + // Direct use of "Label" isn't general, also would prefer to use indexProperty, rather than
  264 + // doing string comparisons (but that isn't implemented yet for FileList) -cao
  265 + QList<QString> targetLabels = targets.get<QString>("Label", "-1");
  266 + QList<QString> queryLabels = queries.get<QString>("Label", "-1");
267 267 QList<int> targetPartitions = targets.crossValidationPartitions();
268 268 QList<int> queryPartitions = queries.crossValidationPartitions();
269 269  
... ...
openbr/core/classify.cpp
... ... @@ -31,7 +31,7 @@ struct Counter
31 31 }
32 32 };
33 33  
34   -void br::EvalClassification(const QString &predictedInput, const QString &truthInput)
  34 +void br::EvalClassification(const QString &predictedInput, const QString &truthInput, const QString & predictedProperty, const QString & truthProperty)
35 35 {
36 36 qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
37 37  
... ... @@ -44,9 +44,8 @@ void br::EvalClassification(const QString &amp;predictedInput, const QString &amp;truthI
44 44 if (predicted[i].file.name != truth[i].file.name)
45 45 qFatal("Input order mismatch.");
46 46  
47   - // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy.
48   - QString predictedSubject = predicted[i].file.get<QString>("Subject");
49   - QString trueSubject = truth[i].file.get<QString>("Subject");
  47 + QString predictedSubject = predicted[i].file.get<QString>(predictedProperty);
  48 + QString trueSubject = truth[i].file.get<QString>(truthProperty);
50 49  
51 50 QStringList predictedSubjects(predictedSubject);
52 51 QStringList trueSubjects(trueSubject);
... ... @@ -99,13 +98,19 @@ void br::EvalRegression(const QString &amp;predictedInput, const QString &amp;truthInput
99 98 if (predicted.size() != truth.size()) qFatal("Input size mismatch.");
100 99  
101 100 float rmsError = 0;
  101 + float maeError = 0;
  102 + // Direct use of Regressor/Regressand is not general -cao
102 103 QStringList truthValues, predictedValues;
103 104 for (int i=0; i<predicted.size(); i++) {
104 105 if (predicted[i].file.name != truth[i].file.name)
105 106 qFatal("Input order mismatch.");
106   - rmsError += pow(predicted[i].file.get<float>("Subject")-truth[i].file.get<float>("Subject"), 2.f);
107   - truthValues.append(QString::number(truth[i].file.get<float>("Subject")));
108   - predictedValues.append(QString::number(predicted[i].file.get<float>("Subject")));
  107 +
  108 + float difference = predicted[i].file.get<float>("Regressand") - truth[i].file.get<float>("Regressor");
  109 +
  110 + rmsError += pow(difference, 2.f);
  111 + maeError += fabsf(difference);
  112 + truthValues.append(QString::number(truth[i].file.get<float>("Regressor")));
  113 + predictedValues.append(QString::number(predicted[i].file.get<float>("Regressand")));
109 114 }
110 115  
111 116 QStringList rSource;
... ... @@ -125,4 +130,6 @@ void br::EvalRegression(const QString &amp;predictedInput, const QString &amp;truthInput
125 130 if (success) QtUtils::showFile("EvalRegression.pdf");
126 131  
127 132 qDebug("RMS Error = %f", sqrt(rmsError/predicted.size()));
  133 + qDebug("MAE = %f", maeError/predicted.size());
  134 +
128 135 }
... ...
openbr/core/classify.h
... ... @@ -22,7 +22,7 @@
22 22  
23 23 namespace br
24 24 {
25   - void EvalClassification(const QString &predictedInput, const QString &truthInput);
  25 + void EvalClassification(const QString &predictedInput, const QString &truthInput, const QString & predictedProperty="Label", const QString & truthProperty="Label");
26 26 void EvalRegression(const QString &predictedInput, const QString &truthInput);
27 27 }
28 28  
... ...
openbr/core/cluster.cpp
... ... @@ -279,8 +279,8 @@ void br::EvalClustering(const QString &amp;csv, const QString &amp;input)
279 279 qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input));
280 280  
281 281 // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are
282   - // not named).
283   - QList<int> labels = TemplateList::fromGallery(input).files().get<int>("Subject");
  282 + // not named). Direct use of ClusterID is not general -cao
  283 + QList<int> labels = TemplateList::fromGallery(input).files().get<int>("ClusterID");
284 284  
285 285 QHash<int, int> labelToIndex;
286 286 int nClusters = 0;
... ...
openbr/frvt2012.cpp
... ... @@ -132,7 +132,8 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &amp;input_face, int32_t &amp;age)
132 132 TemplateList templates;
133 133 templates.append(templateFromONEFACE(input_face));
134 134 templates >> *frvt2012_age_transform.data();
135   - age = templates.first().file.get<float>("Subject");
  135 + // should maybe use "Age" directly -cao
  136 + age = templates.first().file.get<float>("Regressand");
136 137 return templates.first().file.failed() ? 4 : 0;
137 138 }
138 139  
... ... @@ -141,6 +142,7 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &amp;input_face, int8_t &amp;gender,
141 142 TemplateList templates;
142 143 templates.append(templateFromONEFACE(input_face));
143 144 templates >> *frvt2012_gender_transform.data();
144   - mf = gender = templates.first().file.get<QString>("Subject") == "Male" ? 0 : 1;
  145 + // Should maybe use "Gender" directly -cao
  146 + mf = gender = templates.first().file.get<QString>("Label") == "Male" ? 0 : 1;
145 147 return templates.first().file.failed() ? 4 : 0;
146 148 }
... ...
openbr/gui/classifier.cpp
... ... @@ -43,14 +43,16 @@ void Classifier::_classify(File file)
43 43 continue;
44 44  
45 45 if (algorithm == "GenderClassification") {
  46 + // Should maybe use gender directly -cao
46 47 key = "Gender";
47   - value = (f.get<QString>("Subject"));
  48 + value = f.get<QString>("Label");
48 49 } else if (algorithm == "AgeRegression") {
49 50 key = "Age";
50   - value = QString::number(int(f.get<float>("Subject")+0.5)) + " Years";
  51 + // similarly, age -cao
  52 + value = QString::number(int(f.get<float>("Regressand")+0.5)) + " Years";
51 53 } else {
52 54 key = algorithm;
53   - value = f.get<QString>("Subject");
  55 + value = f.get<QString>("Label");
54 56 }
55 57 break;
56 58 }
... ...
openbr/openbr.cpp
... ... @@ -77,9 +77,14 @@ float br_eval(const char *simmat, const char *mask, const char *csv)
77 77 return Evaluate(simmat, mask, csv);
78 78 }
79 79  
80   -void br_eval_classification(const char *predicted_input, const char *truth_input)
81   -{
82   - EvalClassification(predicted_input, truth_input);
  80 +void br_eval_classification(const char *predicted_input, const char *truth_input, const char *predicted_property, const char * truth_property)
  81 +{
  82 + if (predicted_property && truth_property)
  83 + EvalClassification(predicted_input, truth_input, predicted_property, truth_property);
  84 + else if (predicted_property)
  85 + EvalClassification(predicted_input, truth_input, predicted_property);
  86 + else
  87 + EvalClassification(predicted_input, truth_input);
83 88 }
84 89  
85 90 void br_eval_clustering(const char *csv, const char *input)
... ...
openbr/openbr.h
... ... @@ -164,7 +164,7 @@ BR_EXPORT float br_eval(const char *simmat, const char *mask, const char *csv =
164 164 * \param truth_input The ground truth br::Input.
165 165 * \see br_enroll
166 166 */
167   -BR_EXPORT void br_eval_classification(const char *predicted_input, const char *truth_input);
  167 +BR_EXPORT void br_eval_classification(const char *predicted_input, const char *truth_input, const char * predicted_property, const char * truth_property);
168 168  
169 169 /*!
170 170 * \brief Evaluates and prints clustering accuracy to the terminal.
... ...
openbr/openbr_plugin.cpp
... ... @@ -412,6 +412,7 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery)
412 412 // of target images to every partition
413 413 newTemplates[i].file.set("Partition", -1);
414 414 } else {
  415 + // Direct use of "Subject" is not general -cao
415 416 const QByteArray md5 = QCryptographicHash::hash(newTemplates[i].file.get<QString>("Subject").toLatin1(), QCryptographicHash::Md5);
416 417 // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow
417 418 newTemplates[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate);
... ...
openbr/plugins/cluster.cpp
... ... @@ -89,10 +89,12 @@ class KNNTransform : public Transform
89 89 Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
90 90 Q_PROPERTY(bool weighted READ get_weighted WRITE set_weighted RESET reset_weighted STORED false)
91 91 Q_PROPERTY(int numSubjects READ get_numSubjects WRITE set_numSubjects RESET reset_numSubjects STORED false)
  92 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
92 93 BR_PROPERTY(int, k, 1)
93 94 BR_PROPERTY(br::Distance*, distance, NULL)
94 95 BR_PROPERTY(bool, weighted, false)
95 96 BR_PROPERTY(int, numSubjects, 1)
  97 + BR_PROPERTY(QString, inputVariable, "Label")
96 98  
97 99 TemplateList gallery;
98 100  
... ... @@ -111,13 +113,13 @@ class KNNTransform : public Transform
111 113 QHash<QString, float> votes;
112 114 const int max = (k < 1) ? sortedScores.size() : std::min(k, sortedScores.size());
113 115 for (int j=0; j<max; j++)
114   - votes[gallery[sortedScores[j].second].file.get<QString>("Subject")] += (weighted ? sortedScores[j].first : 1);
  116 + votes[gallery[sortedScores[j].second].file.get<QString>(inputVariable)] += (weighted ? sortedScores[j].first : 1);
115 117 subjects.append(votes.keys()[votes.values().indexOf(Common::Max(votes.values()))]);
116 118  
117 119 // Remove subject from consideration
118 120 if (subjects.size() < numSubjects)
119 121 for (int j=sortedScores.size()-1; j>=0; j--)
120   - if (gallery[sortedScores[j].second].file.get<QString>("Subject") == subjects.last())
  122 + if (gallery[sortedScores[j].second].file.get<QString>(inputVariable) == subjects.last())
121 123 sortedScores.removeAt(j);
122 124 }
123 125  
... ...
openbr/plugins/eigen3.cpp
... ... @@ -318,10 +318,12 @@ class LDATransform : public Transform
318 318 Q_PROPERTY(bool pcaWhiten READ get_pcaWhiten WRITE set_pcaWhiten RESET reset_pcaWhiten STORED false)
319 319 Q_PROPERTY(int directLDA READ get_directLDA WRITE set_directLDA RESET reset_directLDA STORED false)
320 320 Q_PROPERTY(float directDrop READ get_directDrop WRITE set_directDrop RESET reset_directDrop STORED false)
  321 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
321 322 BR_PROPERTY(float, pcaKeep, 0.98)
322 323 BR_PROPERTY(bool, pcaWhiten, false)
323 324 BR_PROPERTY(int, directLDA, 0)
324 325 BR_PROPERTY(float, directDrop, 0.1)
  326 + BR_PROPERTY(QString, inputVariable, "Label")
325 327  
326 328 int dimsOut;
327 329 Eigen::VectorXf mean;
... ... @@ -330,7 +332,7 @@ class LDATransform : public Transform
330 332 void train(const TemplateList &_trainingSet)
331 333 {
332 334 // creates "Label"
333   - TemplateList trainingSet = TemplateList::relabel(_trainingSet, "Subject");
  335 + TemplateList trainingSet = TemplateList::relabel(_trainingSet, inputVariable);
334 336  
335 337 int instances = trainingSet.size();
336 338  
... ...
openbr/plugins/gallery.cpp
... ... @@ -71,7 +71,7 @@ class arffGallery : public Gallery
71 71 }
72 72  
73 73 arffFile.write(qPrintable(OpenCVUtils::matrixToStringList(t).join(',')));
74   - arffFile.write(qPrintable(",'" + t.file.get<QString>("Subject") + "'\n"));
  74 + arffFile.write(qPrintable(",'" + t.file.get<QString>("Label") + "'\n"));
75 75 }
76 76 };
77 77  
... ... @@ -874,7 +874,7 @@ class statGallery : public Gallery
874 874  
875 875 void write(const Template &t)
876 876 {
877   - subjects.insert(t.file.get<QString>("Subject"));
  877 + subjects.insert(t.file.get<QString>("Label"));
878 878 bytes.append(t.bytes());
879 879 }
880 880 };
... ...
openbr/plugins/independent.cpp
... ... @@ -20,11 +20,11 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t
20 20 const bool atLeast = transform->instances < 0;
21 21 const int instances = abs(transform->instances);
22 22  
23   - QList<QString> allLabels = templates.get<QString>("Subject");
  23 + QList<QString> allLabels = templates.get<QString>("Label");
24 24 QList<QString> uniqueLabels = allLabels.toSet().toList();
25 25 qSort(uniqueLabels);
26 26  
27   - QMap<QString,int> counts = templates.countValues<QString>("Subject", instances != std::numeric_limits<int>::max());
  27 + QMap<QString,int> counts = templates.countValues<QString>("Label", instances != std::numeric_limits<int>::max());
28 28  
29 29 if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max()))
30 30 foreach (const QString & label, counts.keys())
... ...
openbr/plugins/normalize.cpp
... ... @@ -127,7 +127,7 @@ private:
127 127 Mat m;
128 128 OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F);
129 129  
130   - const QList<int> labels = data.indexProperty("Subject");
  130 + const QList<int> labels = data.indexProperty("Label");
131 131 const int dims = m.cols;
132 132  
133 133 vector<Mat> mv, av, bv;
... ...
openbr/plugins/output.cpp
... ... @@ -146,8 +146,8 @@ class meltOutput : public MatrixOutput
146 146 QStringList lines;
147 147 if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys));
148 148  
149   - QList<QString> queryLabels = queryFiles.get<QString>("Subject");
150   - QList<QString> targetLabels = targetFiles.get<QString>("Subject");
  149 + QList<QString> queryLabels = queryFiles.get<QString>("Label");
  150 + QList<QString> targetLabels = targetFiles.get<QString>("Label");
151 151  
152 152 for (int i=0; i<queryFiles.size(); i++) {
153 153 for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) {
... ... @@ -298,7 +298,7 @@ class txtOutput : public MatrixOutput
298 298 if (file.isNull() || targetFiles.isEmpty() || queryFiles.isEmpty()) return;
299 299 QStringList lines;
300 300 foreach (const File &file, queryFiles)
301   - lines.append(file.name + " " + file.get<QString>("Subject"));
  301 + lines.append(file.name + " " + file.get<QString>("Label"));
302 302 QtUtils::writeFile(file, lines);
303 303 }
304 304 };
... ... @@ -426,7 +426,7 @@ class rankOutput : public MatrixOutput
426 426 typedef QPair<float,int> Pair;
427 427 int rank = 1;
428 428 foreach (const Pair &pair, Common::Sort(OpenCVUtils::matrixToVector<float>(data.row(i)), true)) {
429   - if (targetFiles[pair.second].get<QString>("Subject") == queryFiles[i].get<QString>("Subject")) {
  429 + if (targetFiles[pair.second].get<QString>("Label") == queryFiles[i].get<QString>("Label")) {
430 430 ranks.append(rank);
431 431 positions.append(pair.second);
432 432 scores.append(pair.first);
... ...
openbr/plugins/quality.cpp
... ... @@ -26,10 +26,10 @@ class ImpostorUniquenessMeasureTransform : public Transform
26 26  
27 27 float calculateIUM(const Template &probe, const TemplateList &gallery) const
28 28 {
29   - const QString probeLabel = probe.file.get<QString>("Subject");
  29 + const QString probeLabel = probe.file.get<QString>("Label");
30 30 TemplateList subset = gallery;
31 31 for (int j=subset.size()-1; j>=0; j--)
32   - if (subset[j].file.get<QString>("Subject") == probeLabel)
  32 + if (subset[j].file.get<QString>("Label") == probeLabel)
33 33 subset.removeAt(j);
34 34  
35 35 QList<float> scores = distance->compare(subset, probe);
... ... @@ -158,7 +158,7 @@ class MatchProbabilityDistance : public Distance
158 158 {
159 159 distance->train(src);
160 160  
161   - const QList<int> labels = src.indexProperty("Subject");
  161 + const QList<int> labels = src.indexProperty("Label");
162 162 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size())));
163 163 distance->compare(src, src, matrixOutput.data());
164 164  
... ... @@ -228,7 +228,7 @@ class HeatMapDistance : public Distance
228 228 {
229 229 distance->train(src);
230 230  
231   - const QList<int> labels = src.indexProperty("Subject");
  231 + const QList<int> labels = src.indexProperty("Label");
232 232  
233 233 QList<TemplateList> patches;
234 234  
... ... @@ -316,7 +316,7 @@ class UnitDistance : public Distance
316 316 void train(const TemplateList &templates)
317 317 {
318 318 const TemplateList samples = templates.mid(0, 2000);
319   - const QList<int> sampleLabels = samples.indexProperty("Subject");
  319 + const QList<int> sampleLabels = samples.indexProperty("Label");
320 320 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size())));
321 321 Distance::compare(samples, samples, matrixOutput.data());
322 322  
... ...
openbr/plugins/quantize.cpp
... ... @@ -150,7 +150,7 @@ class BayesianQuantizationDistance : public Distance
150 150 qFatal("Expected sigle matrix templates of type CV_8UC1!");
151 151  
152 152 const Mat data = OpenCVUtils::toMat(src.data());
153   - const QList<int> templateLabels = src.indexProperty("Subject");
  153 + const QList<int> templateLabels = src.indexProperty("Label");
154 154 loglikelihoods = QVector<float>(data.cols*256, 0);
155 155  
156 156 QFutureSynchronizer<void> futures;
... ... @@ -474,7 +474,7 @@ private:
474 474 Mat data = OpenCVUtils::toMat(src.data());
475 475 const int step = getStep(data.cols);
476 476  
477   - const QList<int> labels = src.indexProperty("Subject");
  477 + const QList<int> labels = src.indexProperty("Label");
478 478  
479 479 Mat &lut = ProductQuantizationLUTs[index];
480 480 lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1);
... ...
openbr/plugins/quantize2.cpp
... ... @@ -77,7 +77,7 @@ class BayesianQuantizationTransform : public Transform
77 77 void train(const TemplateList &src)
78 78 {
79 79 const Mat data = OpenCVUtils::toMat(src.data());
80   - const QList<int> labels = src.indexProperty("Subject");
  80 + const QList<int> labels = src.indexProperty("Label");
81 81  
82 82 thresholds = QVector<float>(256*data.cols);
83 83  
... ...
openbr/plugins/svm.cpp
... ... @@ -101,6 +101,8 @@ class SVMTransform : public Transform
101 101 Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false)
102 102 Q_PROPERTY(float C READ get_C WRITE set_C RESET reset_C STORED false)
103 103 Q_PROPERTY(float gamma READ get_gamma WRITE set_gamma RESET reset_gamma STORED false)
  104 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  105 + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
104 106  
105 107 public:
106 108 enum Kernel { Linear = CvSVM::LINEAR,
... ... @@ -119,6 +121,9 @@ private:
119 121 BR_PROPERTY(Type, type, C_SVC)
120 122 BR_PROPERTY(float, C, -1)
121 123 BR_PROPERTY(float, gamma, -1)
  124 + BR_PROPERTY(QString, inputVariable, "")
  125 + BR_PROPERTY(QString, outputVariable, "")
  126 +
122 127  
123 128 SVM svm;
124 129 QHash<QString, int> labelMap;
... ... @@ -128,14 +133,14 @@ private:
128 133 {
129 134 Mat data = OpenCVUtils::toMat(_data.data());
130 135 Mat lab;
131   - // If we are doing regression, assume subject has float values
  136 + // If we are doing regression, the input variable should have float values
132 137 if (type == EPS_SVR || type == NU_SVR) {
133   - lab = OpenCVUtils::toMat(_data.get<float>("Subject"));
  138 + lab = OpenCVUtils::toMat(_data.get<float>(inputVariable));
134 139 }
135   - // If we are doing classification, assume subject has discrete values, map them
  140 + // If we are doing classification, we should be dealing with discrete values. Map them
136 141 // and store the mapping data
137 142 else {
138   - QList<int> dataLabels = _data.indexProperty("Subject", labelMap, reverseLookup);
  143 + QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup);
139 144 lab = OpenCVUtils::toMat(dataLabels);
140 145 }
141 146 trainSVM(svm, data, lab, kernel, type, C, gamma);
... ... @@ -146,9 +151,9 @@ private:
146 151 dst = src;
147 152 float prediction = svm.predict(src.m().reshape(1, 1));
148 153 if (type == EPS_SVR || type == NU_SVR)
149   - dst.file.set("Subject", prediction);
  154 + dst.file.set(outputVariable, prediction);
150 155 else
151   - dst.file.set("Subject", reverseLookup[prediction]);
  156 + dst.file.set(outputVariable, reverseLookup[prediction]);
152 157 }
153 158  
154 159 void store(QDataStream &stream) const
... ... @@ -162,6 +167,24 @@ private:
162 167 loadSVM(svm, stream);
163 168 stream >> labelMap >> reverseLookup;
164 169 }
  170 +
  171 + void init()
  172 + {
  173 + // Since SVM can do regression or classification, we have to check the problem type before
  174 + // specifying target variable names
  175 + if (inputVariable.isEmpty())
  176 + {
  177 + if (type == EPS_SVR || type == NU_SVR) {
  178 + inputVariable = "Regressor";
  179 + if (outputVariable.isEmpty())
  180 + outputVariable = "Regressand";
  181 + }
  182 + else
  183 + inputVariable = "Label";
  184 + }
  185 + if (outputVariable.isEmpty())
  186 + outputVariable = inputVariable;
  187 + }
165 188 };
166 189  
167 190 BR_REGISTER(Transform, SVMTransform)
... ... @@ -178,6 +201,8 @@ class SVMDistance : public Distance
178 201 Q_ENUMS(Type)
179 202 Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false)
180 203 Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false)
  204 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  205 +
181 206  
182 207 public:
183 208 enum Kernel { Linear = CvSVM::LINEAR,
... ... @@ -194,13 +219,14 @@ public:
194 219 private:
195 220 BR_PROPERTY(Kernel, kernel, Linear)
196 221 BR_PROPERTY(Type, type, EPS_SVR)
  222 + BR_PROPERTY(QString, inputVariable, "Label")
197 223  
198 224 SVM svm;
199 225  
200 226 void train(const TemplateList &src)
201 227 {
202 228 const Mat data = OpenCVUtils::toMat(src.data());
203   - const QList<int> lab = src.indexProperty("Subject");
  229 + const QList<int> lab = src.indexProperty(inputVariable);
204 230  
205 231 const int instances = data.rows * (data.rows+1) / 2;
206 232 Mat deltaData(instances, data.cols, data.type());
... ...