Commit 22de0ca461d6cb90cfda0781d5d797e87504d486

Authored by Charles Otto
1 parent aebd450c

Preliminar changes towards refering to class labels etc. more specifically

Change default label name from Subjet to Label (since label is a more general
term).
Use different default variable names for classification (label), regression
(regressor/regressand), and clustering (ClusterID)
Update some (far from all) transforms to accept arguments specifying their
input/output variables.
Update eval classification to optionally take target variable names as
arguments
app/br/br.cpp
@@ -130,7 +130,7 @@ public: @@ -130,7 +130,7 @@ public:
130 br_convert(parv[0], parv[1], parv[2]); 130 br_convert(parv[0], parv[1], parv[2]);
131 } else if (!strcmp(fun, "evalClassification")) { 131 } else if (!strcmp(fun, "evalClassification")) {
132 check(parc == 2, "Incorrect parameter count for 'evalClassification'."); 132 check(parc == 2, "Incorrect parameter count for 'evalClassification'.");
133 - br_eval_classification(parv[0], parv[1]); 133 + br_eval_classification(parv[0], parv[1], parc >= 3 ? parv[2] : NULL, parc >= 4 ? parv[3] : NULL);
134 } else if (!strcmp(fun, "evalRegression")) { 134 } else if (!strcmp(fun, "evalRegression")) {
135 check(parc == 2, "Incorrect parameter count for 'evalRegression'."); 135 check(parc == 2, "Incorrect parameter count for 'evalRegression'.");
136 br_eval_regression(parv[0], parv[1]); 136 br_eval_regression(parv[0], parv[1]);
app/examples/age_estimation.cpp
@@ -29,7 +29,8 @@ @@ -29,7 +29,8 @@
29 29
30 static void printTemplate(const br::Template &t) 30 static void printTemplate(const br::Template &t)
31 { 31 {
32 - printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Subject"))); 32 + // may use age directly -cao
  33 + printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Regressand")));
33 } 34 }
34 35
35 int main(int argc, char *argv[]) 36 int main(int argc, char *argv[])
app/examples/gender_estimation.cpp
@@ -29,7 +29,8 @@ @@ -29,7 +29,8 @@
29 29
30 static void printTemplate(const br::Template &t) 30 static void printTemplate(const br::Template &t)
31 { 31 {
32 - printf("%s gender: %s\n", qPrintable(t.file.fileName()), qPrintable(t.file.get<QString>("Subject"))); 32 + // may use gender directly -cao
  33 + printf("%s gender: %s\n", qPrintable(t.file.fileName()), qPrintable(t.file.get<QString>("Label")));
33 } 34 }
34 35
35 int main(int argc, char *argv[]) 36 int main(int argc, char *argv[])
openbr/core/bee.cpp
@@ -93,10 +93,10 @@ void BEE::writeSigset(const QString &amp;sigset, const br::FileList &amp;files, bool ign @@ -93,10 +93,10 @@ void BEE::writeSigset(const QString &amp;sigset, const br::FileList &amp;files, bool ign
93 QStringList metadata; 93 QStringList metadata;
94 if (!ignoreMetadata) 94 if (!ignoreMetadata)
95 foreach (const QString &key, file.localKeys()) { 95 foreach (const QString &key, file.localKeys()) {
96 - if ((key == "Index") || (key == "Subject")) continue; 96 + if ((key == "Index") || (key == "Label")) continue;
97 metadata.append(key+"=\""+QtUtils::toString(file.value(key))+"\""); 97 metadata.append(key+"=\""+QtUtils::toString(file.value(key))+"\"");
98 } 98 }
99 - lines.append("\t<biometric-signature name=\"" + file.get<QString>("Subject") +"\">"); 99 + lines.append("\t<biometric-signature name=\"" + file.get<QString>("Label") +"\">");
100 lines.append("\t\t<presentation file-name=\"" + file.name + "\" " + metadata.join(" ") + "/>"); 100 lines.append("\t\t<presentation file-name=\"" + file.name + "\" " + metadata.join(" ") + "/>");
101 lines.append("\t</biometric-signature>"); 101 lines.append("\t</biometric-signature>");
102 } 102 }
@@ -260,10 +260,10 @@ void BEE::makeMask(const QString &amp;targetInput, const QString &amp;queryInput, const @@ -260,10 +260,10 @@ void BEE::makeMask(const QString &amp;targetInput, const QString &amp;queryInput, const
260 260
261 cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition) 261 cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition)
262 { 262 {
263 - // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet  
264 - // -cao  
265 - QList<QString> targetLabels = targets.get<QString>("Subject", "-1");  
266 - QList<QString> queryLabels = queries.get<QString>("Subject", "-1"); 263 + // Direct use of "Label" isn't general, also would prefer to use indexProperty, rather than
  264 + // doing string comparisons (but that isn't implemented yet for FileList) -cao
  265 + QList<QString> targetLabels = targets.get<QString>("Label", "-1");
  266 + QList<QString> queryLabels = queries.get<QString>("Label", "-1");
267 QList<int> targetPartitions = targets.crossValidationPartitions(); 267 QList<int> targetPartitions = targets.crossValidationPartitions();
268 QList<int> queryPartitions = queries.crossValidationPartitions(); 268 QList<int> queryPartitions = queries.crossValidationPartitions();
269 269
openbr/core/classify.cpp
@@ -31,7 +31,7 @@ struct Counter @@ -31,7 +31,7 @@ struct Counter
31 } 31 }
32 }; 32 };
33 33
34 -void br::EvalClassification(const QString &predictedInput, const QString &truthInput) 34 +void br::EvalClassification(const QString &predictedInput, const QString &truthInput, const QString & predictedProperty, const QString & truthProperty)
35 { 35 {
36 qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); 36 qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
37 37
@@ -44,9 +44,8 @@ void br::EvalClassification(const QString &amp;predictedInput, const QString &amp;truthI @@ -44,9 +44,8 @@ void br::EvalClassification(const QString &amp;predictedInput, const QString &amp;truthI
44 if (predicted[i].file.name != truth[i].file.name) 44 if (predicted[i].file.name != truth[i].file.name)
45 qFatal("Input order mismatch."); 45 qFatal("Input order mismatch.");
46 46
47 - // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy.  
48 - QString predictedSubject = predicted[i].file.get<QString>("Subject");  
49 - QString trueSubject = truth[i].file.get<QString>("Subject"); 47 + QString predictedSubject = predicted[i].file.get<QString>(predictedProperty);
  48 + QString trueSubject = truth[i].file.get<QString>(truthProperty);
50 49
51 QStringList predictedSubjects(predictedSubject); 50 QStringList predictedSubjects(predictedSubject);
52 QStringList trueSubjects(trueSubject); 51 QStringList trueSubjects(trueSubject);
@@ -99,13 +98,19 @@ void br::EvalRegression(const QString &amp;predictedInput, const QString &amp;truthInput @@ -99,13 +98,19 @@ void br::EvalRegression(const QString &amp;predictedInput, const QString &amp;truthInput
99 if (predicted.size() != truth.size()) qFatal("Input size mismatch."); 98 if (predicted.size() != truth.size()) qFatal("Input size mismatch.");
100 99
101 float rmsError = 0; 100 float rmsError = 0;
  101 + float maeError = 0;
  102 + // Direct use of Regressor/Regressand is not general -cao
102 QStringList truthValues, predictedValues; 103 QStringList truthValues, predictedValues;
103 for (int i=0; i<predicted.size(); i++) { 104 for (int i=0; i<predicted.size(); i++) {
104 if (predicted[i].file.name != truth[i].file.name) 105 if (predicted[i].file.name != truth[i].file.name)
105 qFatal("Input order mismatch."); 106 qFatal("Input order mismatch.");
106 - rmsError += pow(predicted[i].file.get<float>("Subject")-truth[i].file.get<float>("Subject"), 2.f);  
107 - truthValues.append(QString::number(truth[i].file.get<float>("Subject")));  
108 - predictedValues.append(QString::number(predicted[i].file.get<float>("Subject"))); 107 +
  108 + float difference = predicted[i].file.get<float>("Regressand") - truth[i].file.get<float>("Regressor");
  109 +
  110 + rmsError += pow(difference, 2.f);
  111 + maeError += fabsf(difference);
  112 + truthValues.append(QString::number(truth[i].file.get<float>("Regressor")));
  113 + predictedValues.append(QString::number(predicted[i].file.get<float>("Regressand")));
109 } 114 }
110 115
111 QStringList rSource; 116 QStringList rSource;
@@ -125,4 +130,6 @@ void br::EvalRegression(const QString &amp;predictedInput, const QString &amp;truthInput @@ -125,4 +130,6 @@ void br::EvalRegression(const QString &amp;predictedInput, const QString &amp;truthInput
125 if (success) QtUtils::showFile("EvalRegression.pdf"); 130 if (success) QtUtils::showFile("EvalRegression.pdf");
126 131
127 qDebug("RMS Error = %f", sqrt(rmsError/predicted.size())); 132 qDebug("RMS Error = %f", sqrt(rmsError/predicted.size()));
  133 + qDebug("MAE = %f", maeError/predicted.size());
  134 +
128 } 135 }
openbr/core/classify.h
@@ -22,7 +22,7 @@ @@ -22,7 +22,7 @@
22 22
23 namespace br 23 namespace br
24 { 24 {
25 - void EvalClassification(const QString &predictedInput, const QString &truthInput); 25 + void EvalClassification(const QString &predictedInput, const QString &truthInput, const QString & predictedProperty="Label", const QString & truthProperty="Label");
26 void EvalRegression(const QString &predictedInput, const QString &truthInput); 26 void EvalRegression(const QString &predictedInput, const QString &truthInput);
27 } 27 }
28 28
openbr/core/cluster.cpp
@@ -279,8 +279,8 @@ void br::EvalClustering(const QString &amp;csv, const QString &amp;input) @@ -279,8 +279,8 @@ void br::EvalClustering(const QString &amp;csv, const QString &amp;input)
279 qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input)); 279 qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input));
280 280
281 // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are 281 // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are
282 - // not named).  
283 - QList<int> labels = TemplateList::fromGallery(input).files().get<int>("Subject"); 282 + // not named). Direct use of ClusterID is not general -cao
  283 + QList<int> labels = TemplateList::fromGallery(input).files().get<int>("ClusterID");
284 284
285 QHash<int, int> labelToIndex; 285 QHash<int, int> labelToIndex;
286 int nClusters = 0; 286 int nClusters = 0;
openbr/frvt2012.cpp
@@ -132,7 +132,8 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &amp;input_face, int32_t &amp;age) @@ -132,7 +132,8 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &amp;input_face, int32_t &amp;age)
132 TemplateList templates; 132 TemplateList templates;
133 templates.append(templateFromONEFACE(input_face)); 133 templates.append(templateFromONEFACE(input_face));
134 templates >> *frvt2012_age_transform.data(); 134 templates >> *frvt2012_age_transform.data();
135 - age = templates.first().file.get<float>("Subject"); 135 + // should maybe use "Age" directly -cao
  136 + age = templates.first().file.get<float>("Regressand");
136 return templates.first().file.failed() ? 4 : 0; 137 return templates.first().file.failed() ? 4 : 0;
137 } 138 }
138 139
@@ -141,6 +142,7 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &amp;input_face, int8_t &amp;gender, @@ -141,6 +142,7 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &amp;input_face, int8_t &amp;gender,
141 TemplateList templates; 142 TemplateList templates;
142 templates.append(templateFromONEFACE(input_face)); 143 templates.append(templateFromONEFACE(input_face));
143 templates >> *frvt2012_gender_transform.data(); 144 templates >> *frvt2012_gender_transform.data();
144 - mf = gender = templates.first().file.get<QString>("Subject") == "Male" ? 0 : 1; 145 + // Should maybe use "Gender" directly -cao
  146 + mf = gender = templates.first().file.get<QString>("Label") == "Male" ? 0 : 1;
145 return templates.first().file.failed() ? 4 : 0; 147 return templates.first().file.failed() ? 4 : 0;
146 } 148 }
openbr/gui/classifier.cpp
@@ -43,14 +43,16 @@ void Classifier::_classify(File file) @@ -43,14 +43,16 @@ void Classifier::_classify(File file)
43 continue; 43 continue;
44 44
45 if (algorithm == "GenderClassification") { 45 if (algorithm == "GenderClassification") {
  46 + // Should maybe use gender directly -cao
46 key = "Gender"; 47 key = "Gender";
47 - value = (f.get<QString>("Subject")); 48 + value = f.get<QString>("Label");
48 } else if (algorithm == "AgeRegression") { 49 } else if (algorithm == "AgeRegression") {
49 key = "Age"; 50 key = "Age";
50 - value = QString::number(int(f.get<float>("Subject")+0.5)) + " Years"; 51 + // similarly, age -cao
  52 + value = QString::number(int(f.get<float>("Regressand")+0.5)) + " Years";
51 } else { 53 } else {
52 key = algorithm; 54 key = algorithm;
53 - value = f.get<QString>("Subject"); 55 + value = f.get<QString>("Label");
54 } 56 }
55 break; 57 break;
56 } 58 }
openbr/openbr.cpp
@@ -77,9 +77,14 @@ float br_eval(const char *simmat, const char *mask, const char *csv) @@ -77,9 +77,14 @@ float br_eval(const char *simmat, const char *mask, const char *csv)
77 return Evaluate(simmat, mask, csv); 77 return Evaluate(simmat, mask, csv);
78 } 78 }
79 79
80 -void br_eval_classification(const char *predicted_input, const char *truth_input)  
81 -{  
82 - EvalClassification(predicted_input, truth_input); 80 +void br_eval_classification(const char *predicted_input, const char *truth_input, const char *predicted_property, const char * truth_property)
  81 +{
  82 + if (predicted_property && truth_property)
  83 + EvalClassification(predicted_input, truth_input, predicted_property, truth_property);
  84 + else if (predicted_property)
  85 + EvalClassification(predicted_input, truth_input, predicted_property);
  86 + else
  87 + EvalClassification(predicted_input, truth_input);
83 } 88 }
84 89
85 void br_eval_clustering(const char *csv, const char *input) 90 void br_eval_clustering(const char *csv, const char *input)
openbr/openbr.h
@@ -164,7 +164,7 @@ BR_EXPORT float br_eval(const char *simmat, const char *mask, const char *csv = @@ -164,7 +164,7 @@ BR_EXPORT float br_eval(const char *simmat, const char *mask, const char *csv =
164 * \param truth_input The ground truth br::Input. 164 * \param truth_input The ground truth br::Input.
165 * \see br_enroll 165 * \see br_enroll
166 */ 166 */
167 -BR_EXPORT void br_eval_classification(const char *predicted_input, const char *truth_input); 167 +BR_EXPORT void br_eval_classification(const char *predicted_input, const char *truth_input, const char * predicted_property, const char * truth_property);
168 168
169 /*! 169 /*!
170 * \brief Evaluates and prints clustering accuracy to the terminal. 170 * \brief Evaluates and prints clustering accuracy to the terminal.
openbr/openbr_plugin.cpp
@@ -412,6 +412,7 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery) @@ -412,6 +412,7 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery)
412 // of target images to every partition 412 // of target images to every partition
413 newTemplates[i].file.set("Partition", -1); 413 newTemplates[i].file.set("Partition", -1);
414 } else { 414 } else {
  415 + // Direct use of "Subject" is not general -cao
415 const QByteArray md5 = QCryptographicHash::hash(newTemplates[i].file.get<QString>("Subject").toLatin1(), QCryptographicHash::Md5); 416 const QByteArray md5 = QCryptographicHash::hash(newTemplates[i].file.get<QString>("Subject").toLatin1(), QCryptographicHash::Md5);
416 // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow 417 // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow
417 newTemplates[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); 418 newTemplates[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate);
openbr/plugins/cluster.cpp
@@ -89,10 +89,12 @@ class KNNTransform : public Transform @@ -89,10 +89,12 @@ class KNNTransform : public Transform
89 Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance STORED false) 89 Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
90 Q_PROPERTY(bool weighted READ get_weighted WRITE set_weighted RESET reset_weighted STORED false) 90 Q_PROPERTY(bool weighted READ get_weighted WRITE set_weighted RESET reset_weighted STORED false)
91 Q_PROPERTY(int numSubjects READ get_numSubjects WRITE set_numSubjects RESET reset_numSubjects STORED false) 91 Q_PROPERTY(int numSubjects READ get_numSubjects WRITE set_numSubjects RESET reset_numSubjects STORED false)
  92 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
92 BR_PROPERTY(int, k, 1) 93 BR_PROPERTY(int, k, 1)
93 BR_PROPERTY(br::Distance*, distance, NULL) 94 BR_PROPERTY(br::Distance*, distance, NULL)
94 BR_PROPERTY(bool, weighted, false) 95 BR_PROPERTY(bool, weighted, false)
95 BR_PROPERTY(int, numSubjects, 1) 96 BR_PROPERTY(int, numSubjects, 1)
  97 + BR_PROPERTY(QString, inputVariable, "Label")
96 98
97 TemplateList gallery; 99 TemplateList gallery;
98 100
@@ -111,13 +113,13 @@ class KNNTransform : public Transform @@ -111,13 +113,13 @@ class KNNTransform : public Transform
111 QHash<QString, float> votes; 113 QHash<QString, float> votes;
112 const int max = (k < 1) ? sortedScores.size() : std::min(k, sortedScores.size()); 114 const int max = (k < 1) ? sortedScores.size() : std::min(k, sortedScores.size());
113 for (int j=0; j<max; j++) 115 for (int j=0; j<max; j++)
114 - votes[gallery[sortedScores[j].second].file.get<QString>("Subject")] += (weighted ? sortedScores[j].first : 1); 116 + votes[gallery[sortedScores[j].second].file.get<QString>(inputVariable)] += (weighted ? sortedScores[j].first : 1);
115 subjects.append(votes.keys()[votes.values().indexOf(Common::Max(votes.values()))]); 117 subjects.append(votes.keys()[votes.values().indexOf(Common::Max(votes.values()))]);
116 118
117 // Remove subject from consideration 119 // Remove subject from consideration
118 if (subjects.size() < numSubjects) 120 if (subjects.size() < numSubjects)
119 for (int j=sortedScores.size()-1; j>=0; j--) 121 for (int j=sortedScores.size()-1; j>=0; j--)
120 - if (gallery[sortedScores[j].second].file.get<QString>("Subject") == subjects.last()) 122 + if (gallery[sortedScores[j].second].file.get<QString>(inputVariable) == subjects.last())
121 sortedScores.removeAt(j); 123 sortedScores.removeAt(j);
122 } 124 }
123 125
openbr/plugins/eigen3.cpp
@@ -318,10 +318,12 @@ class LDATransform : public Transform @@ -318,10 +318,12 @@ class LDATransform : public Transform
318 Q_PROPERTY(bool pcaWhiten READ get_pcaWhiten WRITE set_pcaWhiten RESET reset_pcaWhiten STORED false) 318 Q_PROPERTY(bool pcaWhiten READ get_pcaWhiten WRITE set_pcaWhiten RESET reset_pcaWhiten STORED false)
319 Q_PROPERTY(int directLDA READ get_directLDA WRITE set_directLDA RESET reset_directLDA STORED false) 319 Q_PROPERTY(int directLDA READ get_directLDA WRITE set_directLDA RESET reset_directLDA STORED false)
320 Q_PROPERTY(float directDrop READ get_directDrop WRITE set_directDrop RESET reset_directDrop STORED false) 320 Q_PROPERTY(float directDrop READ get_directDrop WRITE set_directDrop RESET reset_directDrop STORED false)
  321 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
321 BR_PROPERTY(float, pcaKeep, 0.98) 322 BR_PROPERTY(float, pcaKeep, 0.98)
322 BR_PROPERTY(bool, pcaWhiten, false) 323 BR_PROPERTY(bool, pcaWhiten, false)
323 BR_PROPERTY(int, directLDA, 0) 324 BR_PROPERTY(int, directLDA, 0)
324 BR_PROPERTY(float, directDrop, 0.1) 325 BR_PROPERTY(float, directDrop, 0.1)
  326 + BR_PROPERTY(QString, inputVariable, "Label")
325 327
326 int dimsOut; 328 int dimsOut;
327 Eigen::VectorXf mean; 329 Eigen::VectorXf mean;
@@ -330,7 +332,7 @@ class LDATransform : public Transform @@ -330,7 +332,7 @@ class LDATransform : public Transform
330 void train(const TemplateList &_trainingSet) 332 void train(const TemplateList &_trainingSet)
331 { 333 {
332 // creates "Label" 334 // creates "Label"
333 - TemplateList trainingSet = TemplateList::relabel(_trainingSet, "Subject"); 335 + TemplateList trainingSet = TemplateList::relabel(_trainingSet, inputVariable);
334 336
335 int instances = trainingSet.size(); 337 int instances = trainingSet.size();
336 338
openbr/plugins/gallery.cpp
@@ -71,7 +71,7 @@ class arffGallery : public Gallery @@ -71,7 +71,7 @@ class arffGallery : public Gallery
71 } 71 }
72 72
73 arffFile.write(qPrintable(OpenCVUtils::matrixToStringList(t).join(','))); 73 arffFile.write(qPrintable(OpenCVUtils::matrixToStringList(t).join(',')));
74 - arffFile.write(qPrintable(",'" + t.file.get<QString>("Subject") + "'\n")); 74 + arffFile.write(qPrintable(",'" + t.file.get<QString>("Label") + "'\n"));
75 } 75 }
76 }; 76 };
77 77
@@ -874,7 +874,7 @@ class statGallery : public Gallery @@ -874,7 +874,7 @@ class statGallery : public Gallery
874 874
875 void write(const Template &t) 875 void write(const Template &t)
876 { 876 {
877 - subjects.insert(t.file.get<QString>("Subject")); 877 + subjects.insert(t.file.get<QString>("Label"));
878 bytes.append(t.bytes()); 878 bytes.append(t.bytes());
879 } 879 }
880 }; 880 };
openbr/plugins/independent.cpp
@@ -20,11 +20,11 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t @@ -20,11 +20,11 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t
20 const bool atLeast = transform->instances < 0; 20 const bool atLeast = transform->instances < 0;
21 const int instances = abs(transform->instances); 21 const int instances = abs(transform->instances);
22 22
23 - QList<QString> allLabels = templates.get<QString>("Subject"); 23 + QList<QString> allLabels = templates.get<QString>("Label");
24 QList<QString> uniqueLabels = allLabels.toSet().toList(); 24 QList<QString> uniqueLabels = allLabels.toSet().toList();
25 qSort(uniqueLabels); 25 qSort(uniqueLabels);
26 26
27 - QMap<QString,int> counts = templates.countValues<QString>("Subject", instances != std::numeric_limits<int>::max()); 27 + QMap<QString,int> counts = templates.countValues<QString>("Label", instances != std::numeric_limits<int>::max());
28 28
29 if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max())) 29 if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max()))
30 foreach (const QString & label, counts.keys()) 30 foreach (const QString & label, counts.keys())
openbr/plugins/normalize.cpp
@@ -127,7 +127,7 @@ private: @@ -127,7 +127,7 @@ private:
127 Mat m; 127 Mat m;
128 OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F); 128 OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F);
129 129
130 - const QList<int> labels = data.indexProperty("Subject"); 130 + const QList<int> labels = data.indexProperty("Label");
131 const int dims = m.cols; 131 const int dims = m.cols;
132 132
133 vector<Mat> mv, av, bv; 133 vector<Mat> mv, av, bv;
openbr/plugins/output.cpp
@@ -146,8 +146,8 @@ class meltOutput : public MatrixOutput @@ -146,8 +146,8 @@ class meltOutput : public MatrixOutput
146 QStringList lines; 146 QStringList lines;
147 if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys)); 147 if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys));
148 148
149 - QList<QString> queryLabels = queryFiles.get<QString>("Subject");  
150 - QList<QString> targetLabels = targetFiles.get<QString>("Subject"); 149 + QList<QString> queryLabels = queryFiles.get<QString>("Label");
  150 + QList<QString> targetLabels = targetFiles.get<QString>("Label");
151 151
152 for (int i=0; i<queryFiles.size(); i++) { 152 for (int i=0; i<queryFiles.size(); i++) {
153 for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) { 153 for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) {
@@ -298,7 +298,7 @@ class txtOutput : public MatrixOutput @@ -298,7 +298,7 @@ class txtOutput : public MatrixOutput
298 if (file.isNull() || targetFiles.isEmpty() || queryFiles.isEmpty()) return; 298 if (file.isNull() || targetFiles.isEmpty() || queryFiles.isEmpty()) return;
299 QStringList lines; 299 QStringList lines;
300 foreach (const File &file, queryFiles) 300 foreach (const File &file, queryFiles)
301 - lines.append(file.name + " " + file.get<QString>("Subject")); 301 + lines.append(file.name + " " + file.get<QString>("Label"));
302 QtUtils::writeFile(file, lines); 302 QtUtils::writeFile(file, lines);
303 } 303 }
304 }; 304 };
@@ -426,7 +426,7 @@ class rankOutput : public MatrixOutput @@ -426,7 +426,7 @@ class rankOutput : public MatrixOutput
426 typedef QPair<float,int> Pair; 426 typedef QPair<float,int> Pair;
427 int rank = 1; 427 int rank = 1;
428 foreach (const Pair &pair, Common::Sort(OpenCVUtils::matrixToVector<float>(data.row(i)), true)) { 428 foreach (const Pair &pair, Common::Sort(OpenCVUtils::matrixToVector<float>(data.row(i)), true)) {
429 - if (targetFiles[pair.second].get<QString>("Subject") == queryFiles[i].get<QString>("Subject")) { 429 + if (targetFiles[pair.second].get<QString>("Label") == queryFiles[i].get<QString>("Label")) {
430 ranks.append(rank); 430 ranks.append(rank);
431 positions.append(pair.second); 431 positions.append(pair.second);
432 scores.append(pair.first); 432 scores.append(pair.first);
openbr/plugins/quality.cpp
@@ -26,10 +26,10 @@ class ImpostorUniquenessMeasureTransform : public Transform @@ -26,10 +26,10 @@ class ImpostorUniquenessMeasureTransform : public Transform
26 26
27 float calculateIUM(const Template &probe, const TemplateList &gallery) const 27 float calculateIUM(const Template &probe, const TemplateList &gallery) const
28 { 28 {
29 - const QString probeLabel = probe.file.get<QString>("Subject"); 29 + const QString probeLabel = probe.file.get<QString>("Label");
30 TemplateList subset = gallery; 30 TemplateList subset = gallery;
31 for (int j=subset.size()-1; j>=0; j--) 31 for (int j=subset.size()-1; j>=0; j--)
32 - if (subset[j].file.get<QString>("Subject") == probeLabel) 32 + if (subset[j].file.get<QString>("Label") == probeLabel)
33 subset.removeAt(j); 33 subset.removeAt(j);
34 34
35 QList<float> scores = distance->compare(subset, probe); 35 QList<float> scores = distance->compare(subset, probe);
@@ -158,7 +158,7 @@ class MatchProbabilityDistance : public Distance @@ -158,7 +158,7 @@ class MatchProbabilityDistance : public Distance
158 { 158 {
159 distance->train(src); 159 distance->train(src);
160 160
161 - const QList<int> labels = src.indexProperty("Subject"); 161 + const QList<int> labels = src.indexProperty("Label");
162 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); 162 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size())));
163 distance->compare(src, src, matrixOutput.data()); 163 distance->compare(src, src, matrixOutput.data());
164 164
@@ -228,7 +228,7 @@ class HeatMapDistance : public Distance @@ -228,7 +228,7 @@ class HeatMapDistance : public Distance
228 { 228 {
229 distance->train(src); 229 distance->train(src);
230 230
231 - const QList<int> labels = src.indexProperty("Subject"); 231 + const QList<int> labels = src.indexProperty("Label");
232 232
233 QList<TemplateList> patches; 233 QList<TemplateList> patches;
234 234
@@ -316,7 +316,7 @@ class UnitDistance : public Distance @@ -316,7 +316,7 @@ class UnitDistance : public Distance
316 void train(const TemplateList &templates) 316 void train(const TemplateList &templates)
317 { 317 {
318 const TemplateList samples = templates.mid(0, 2000); 318 const TemplateList samples = templates.mid(0, 2000);
319 - const QList<int> sampleLabels = samples.indexProperty("Subject"); 319 + const QList<int> sampleLabels = samples.indexProperty("Label");
320 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size()))); 320 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size())));
321 Distance::compare(samples, samples, matrixOutput.data()); 321 Distance::compare(samples, samples, matrixOutput.data());
322 322
openbr/plugins/quantize.cpp
@@ -150,7 +150,7 @@ class BayesianQuantizationDistance : public Distance @@ -150,7 +150,7 @@ class BayesianQuantizationDistance : public Distance
150 qFatal("Expected sigle matrix templates of type CV_8UC1!"); 150 qFatal("Expected sigle matrix templates of type CV_8UC1!");
151 151
152 const Mat data = OpenCVUtils::toMat(src.data()); 152 const Mat data = OpenCVUtils::toMat(src.data());
153 - const QList<int> templateLabels = src.indexProperty("Subject"); 153 + const QList<int> templateLabels = src.indexProperty("Label");
154 loglikelihoods = QVector<float>(data.cols*256, 0); 154 loglikelihoods = QVector<float>(data.cols*256, 0);
155 155
156 QFutureSynchronizer<void> futures; 156 QFutureSynchronizer<void> futures;
@@ -474,7 +474,7 @@ private: @@ -474,7 +474,7 @@ private:
474 Mat data = OpenCVUtils::toMat(src.data()); 474 Mat data = OpenCVUtils::toMat(src.data());
475 const int step = getStep(data.cols); 475 const int step = getStep(data.cols);
476 476
477 - const QList<int> labels = src.indexProperty("Subject"); 477 + const QList<int> labels = src.indexProperty("Label");
478 478
479 Mat &lut = ProductQuantizationLUTs[index]; 479 Mat &lut = ProductQuantizationLUTs[index];
480 lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1); 480 lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1);
openbr/plugins/quantize2.cpp
@@ -77,7 +77,7 @@ class BayesianQuantizationTransform : public Transform @@ -77,7 +77,7 @@ class BayesianQuantizationTransform : public Transform
77 void train(const TemplateList &src) 77 void train(const TemplateList &src)
78 { 78 {
79 const Mat data = OpenCVUtils::toMat(src.data()); 79 const Mat data = OpenCVUtils::toMat(src.data());
80 - const QList<int> labels = src.indexProperty("Subject"); 80 + const QList<int> labels = src.indexProperty("Label");
81 81
82 thresholds = QVector<float>(256*data.cols); 82 thresholds = QVector<float>(256*data.cols);
83 83
openbr/plugins/svm.cpp
@@ -101,6 +101,8 @@ class SVMTransform : public Transform @@ -101,6 +101,8 @@ class SVMTransform : public Transform
101 Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) 101 Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false)
102 Q_PROPERTY(float C READ get_C WRITE set_C RESET reset_C STORED false) 102 Q_PROPERTY(float C READ get_C WRITE set_C RESET reset_C STORED false)
103 Q_PROPERTY(float gamma READ get_gamma WRITE set_gamma RESET reset_gamma STORED false) 103 Q_PROPERTY(float gamma READ get_gamma WRITE set_gamma RESET reset_gamma STORED false)
  104 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  105 + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
104 106
105 public: 107 public:
106 enum Kernel { Linear = CvSVM::LINEAR, 108 enum Kernel { Linear = CvSVM::LINEAR,
@@ -119,6 +121,9 @@ private: @@ -119,6 +121,9 @@ private:
119 BR_PROPERTY(Type, type, C_SVC) 121 BR_PROPERTY(Type, type, C_SVC)
120 BR_PROPERTY(float, C, -1) 122 BR_PROPERTY(float, C, -1)
121 BR_PROPERTY(float, gamma, -1) 123 BR_PROPERTY(float, gamma, -1)
  124 + BR_PROPERTY(QString, inputVariable, "")
  125 + BR_PROPERTY(QString, outputVariable, "")
  126 +
122 127
123 SVM svm; 128 SVM svm;
124 QHash<QString, int> labelMap; 129 QHash<QString, int> labelMap;
@@ -128,14 +133,14 @@ private: @@ -128,14 +133,14 @@ private:
128 { 133 {
129 Mat data = OpenCVUtils::toMat(_data.data()); 134 Mat data = OpenCVUtils::toMat(_data.data());
130 Mat lab; 135 Mat lab;
131 - // If we are doing regression, assume subject has float values 136 + // If we are doing regression, the input variable should have float values
132 if (type == EPS_SVR || type == NU_SVR) { 137 if (type == EPS_SVR || type == NU_SVR) {
133 - lab = OpenCVUtils::toMat(_data.get<float>("Subject")); 138 + lab = OpenCVUtils::toMat(_data.get<float>(inputVariable));
134 } 139 }
135 - // If we are doing classification, assume subject has discrete values, map them 140 + // If we are doing classification, we should be dealing with discrete values. Map them
136 // and store the mapping data 141 // and store the mapping data
137 else { 142 else {
138 - QList<int> dataLabels = _data.indexProperty("Subject", labelMap, reverseLookup); 143 + QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup);
139 lab = OpenCVUtils::toMat(dataLabels); 144 lab = OpenCVUtils::toMat(dataLabels);
140 } 145 }
141 trainSVM(svm, data, lab, kernel, type, C, gamma); 146 trainSVM(svm, data, lab, kernel, type, C, gamma);
@@ -146,9 +151,9 @@ private: @@ -146,9 +151,9 @@ private:
146 dst = src; 151 dst = src;
147 float prediction = svm.predict(src.m().reshape(1, 1)); 152 float prediction = svm.predict(src.m().reshape(1, 1));
148 if (type == EPS_SVR || type == NU_SVR) 153 if (type == EPS_SVR || type == NU_SVR)
149 - dst.file.set("Subject", prediction); 154 + dst.file.set(outputVariable, prediction);
150 else 155 else
151 - dst.file.set("Subject", reverseLookup[prediction]); 156 + dst.file.set(outputVariable, reverseLookup[prediction]);
152 } 157 }
153 158
154 void store(QDataStream &stream) const 159 void store(QDataStream &stream) const
@@ -162,6 +167,24 @@ private: @@ -162,6 +167,24 @@ private:
162 loadSVM(svm, stream); 167 loadSVM(svm, stream);
163 stream >> labelMap >> reverseLookup; 168 stream >> labelMap >> reverseLookup;
164 } 169 }
  170 +
  171 + void init()
  172 + {
  173 + // Since SVM can do regression or classification, we have to check the problem type before
  174 + // specifying target variable names
  175 + if (inputVariable.isEmpty())
  176 + {
  177 + if (type == EPS_SVR || type == NU_SVR) {
  178 + inputVariable = "Regressor";
  179 + if (outputVariable.isEmpty())
  180 + outputVariable = "Regressand";
  181 + }
  182 + else
  183 + inputVariable = "Label";
  184 + }
  185 + if (outputVariable.isEmpty())
  186 + outputVariable = inputVariable;
  187 + }
165 }; 188 };
166 189
167 BR_REGISTER(Transform, SVMTransform) 190 BR_REGISTER(Transform, SVMTransform)
@@ -178,6 +201,8 @@ class SVMDistance : public Distance @@ -178,6 +201,8 @@ class SVMDistance : public Distance
178 Q_ENUMS(Type) 201 Q_ENUMS(Type)
179 Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) 202 Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false)
180 Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) 203 Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false)
  204 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  205 +
181 206
182 public: 207 public:
183 enum Kernel { Linear = CvSVM::LINEAR, 208 enum Kernel { Linear = CvSVM::LINEAR,
@@ -194,13 +219,14 @@ public: @@ -194,13 +219,14 @@ public:
194 private: 219 private:
195 BR_PROPERTY(Kernel, kernel, Linear) 220 BR_PROPERTY(Kernel, kernel, Linear)
196 BR_PROPERTY(Type, type, EPS_SVR) 221 BR_PROPERTY(Type, type, EPS_SVR)
  222 + BR_PROPERTY(QString, inputVariable, "Label")
197 223
198 SVM svm; 224 SVM svm;
199 225
200 void train(const TemplateList &src) 226 void train(const TemplateList &src)
201 { 227 {
202 const Mat data = OpenCVUtils::toMat(src.data()); 228 const Mat data = OpenCVUtils::toMat(src.data());
203 - const QList<int> lab = src.indexProperty("Subject"); 229 + const QList<int> lab = src.indexProperty(inputVariable);
204 230
205 const int instances = data.rows * (data.rows+1) / 2; 231 const int instances = data.rows * (data.rows+1) / 2;
206 Mat deltaData(instances, data.cols, data.type()); 232 Mat deltaData(instances, data.cols, data.type());