Commit 7e8092da4325ceb3a2bf5aebed7ddada408c7c8a

Authored by jklontz
2 parents 3a165aa9 95c1a73a

Merge pull request #58 from biometrics/specificity

Changes to variable handling for classification/regression/clustering
app/br/br.cpp
@@ -129,8 +129,8 @@ public: @@ -129,8 +129,8 @@ public:
129 check(parc == 3, "Incorrect parameter count for 'convert'."); 129 check(parc == 3, "Incorrect parameter count for 'convert'.");
130 br_convert(parv[0], parv[1], parv[2]); 130 br_convert(parv[0], parv[1], parv[2]);
131 } else if (!strcmp(fun, "evalClassification")) { 131 } else if (!strcmp(fun, "evalClassification")) {
132 - check(parc == 2, "Incorrect parameter count for 'evalClassification'.");  
133 - br_eval_classification(parv[0], parv[1]); 132 + check(parc >= 2 && parc <= 4, "Incorrect parameter count for 'evalClassification'.");
  133 + br_eval_classification(parv[0], parv[1], parc >= 3 ? parv[2] : "", parc >= 4 ? parv[3] : "");
134 } else if (!strcmp(fun, "evalClustering")) { 134 } else if (!strcmp(fun, "evalClustering")) {
135 check(parc == 2, "Incorrect parameter count for 'evalClustering'."); 135 check(parc == 2, "Incorrect parameter count for 'evalClustering'.");
136 br_eval_clustering(parv[0], parv[1]); 136 br_eval_clustering(parv[0], parv[1]);
@@ -138,8 +138,8 @@ public: @@ -138,8 +138,8 @@ public:
138 check((parc >= 2) && (parc <= 3), "Incorrect parameter count for 'evalDetection'."); 138 check((parc >= 2) && (parc <= 3), "Incorrect parameter count for 'evalDetection'.");
139 br_eval_detection(parv[0], parv[1], parc == 3 ? parv[2] : ""); 139 br_eval_detection(parv[0], parv[1], parc == 3 ? parv[2] : "");
140 } else if (!strcmp(fun, "evalRegression")) { 140 } else if (!strcmp(fun, "evalRegression")) {
141 - check(parc == 2, "Incorrect parameter count for 'evalRegression'.");  
142 - br_eval_regression(parv[0], parv[1]); 141 + check(parc >= 2 && parc <= 4, "Incorrect parameter count for 'evalRegression'.");
  142 + br_eval_regression(parv[0], parv[1], parc >= 3 ? parv[2] : "", parc >= 4 ? parv[3] : "");
143 } else if (!strcmp(fun, "plotDetection")) { 143 } else if (!strcmp(fun, "plotDetection")) {
144 check(parc >= 2, "Incorrect parameter count for 'plotDetection'."); 144 check(parc >= 2, "Incorrect parameter count for 'plotDetection'.");
145 br_plot_detection(parc-1, parv, parv[parc-1], true); 145 br_plot_detection(parc-1, parv, parv[parc-1], true);
@@ -215,10 +215,10 @@ private: @@ -215,10 +215,10 @@ private:
215 "-combineMasks <mask> ... <mask> {mask} (And|Or)\n" 215 "-combineMasks <mask> ... <mask> {mask} (And|Or)\n"
216 "-cat <gallery> ... <gallery> {gallery}\n" 216 "-cat <gallery> ... <gallery> {gallery}\n"
217 "-convert (Format|Gallery|Output) <input_file> {output_file}\n" 217 "-convert (Format|Gallery|Output) <input_file> {output_file}\n"
218 - "-evalClassification <predicted_gallery> <truth_gallery>\n" 218 + "-evalClassification <predicted_gallery> <truth_gallery> <predicted property name> <ground truth proprty name>\n"
219 "-evalClustering <clusters> <gallery>\n" 219 "-evalClustering <clusters> <gallery>\n"
220 "-evalDetection <predicted_gallery> <truth_gallery> [{csv}]\n" 220 "-evalDetection <predicted_gallery> <truth_gallery> [{csv}]\n"
221 - "-evalRegression <predicted_gallery> <truth_gallery>\n" 221 + "-evalRegression <predicted_gallery> <truth_gallery> <predicted property name> <ground truth property name>\n"
222 "-plotDetection <file> ... <file> {destination}\n" 222 "-plotDetection <file> ... <file> {destination}\n"
223 "-plotMetadata <file> ... <file> <columns>\n" 223 "-plotMetadata <file> ... <file> <columns>\n"
224 "-getHeader <matrix>\n" 224 "-getHeader <matrix>\n"
app/examples/age_estimation.cpp
@@ -29,7 +29,7 @@ @@ -29,7 +29,7 @@
29 29
30 static void printTemplate(const br::Template &t) 30 static void printTemplate(const br::Template &t)
31 { 31 {
32 - printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Subject"))); 32 + printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Age")));
33 } 33 }
34 34
35 int main(int argc, char *argv[]) 35 int main(int argc, char *argv[])
app/examples/gender_estimation.cpp
@@ -29,7 +29,7 @@ @@ -29,7 +29,7 @@
29 29
30 static void printTemplate(const br::Template &t) 30 static void printTemplate(const br::Template &t)
31 { 31 {
32 - printf("%s gender: %s\n", qPrintable(t.file.fileName()), qPrintable(t.file.get<QString>("Subject"))); 32 + printf("%s gender: %s\n", qPrintable(t.file.fileName()), qPrintable(t.file.get<QString>("Gender")));
33 } 33 }
34 34
35 int main(int argc, char *argv[]) 35 int main(int argc, char *argv[])
openbr/core/bee.cpp
@@ -99,10 +99,10 @@ void BEE::writeSigset(const QString &amp;sigset, const br::FileList &amp;files, bool ign @@ -99,10 +99,10 @@ void BEE::writeSigset(const QString &amp;sigset, const br::FileList &amp;files, bool ign
99 QStringList metadata; 99 QStringList metadata;
100 if (!ignoreMetadata) 100 if (!ignoreMetadata)
101 foreach (const QString &key, file.localKeys()) { 101 foreach (const QString &key, file.localKeys()) {
102 - if ((key == "Index") || (key == "Subject")) continue; 102 + if ((key == "Index") || (key == "Label")) continue;
103 metadata.append(key+"=\""+QtUtils::toString(file.value(key))+"\""); 103 metadata.append(key+"=\""+QtUtils::toString(file.value(key))+"\"");
104 } 104 }
105 - lines.append("\t<biometric-signature name=\"" + file.get<QString>("Subject",file.fileName()) +"\">"); 105 + lines.append("\t<biometric-signature name=\"" + file.get<QString>("Label",file.fileName()) +"\">");
106 lines.append("\t\t<presentation file-name=\"" + file.name + "\" " + metadata.join(" ") + "/>"); 106 lines.append("\t\t<presentation file-name=\"" + file.name + "\" " + metadata.join(" ") + "/>");
107 lines.append("\t</biometric-signature>"); 107 lines.append("\t</biometric-signature>");
108 } 108 }
@@ -266,10 +266,11 @@ void BEE::makeMask(const QString &amp;targetInput, const QString &amp;queryInput, const @@ -266,10 +266,11 @@ void BEE::makeMask(const QString &amp;targetInput, const QString &amp;queryInput, const
266 266
267 cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition) 267 cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition)
268 { 268 {
269 - // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet  
270 - // -cao  
271 - QList<QString> targetLabels = File::get<QString>(targets, "Subject", "-1");  
272 - QList<QString> queryLabels = File::get<QString>(queries, "Subject", "-1"); 269 + // Direct use of "Label" isn't general, also would prefer to use indexProperty, rather than
  270 + // doing string comparisons (but that isn't implemented yet for FileList) -cao
  271 + QList<QString> targetLabels = File::get<QString>(targets, "Label", "-1");
  272 + QList<QString> queryLabels = File::get<QString>(queries, "Label", "-1");
  273 +
273 QList<int> targetPartitions = targets.crossValidationPartitions(); 274 QList<int> targetPartitions = targets.crossValidationPartitions();
274 QList<int> queryPartitions = queries.crossValidationPartitions(); 275 QList<int> queryPartitions = queries.crossValidationPartitions();
275 276
openbr/core/cluster.cpp
@@ -279,8 +279,8 @@ void br::EvalClustering(const QString &amp;csv, const QString &amp;input) @@ -279,8 +279,8 @@ void br::EvalClustering(const QString &amp;csv, const QString &amp;input)
279 qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input)); 279 qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input));
280 280
281 // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are 281 // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are
282 - // not named).  
283 - QList<int> labels = File::get<int>(TemplateList::fromGallery(input), "Subject"); 282 + // not named). Direct use of ClusterID is not general -cao
  283 + QList<int> labels = File::get<int>(TemplateList::fromGallery(input), "ClusterID");
284 284
285 QHash<int, int> labelToIndex; 285 QHash<int, int> labelToIndex;
286 int nClusters = 0; 286 int nClusters = 0;
openbr/core/common.cpp
@@ -15,11 +15,15 @@ @@ -15,11 +15,15 @@
15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
16 16
17 #include "common.h" 17 #include "common.h"
  18 +#include <QMutex>
18 19
19 using namespace std; 20 using namespace std;
20 21
21 /**** GLOBAL ****/ 22 /**** GLOBAL ****/
22 void Common::seedRNG() { 23 void Common::seedRNG() {
  24 + static QMutex seedControl;
  25 + QMutexLocker lock(&seedControl);
  26 +
23 static bool seeded = false; 27 static bool seeded = false;
24 if (!seeded) { 28 if (!seeded) {
25 srand(0); // We seed with 0 instead of time(NULL) to have reproducible randomness 29 srand(0); // We seed with 0 instead of time(NULL) to have reproducible randomness
@@ -29,8 +33,6 @@ void Common::seedRNG() { @@ -29,8 +33,6 @@ void Common::seedRNG() {
29 33
30 QList<int> Common::RandSample(int n, int max, int min, bool unique) 34 QList<int> Common::RandSample(int n, int max, int min, bool unique)
31 { 35 {
32 - seedRNG();  
33 -  
34 QList<int> samples; samples.reserve(n); 36 QList<int> samples; samples.reserve(n);
35 int range = max-min; 37 int range = max-min;
36 if (range <= 0) qFatal("Non-positive range."); 38 if (range <= 0) qFatal("Non-positive range.");
@@ -50,8 +52,6 @@ QList&lt;int&gt; Common::RandSample(int n, int max, int min, bool unique) @@ -50,8 +52,6 @@ QList&lt;int&gt; Common::RandSample(int n, int max, int min, bool unique)
50 52
51 QList<int> Common::RandSample(int n, const QSet<int> &values, bool unique) 53 QList<int> Common::RandSample(int n, const QSet<int> &values, bool unique)
52 { 54 {
53 - seedRNG();  
54 -  
55 QList<int> valueList = values.toList(); 55 QList<int> valueList = values.toList();
56 if (unique && (values.size() <= n)) return valueList; 56 if (unique && (values.size() <= n)) return valueList;
57 57
openbr/core/eval.cpp
@@ -255,9 +255,20 @@ struct Counter @@ -255,9 +255,20 @@ struct Counter
255 } 255 }
256 }; 256 };
257 257
258 -void EvalClassification(const QString &predictedInput, const QString &truthInput) 258 +void EvalClassification(const QString &predictedInput, const QString &truthInput, QString predictedProperty, QString truthProperty)
259 { 259 {
260 qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); 260 qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
  261 +
  262 + if (predictedProperty.isEmpty())
  263 + predictedProperty = "Label";
  264 + // If predictedProperty is specified, but truthProperty isn't, copy over the value from
  265 + // predicted property
  266 + else if (truthProperty.isEmpty())
  267 + truthProperty = predictedProperty;
  268 +
  269 + if (truthProperty.isEmpty())
  270 + truthProperty = "Label";
  271 +
261 TemplateList predicted(TemplateList::fromGallery(predictedInput)); 272 TemplateList predicted(TemplateList::fromGallery(predictedInput));
262 TemplateList truth(TemplateList::fromGallery(truthInput)); 273 TemplateList truth(TemplateList::fromGallery(truthInput));
263 if (predicted.size() != truth.size()) qFatal("Input size mismatch."); 274 if (predicted.size() != truth.size()) qFatal("Input size mismatch.");
@@ -267,9 +278,8 @@ void EvalClassification(const QString &amp;predictedInput, const QString &amp;truthInput @@ -267,9 +278,8 @@ void EvalClassification(const QString &amp;predictedInput, const QString &amp;truthInput
267 if (predicted[i].file.name != truth[i].file.name) 278 if (predicted[i].file.name != truth[i].file.name)
268 qFatal("Input order mismatch."); 279 qFatal("Input order mismatch.");
269 280
270 - // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy.  
271 - QString predictedSubject = predicted[i].file.get<QString>("Subject");  
272 - QString trueSubject = truth[i].file.get<QString>("Subject"); 281 + QString predictedSubject = predicted[i].file.get<QString>(predictedProperty);
  282 + QString trueSubject = truth[i].file.get<QString>(truthProperty);
273 283
274 QStringList predictedSubjects(predictedSubject); 284 QStringList predictedSubjects(predictedSubject);
275 QStringList trueSubjects(trueSubject); 285 QStringList trueSubjects(trueSubject);
@@ -466,21 +476,37 @@ float EvalDetection(const QString &amp;predictedInput, const QString &amp;truthInput, co @@ -466,21 +476,37 @@ float EvalDetection(const QString &amp;predictedInput, const QString &amp;truthInput, co
466 return averageOverlap; 476 return averageOverlap;
467 } 477 }
468 478
469 -void EvalRegression(const QString &predictedInput, const QString &truthInput) 479 +void EvalRegression(const QString &predictedInput, const QString &truthInput, QString predictedProperty, QString truthProperty)
470 { 480 {
471 qDebug("Evaluating regression of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); 481 qDebug("Evaluating regression of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
  482 +
  483 + if (predictedProperty.isEmpty())
  484 + predictedProperty = "Regressor";
  485 + // If predictedProperty is specified, but truthProperty isn't, copy the value over
  486 + // rather than using the default for truthProperty
  487 + else if (truthProperty.isEmpty())
  488 + truthProperty = predictedProperty;
  489 +
  490 + if (truthProperty.isEmpty())
  491 + predictedProperty = "Regressand";
  492 +
472 const TemplateList predicted(TemplateList::fromGallery(predictedInput)); 493 const TemplateList predicted(TemplateList::fromGallery(predictedInput));
473 const TemplateList truth(TemplateList::fromGallery(truthInput)); 494 const TemplateList truth(TemplateList::fromGallery(truthInput));
474 if (predicted.size() != truth.size()) qFatal("Input size mismatch."); 495 if (predicted.size() != truth.size()) qFatal("Input size mismatch.");
475 496
476 float rmsError = 0; 497 float rmsError = 0;
  498 + float maeError = 0;
477 QStringList truthValues, predictedValues; 499 QStringList truthValues, predictedValues;
478 for (int i=0; i<predicted.size(); i++) { 500 for (int i=0; i<predicted.size(); i++) {
479 if (predicted[i].file.name != truth[i].file.name) 501 if (predicted[i].file.name != truth[i].file.name)
480 qFatal("Input order mismatch."); 502 qFatal("Input order mismatch.");
481 - rmsError += pow(predicted[i].file.get<float>("Subject")-truth[i].file.get<float>("Subject"), 2.f);  
482 - truthValues.append(QString::number(truth[i].file.get<float>("Subject")));  
483 - predictedValues.append(QString::number(predicted[i].file.get<float>("Subject"))); 503 +
  504 + float difference = predicted[i].file.get<float>(predictedProperty) - truth[i].file.get<float>(truthProperty);
  505 +
  506 + rmsError += pow(difference, 2.f);
  507 + maeError += fabsf(difference);
  508 + truthValues.append(QString::number(truth[i].file.get<float>(truthProperty)));
  509 + predictedValues.append(QString::number(predicted[i].file.get<float>(predictedProperty)));
484 } 510 }
485 511
486 QStringList rSource; 512 QStringList rSource;
@@ -500,6 +526,7 @@ void EvalRegression(const QString &amp;predictedInput, const QString &amp;truthInput) @@ -500,6 +526,7 @@ void EvalRegression(const QString &amp;predictedInput, const QString &amp;truthInput)
500 if (success) QtUtils::showFile("EvalRegression.pdf"); 526 if (success) QtUtils::showFile("EvalRegression.pdf");
501 527
502 qDebug("RMS Error = %f", sqrt(rmsError/predicted.size())); 528 qDebug("RMS Error = %f", sqrt(rmsError/predicted.size()));
  529 + qDebug("MAE = %f", maeError/predicted.size());
503 } 530 }
504 531
505 } // namespace br 532 } // namespace br
openbr/core/eval.h
@@ -26,9 +26,9 @@ namespace br @@ -26,9 +26,9 @@ namespace br
26 float Evaluate(const QString &simmat, const QString &mask = "", const QString &csv = ""); // Returns TAR @ FAR = 0.001 26 float Evaluate(const QString &simmat, const QString &mask = "", const QString &csv = ""); // Returns TAR @ FAR = 0.001
27 float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv = "", int parition = 0); 27 float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv = "", int parition = 0);
28 float Evaluate(const cv::Mat &scores, const cv::Mat &masks, const QString &csv = ""); 28 float Evaluate(const cv::Mat &scores, const cv::Mat &masks, const QString &csv = "");
29 - void EvalClassification(const QString &predictedInput, const QString &truthInput); 29 + void EvalClassification(const QString &predictedInput, const QString &truthInput, QString predictedProperty="", QString truthProperty="");
30 float EvalDetection(const QString &predictedInput, const QString &truthInput, const QString &csv = ""); // Return average overlap 30 float EvalDetection(const QString &predictedInput, const QString &truthInput, const QString &csv = ""); // Return average overlap
31 - void EvalRegression(const QString &predictedInput, const QString &truthInput); 31 + void EvalRegression(const QString &predictedInput, const QString &truthInput, QString predictedProperty="", QString truthProperty="");
32 } 32 }
33 33
34 #endif // __EVAL_H 34 #endif // __EVAL_H
openbr/frvt2012.cpp
@@ -132,7 +132,7 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &amp;input_face, int32_t &amp;age) @@ -132,7 +132,7 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &amp;input_face, int32_t &amp;age)
132 TemplateList templates; 132 TemplateList templates;
133 templates.append(templateFromONEFACE(input_face)); 133 templates.append(templateFromONEFACE(input_face));
134 templates >> *frvt2012_age_transform.data(); 134 templates >> *frvt2012_age_transform.data();
135 - age = templates.first().file.get<float>("Subject"); 135 + age = templates.first().file.get<float>("Age");
136 return templates.first().file.failed() ? 4 : 0; 136 return templates.first().file.failed() ? 4 : 0;
137 } 137 }
138 138
@@ -141,6 +141,6 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &amp;input_face, int8_t &amp;gender, @@ -141,6 +141,6 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &amp;input_face, int8_t &amp;gender,
141 TemplateList templates; 141 TemplateList templates;
142 templates.append(templateFromONEFACE(input_face)); 142 templates.append(templateFromONEFACE(input_face));
143 templates >> *frvt2012_gender_transform.data(); 143 templates >> *frvt2012_gender_transform.data();
144 - mf = gender = templates.first().file.get<QString>("Subject") == "Male" ? 0 : 1; 144 + mf = gender = templates.first().file.get<QString>("Gender") == "Male" ? 0 : 1;
145 return templates.first().file.failed() ? 4 : 0; 145 return templates.first().file.failed() ? 4 : 0;
146 } 146 }
openbr/gui/classifier.cpp
@@ -39,19 +39,23 @@ void Classifier::_classify(File file) @@ -39,19 +39,23 @@ void Classifier::_classify(File file)
39 { 39 {
40 QString key, value; 40 QString key, value;
41 foreach (const File &f, Enroll(file.flat(), File("[algorithm=" + algorithm + "]"))) { 41 foreach (const File &f, Enroll(file.flat(), File("[algorithm=" + algorithm + "]"))) {
42 - if (!f.contains("Label"))  
43 - continue;  
44 42
45 if (algorithm == "GenderClassification") { 43 if (algorithm == "GenderClassification") {
46 key = "Gender"; 44 key = "Gender";
47 - value = (f.get<QString>("Subject"));  
48 } else if (algorithm == "AgeRegression") { 45 } else if (algorithm == "AgeRegression") {
49 key = "Age"; 46 key = "Age";
50 - value = QString::number(int(f.get<float>("Subject")+0.5)) + " Years";  
51 } else { 47 } else {
52 key = algorithm; 48 key = algorithm;
53 - value = f.get<QString>("Subject");  
54 } 49 }
  50 +
  51 + if (!f.contains(key))
  52 + continue;
  53 +
  54 + if (algorithm == "AgeRegression")
  55 + value = QString::number(int(f.get<float>(key)+0.5)) + " Years";
  56 + else
  57 + value = f.get<QString>(key);
  58 +
55 break; 59 break;
56 } 60 }
57 61
openbr/openbr.cpp
@@ -72,9 +72,9 @@ float br_eval(const char *simmat, const char *mask, const char *csv) @@ -72,9 +72,9 @@ float br_eval(const char *simmat, const char *mask, const char *csv)
72 return Evaluate(simmat, mask, csv); 72 return Evaluate(simmat, mask, csv);
73 } 73 }
74 74
75 -void br_eval_classification(const char *predicted_gallery, const char *truth_gallery) 75 +void br_eval_classification(const char *predicted_gallery, const char *truth_gallery, const char *predicted_property, const char * truth_property)
76 { 76 {
77 - EvalClassification(predicted_gallery, truth_gallery); 77 + EvalClassification(predicted_gallery, truth_gallery, predicted_property, truth_property);
78 } 78 }
79 79
80 void br_eval_clustering(const char *csv, const char *gallery) 80 void br_eval_clustering(const char *csv, const char *gallery)
@@ -87,9 +87,9 @@ float br_eval_detection(const char *predicted_gallery, const char *truth_gallery @@ -87,9 +87,9 @@ float br_eval_detection(const char *predicted_gallery, const char *truth_gallery
87 return EvalDetection(predicted_gallery, truth_gallery, csv); 87 return EvalDetection(predicted_gallery, truth_gallery, csv);
88 } 88 }
89 89
90 -void br_eval_regression(const char *predicted_gallery, const char *truth_gallery) 90 +void br_eval_regression(const char *predicted_gallery, const char *truth_gallery, const char * predicted_property, const char * truth_property)
91 { 91 {
92 - EvalRegression(predicted_gallery, truth_gallery); 92 + EvalRegression(predicted_gallery, truth_gallery, predicted_property, truth_property);
93 } 93 }
94 94
95 void br_finalize() 95 void br_finalize()
openbr/openbr.h
@@ -149,7 +149,7 @@ BR_EXPORT float br_eval(const char *simmat, const char *mask, const char *csv = @@ -149,7 +149,7 @@ BR_EXPORT float br_eval(const char *simmat, const char *mask, const char *csv =
149 * \param predicted_gallery The predicted br::Gallery. 149 * \param predicted_gallery The predicted br::Gallery.
150 * \param truth_gallery The ground truth br::Gallery. 150 * \param truth_gallery The ground truth br::Gallery.
151 */ 151 */
152 -BR_EXPORT void br_eval_classification(const char *predicted_gallery, const char *truth_gallery); 152 +BR_EXPORT void br_eval_classification(const char *predicted_gallery, const char *truth_gallery, const char * predicted_property="", const char * truth_property="");
153 153
154 /*! 154 /*!
155 * \brief Evaluates and prints clustering accuracy to the terminal. 155 * \brief Evaluates and prints clustering accuracy to the terminal.
@@ -173,7 +173,7 @@ BR_EXPORT float br_eval_detection(const char *predicted_gallery, const char *tru @@ -173,7 +173,7 @@ BR_EXPORT float br_eval_detection(const char *predicted_gallery, const char *tru
173 * \param predicted_gallery The predicted br::Gallery. 173 * \param predicted_gallery The predicted br::Gallery.
174 * \param truth_gallery The ground truth br::Gallery. 174 * \param truth_gallery The ground truth br::Gallery.
175 */ 175 */
176 -BR_EXPORT void br_eval_regression(const char *predicted_gallery, const char *truth_gallery); 176 +BR_EXPORT void br_eval_regression(const char *predicted_gallery, const char *truth_gallery, const char * predicted_property="", const char * truth_property="");
177 177
178 /*! 178 /*!
179 * \brief Wraps br::Context::finalize() 179 * \brief Wraps br::Context::finalize()
openbr/openbr_plugin.cpp
@@ -412,7 +412,8 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery) @@ -412,7 +412,8 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery)
412 // of target images to every partition 412 // of target images to every partition
413 newTemplates[i].file.set("Partition", -1); 413 newTemplates[i].file.set("Partition", -1);
414 } else { 414 } else {
415 - const QByteArray md5 = QCryptographicHash::hash(newTemplates[i].file.get<QString>("Subject").toLatin1(), QCryptographicHash::Md5); 415 + // Direct use of "Label" is not general -cao
  416 + const QByteArray md5 = QCryptographicHash::hash(newTemplates[i].file.get<QString>("Label").toLatin1(), QCryptographicHash::Md5);
416 // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow 417 // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow
417 newTemplates[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); 418 newTemplates[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate);
418 } 419 }
@@ -890,6 +891,8 @@ void br::Context::initialize(int &amp;argc, char *argv[], QString sdkPath, bool use_ @@ -890,6 +891,8 @@ void br::Context::initialize(int &amp;argc, char *argv[], QString sdkPath, bool use_
890 891
891 qInstallMessageHandler(messageHandler); 892 qInstallMessageHandler(messageHandler);
892 893
  894 + Common::seedRNG();
  895 +
893 // Search for SDK 896 // Search for SDK
894 if (sdkPath.isEmpty()) { 897 if (sdkPath.isEmpty()) {
895 QStringList checkPaths; checkPaths << QDir::currentPath() << QCoreApplication::applicationDirPath(); 898 QStringList checkPaths; checkPaths << QDir::currentPath() << QCoreApplication::applicationDirPath();
@@ -1082,9 +1085,6 @@ Transform::Transform(bool _independent, bool _trainable) @@ -1082,9 +1085,6 @@ Transform::Transform(bool _independent, bool _trainable)
1082 { 1085 {
1083 independent = _independent; 1086 independent = _independent;
1084 trainable = _trainable; 1087 trainable = _trainable;
1085 - classes = std::numeric_limits<int>::max();  
1086 - instances = std::numeric_limits<int>::max();  
1087 - fraction = 1;  
1088 } 1088 }
1089 1089
1090 Transform *Transform::make(QString str, QObject *parent) 1090 Transform *Transform::make(QString str, QObject *parent)
@@ -1140,9 +1140,6 @@ Transform *Transform::make(QString str, QObject *parent) @@ -1140,9 +1140,6 @@ Transform *Transform::make(QString str, QObject *parent)
1140 Transform *Transform::clone() const 1140 Transform *Transform::clone() const
1141 { 1141 {
1142 Transform *clone = Factory<Transform>::make(file.flat()); 1142 Transform *clone = Factory<Transform>::make(file.flat());
1143 - clone->classes = classes;  
1144 - clone->instances = instances;  
1145 - clone->fraction = fraction;  
1146 return clone; 1143 return clone;
1147 } 1144 }
1148 1145
openbr/openbr_plugin.h
@@ -130,13 +130,6 @@ void reset_##NAME() { NAME = DEFAULT; } @@ -130,13 +130,6 @@ void reset_##NAME() { NAME = DEFAULT; }
130 * -# If the value is convertable to a floating point number then it is represented with \c float. 130 * -# If the value is convertable to a floating point number then it is represented with \c float.
131 * -# Otherwise, it is represented with \c QString. 131 * -# Otherwise, it is represented with \c QString.
132 * 132 *
133 - * The metadata keys \c Subject and \c Label have special significance in the system.  
134 - * \c Subject is a string specifying a unique identifier used to determine ground truth match/non-match.  
135 - * \c Label is a floating point value used for supervised learning.  
136 - * When the system needs labels for training, but only subjects are provided in the file metadata, the rule for generating labels is as follows.  
137 - * If the subject value can be converted to a float then do so and consider that the label.  
138 - * Otherwise, generate a unique integer ID for the string starting from zero and incrementing by one everytime another ID is needed.  
139 - *  
140 * Metadata keys fall into one of two categories: 133 * Metadata keys fall into one of two categories:
141 * - \c camelCaseKeys are inputs that specify how to process the file. 134 * - \c camelCaseKeys are inputs that specify how to process the file.
142 * - \c Capitalized_Underscored_Keys are outputs computed from processing the file. 135 * - \c Capitalized_Underscored_Keys are outputs computed from processing the file.
@@ -147,8 +140,6 @@ void reset_##NAME() { NAME = DEFAULT; } @@ -147,8 +140,6 @@ void reset_##NAME() { NAME = DEFAULT; }
147 * --- | ---- | ----------- 140 * --- | ---- | -----------
148 * separator | QString | Seperate #name into multiple files 141 * separator | QString | Seperate #name into multiple files
149 * Index | int | Index of a template in a template list 142 * Index | int | Index of a template in a template list
150 - * Subject | QString | Class name  
151 - * Label | float | Class value  
152 * Confidence | float | Classification/Regression quality 143 * Confidence | float | Classification/Regression quality
153 * FTE | bool | Failure to enroll 144 * FTE | bool | Failure to enroll
154 * FTO | bool | Failure to open 145 * FTO | bool | Failure to open
@@ -157,13 +148,15 @@ void reset_##NAME() { NAME = DEFAULT; } @@ -157,13 +148,15 @@ void reset_##NAME() { NAME = DEFAULT; }
157 * *_Width | float | Size 148 * *_Width | float | Size
158 * *_Height | float | Size 149 * *_Height | float | Size
159 * *_Radius | float | Size 150 * *_Radius | float | Size
  151 + * Label | QString | Class label
160 * Theta | float | Pose 152 * Theta | float | Pose
161 * Roll | float | Pose 153 * Roll | float | Pose
162 * Pitch | float | Pose 154 * Pitch | float | Pose
163 * Yaw | float | Pose 155 * Yaw | float | Pose
164 * Points | QList<QPointF> | List of unnamed points 156 * Points | QList<QPointF> | List of unnamed points
165 * Rects | QList<Rect> | List of unnamed rects 157 * Rects | QList<Rect> | List of unnamed rects
166 - * Age | QString | Age used for demographic filtering 158 + * Age | float | Age used for demographic filtering
  159 + * Gender | QString | Subject gender
167 * _* | * | Reserved for internal use 160 * _* | * | Reserved for internal use
168 */ 161 */
169 struct BR_EXPORT File 162 struct BR_EXPORT File
@@ -172,7 +165,7 @@ struct BR_EXPORT File @@ -172,7 +165,7 @@ struct BR_EXPORT File
172 165
173 File() {} 166 File() {}
174 File(const QString &file) { init(file); } /*!< \brief Construct a file from a string. */ 167 File(const QString &file) { init(file); } /*!< \brief Construct a file from a string. */
175 - File(const QString &file, const QVariant &subject) { init(file); set("Subject", subject); } /*!< \brief Construct a file from a string and assign a label. */ 168 + File(const QString &file, const QVariant &label) { init(file); set("Label", label); } /*!< \brief Construct a file from a string and assign a label. */
176 File(const char *file) { init(file); } /*!< \brief Construct a file from a c-style string. */ 169 File(const char *file) { init(file); } /*!< \brief Construct a file from a c-style string. */
177 inline operator QString() const { return name; } /*!< \brief Returns #name. */ 170 inline operator QString() const { return name; } /*!< \brief Returns #name. */
178 QString flat() const; /*!< \brief A stringified version of the file with metadata. */ 171 QString flat() const; /*!< \brief A stringified version of the file with metadata. */
@@ -1058,12 +1051,6 @@ class BR_EXPORT Transform : public Object @@ -1058,12 +1051,6 @@ class BR_EXPORT Transform : public Object
1058 Q_OBJECT 1051 Q_OBJECT
1059 1052
1060 public: 1053 public:
1061 - Q_PROPERTY(int classes READ get_classes WRITE set_classes RESET reset_classes STORED false)  
1062 - Q_PROPERTY(int instances READ get_instances WRITE set_instances RESET reset_instances STORED false)  
1063 - Q_PROPERTY(float fraction READ get_fraction WRITE set_fraction RESET reset_fraction STORED false)  
1064 - BR_PROPERTY(int, classes, std::numeric_limits<int>::max())  
1065 - BR_PROPERTY(int, instances, std::numeric_limits<int>::max())  
1066 - BR_PROPERTY(float, fraction, 1)  
1067 bool independent, trainable; 1054 bool independent, trainable;
1068 1055
1069 virtual ~Transform() {} 1056 virtual ~Transform() {}
openbr/plugins/algorithms.cpp
@@ -48,7 +48,7 @@ class AlgorithmsInitializer : public Initializer @@ -48,7 +48,7 @@ class AlgorithmsInitializer : public Initializer
48 // Video 48 // Video
49 Globals->abbreviations.insert("DisplayVideo", "Stream([FPSLimit(30)+Show(false,[FrameNumber])+Discard])"); 49 Globals->abbreviations.insert("DisplayVideo", "Stream([FPSLimit(30)+Show(false,[FrameNumber])+Discard])");
50 Globals->abbreviations.insert("PerFrameDetection", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+ASEFEyes+RestoreMat(original)+Draw(inPlace=true),Show(false,[FrameNumber])+Discard])"); 50 Globals->abbreviations.insert("PerFrameDetection", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+ASEFEyes+RestoreMat(original)+Draw(inPlace=true),Show(false,[FrameNumber])+Discard])");
51 - Globals->abbreviations.insert("AgeGenderDemo", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+Expand+<FaceClassificationRegistration>+<FaceClassificationExtraction>+(<AgeRegressor>+Rename(Subject,Age)+Discard)/(<GenderClassifier>+Rename(Subject,Gender)+Discard)+RestoreMat(original)+Draw(inPlace=true)+DrawPropertiesPoint([Age,Gender],Affine_0,inPlace=true)+SaveMat(original)+Discard+Contract,RestoreMat(original)+FPSCalc+Show(false,[AvgFPS,Age,Gender])+Discard])"); 51 + Globals->abbreviations.insert("AgeGenderDemo", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+Expand+<FaceClassificationRegistration>+<FaceClassificationExtraction>+<AgeRegressor>/<GenderClassifier>+Discard+RestoreMat(original)+Draw(inPlace=true)+DrawPropertiesPoint([Age,Gender],Affine_0,inPlace=true)+SaveMat(original)+Discard+Contract,RestoreMat(original)+FPSCalc+Show(false,[AvgFPS,Age,Gender])+Discard])");
52 Globals->abbreviations.insert("BoVW", "Flatten+CatRows+KMeans(500)+Hist(500)"); 52 Globals->abbreviations.insert("BoVW", "Flatten+CatRows+KMeans(500)+Hist(500)");
53 Globals->abbreviations.insert("HOF", "Stream([KeyPointDetector(SIFT),AggregateFrames(2)+OpticalFlow,ROI,HoGDescriptor])+BoVW"); 53 Globals->abbreviations.insert("HOF", "Stream([KeyPointDetector(SIFT),AggregateFrames(2)+OpticalFlow,ROI,HoGDescriptor])+BoVW");
54 Globals->abbreviations.insert("HoG", "Stream([KeyPointDetector(SIFT),ROI,HoGDescriptor])+BoVW"); 54 Globals->abbreviations.insert("HoG", "Stream([KeyPointDetector(SIFT),ROI,HoGDescriptor])+BoVW");
@@ -78,14 +78,14 @@ class AlgorithmsInitializer : public Initializer @@ -78,14 +78,14 @@ class AlgorithmsInitializer : public Initializer
78 Globals->abbreviations.insert("FaceDetection", "(Open+Cvt(Gray)+Cascade(FrontalFace))"); 78 Globals->abbreviations.insert("FaceDetection", "(Open+Cvt(Gray)+Cascade(FrontalFace))");
79 Globals->abbreviations.insert("DenseLBP", "(Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59))"); 79 Globals->abbreviations.insert("DenseLBP", "(Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59))");
80 Globals->abbreviations.insert("DenseSIFT", "(Grid(10,10)+SIFTDescriptor(12)+ByRow)"); 80 Globals->abbreviations.insert("DenseSIFT", "(Grid(10,10)+SIFTDescriptor(12)+ByRow)");
81 - Globals->abbreviations.insert("FaceRecognitionRegistration", "(ASEFEyes+Affine(88,88,0.25,0.35)+FTE(DFFS,instances=1))");  
82 - Globals->abbreviations.insert("FaceRecognitionExtraction", "(Mask+DenseSIFT/DenseLBP+PCA(0.95,instances=1)+Normalize(L2)+Cat)");  
83 - Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+LDA(0.98,instances=-2)+Cat+PCA(768,instances=1))"); 81 + Globals->abbreviations.insert("FaceRecognitionRegistration", "(ASEFEyes+Affine(88,88,0.25,0.35)+DownsampleTraining(FTE(DFFS),instances=1))");
  82 + Globals->abbreviations.insert("FaceRecognitionExtraction", "(Mask+DenseSIFT/DenseLBP+DownsampleTraining(PCA(0.95),instances=1)+Normalize(L2)+Cat)");
  83 + Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+DownsampleTraining(LDA(0.98),instances=-2)+Cat+DownsampleTraining(PCA(768),instances=1))");
84 Globals->abbreviations.insert("FaceRecognitionQuantization", "(Normalize(L1)+Quantize)"); 84 Globals->abbreviations.insert("FaceRecognitionQuantization", "(Normalize(L1)+Quantize)");
85 Globals->abbreviations.insert("FaceClassificationRegistration", "(ASEFEyes+Affine(56,72,0.33,0.45)+FTE(DFFS))"); 85 Globals->abbreviations.insert("FaceClassificationRegistration", "(ASEFEyes+Affine(56,72,0.33,0.45)+FTE(DFFS))");
86 - Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+PCA(0.95,instances=-1)+Cat)");  
87 - Globals->abbreviations.insert("AgeRegressor", "Center(Range,instances=-1)+SVM(RBF,EPS_SVR,instances=100)");  
88 - Globals->abbreviations.insert("GenderClassifier", "Center(Range,instances=-1)+SVM(RBF,C_SVC,instances=4000)"); 86 + Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+DownsampleTraining(PCA(0.95),instances=-1, inputVariable=Gender)+Cat)");
  87 + Globals->abbreviations.insert("AgeRegressor", "DownsampleTraining(Center(Range),instances=-1, inputVariable=Age)+DownsampleTraining(SVM(RBF,EPS_SVR,inputVariable=Age),instances=100, inputVariable=Age)");
  88 + Globals->abbreviations.insert("GenderClassifier", "DownsampleTraining(Center(Range),instances=-1, inputVariable=Gender)+DownsampleTraining(SVM(RBF,C_SVC,inputVariable=Gender),instances=4000, inputVariable=Gender)");
89 Globals->abbreviations.insert("UCharL1", "Unit(ByteL1)"); 89 Globals->abbreviations.insert("UCharL1", "Unit(ByteL1)");
90 } 90 }
91 }; 91 };
openbr/plugins/cluster.cpp
@@ -89,10 +89,14 @@ class KNNTransform : public Transform @@ -89,10 +89,14 @@ class KNNTransform : public Transform
89 Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance STORED false) 89 Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
90 Q_PROPERTY(bool weighted READ get_weighted WRITE set_weighted RESET reset_weighted STORED false) 90 Q_PROPERTY(bool weighted READ get_weighted WRITE set_weighted RESET reset_weighted STORED false)
91 Q_PROPERTY(int numSubjects READ get_numSubjects WRITE set_numSubjects RESET reset_numSubjects STORED false) 91 Q_PROPERTY(int numSubjects READ get_numSubjects WRITE set_numSubjects RESET reset_numSubjects STORED false)
  92 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  93 + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
92 BR_PROPERTY(int, k, 1) 94 BR_PROPERTY(int, k, 1)
93 BR_PROPERTY(br::Distance*, distance, NULL) 95 BR_PROPERTY(br::Distance*, distance, NULL)
94 BR_PROPERTY(bool, weighted, false) 96 BR_PROPERTY(bool, weighted, false)
95 BR_PROPERTY(int, numSubjects, 1) 97 BR_PROPERTY(int, numSubjects, 1)
  98 + BR_PROPERTY(QString, inputVariable, "Label")
  99 + BR_PROPERTY(QString, outputVariable, "KNN")
96 100
97 TemplateList gallery; 101 TemplateList gallery;
98 102
@@ -111,17 +115,17 @@ class KNNTransform : public Transform @@ -111,17 +115,17 @@ class KNNTransform : public Transform
111 QHash<QString, float> votes; 115 QHash<QString, float> votes;
112 const int max = (k < 1) ? sortedScores.size() : std::min(k, sortedScores.size()); 116 const int max = (k < 1) ? sortedScores.size() : std::min(k, sortedScores.size());
113 for (int j=0; j<max; j++) 117 for (int j=0; j<max; j++)
114 - votes[gallery[sortedScores[j].second].file.get<QString>("Subject")] += (weighted ? sortedScores[j].first : 1); 118 + votes[gallery[sortedScores[j].second].file.get<QString>(inputVariable)] += (weighted ? sortedScores[j].first : 1);
115 subjects.append(votes.keys()[votes.values().indexOf(Common::Max(votes.values()))]); 119 subjects.append(votes.keys()[votes.values().indexOf(Common::Max(votes.values()))]);
116 120
117 // Remove subject from consideration 121 // Remove subject from consideration
118 if (subjects.size() < numSubjects) 122 if (subjects.size() < numSubjects)
119 for (int j=sortedScores.size()-1; j>=0; j--) 123 for (int j=sortedScores.size()-1; j>=0; j--)
120 - if (gallery[sortedScores[j].second].file.get<QString>("Subject") == subjects.last()) 124 + if (gallery[sortedScores[j].second].file.get<QString>(inputVariable) == subjects.last())
121 sortedScores.removeAt(j); 125 sortedScores.removeAt(j);
122 } 126 }
123 127
124 - dst.file.set("KNN", subjects.size() > 1 ? "[" + subjects.join(",") + "]" : subjects.first()); 128 + dst.file.set(outputVariable, subjects.size() > 1 ? "[" + subjects.join(",") + "]" : subjects.first());
125 } 129 }
126 130
127 void store(QDataStream &stream) const 131 void store(QDataStream &stream) const
openbr/plugins/eigen3.cpp
@@ -303,10 +303,12 @@ class LDATransform : public Transform @@ -303,10 +303,12 @@ class LDATransform : public Transform
303 Q_PROPERTY(bool pcaWhiten READ get_pcaWhiten WRITE set_pcaWhiten RESET reset_pcaWhiten STORED false) 303 Q_PROPERTY(bool pcaWhiten READ get_pcaWhiten WRITE set_pcaWhiten RESET reset_pcaWhiten STORED false)
304 Q_PROPERTY(int directLDA READ get_directLDA WRITE set_directLDA RESET reset_directLDA STORED false) 304 Q_PROPERTY(int directLDA READ get_directLDA WRITE set_directLDA RESET reset_directLDA STORED false)
305 Q_PROPERTY(float directDrop READ get_directDrop WRITE set_directDrop RESET reset_directDrop STORED false) 305 Q_PROPERTY(float directDrop READ get_directDrop WRITE set_directDrop RESET reset_directDrop STORED false)
  306 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
306 BR_PROPERTY(float, pcaKeep, 0.98) 307 BR_PROPERTY(float, pcaKeep, 0.98)
307 BR_PROPERTY(bool, pcaWhiten, false) 308 BR_PROPERTY(bool, pcaWhiten, false)
308 BR_PROPERTY(int, directLDA, 0) 309 BR_PROPERTY(int, directLDA, 0)
309 BR_PROPERTY(float, directDrop, 0.1) 310 BR_PROPERTY(float, directDrop, 0.1)
  311 + BR_PROPERTY(QString, inputVariable, "Label")
310 312
311 int dimsOut; 313 int dimsOut;
312 Eigen::VectorXf mean; 314 Eigen::VectorXf mean;
@@ -315,7 +317,7 @@ class LDATransform : public Transform @@ -315,7 +317,7 @@ class LDATransform : public Transform
315 void train(const TemplateList &_trainingSet) 317 void train(const TemplateList &_trainingSet)
316 { 318 {
317 // creates "Label" 319 // creates "Label"
318 - TemplateList trainingSet = TemplateList::relabel(_trainingSet, "Subject"); 320 + TemplateList trainingSet = TemplateList::relabel(_trainingSet, inputVariable);
319 321
320 int instances = trainingSet.size(); 322 int instances = trainingSet.size();
321 323
openbr/plugins/gallery.cpp
@@ -71,7 +71,7 @@ class arffGallery : public Gallery @@ -71,7 +71,7 @@ class arffGallery : public Gallery
71 } 71 }
72 72
73 arffFile.write(qPrintable(OpenCVUtils::matrixToStringList(t).join(','))); 73 arffFile.write(qPrintable(OpenCVUtils::matrixToStringList(t).join(',')));
74 - arffFile.write(qPrintable(",'" + t.file.get<QString>("Subject") + "'\n")); 74 + arffFile.write(qPrintable(",'" + t.file.get<QString>("Label") + "'\n"));
75 } 75 }
76 }; 76 };
77 77
@@ -643,11 +643,16 @@ class dbGallery : public Gallery @@ -643,11 +643,16 @@ class dbGallery : public Gallery
643 query = query.mid(1, query.size()-2); 643 query = query.mid(1, query.size()-2);
644 if (!q.exec(query)) 644 if (!q.exec(query))
645 qFatal("%s.", qPrintable(q.lastError().text())); 645 qFatal("%s.", qPrintable(q.lastError().text()));
  646 +
646 if ((q.record().count() == 0) || (q.record().count() > 3)) 647 if ((q.record().count() == 0) || (q.record().count() > 3))
647 qFatal("Query record expected one to three fields, got %d.", q.record().count()); 648 qFatal("Query record expected one to three fields, got %d.", q.record().count());
648 const bool hasMetadata = (q.record().count() >= 2); 649 const bool hasMetadata = (q.record().count() >= 2);
649 const bool hasFilter = (q.record().count() >= 3); 650 const bool hasFilter = (q.record().count() >= 3);
650 651
  652 + QString labelName = "Label";
  653 + if (q.record().count() >= 2)
  654 + labelName = q.record().fieldName(1);
  655 +
651 // subset = seed:subjectMaxSize:numSubjects:subjectMinSize or 656 // subset = seed:subjectMaxSize:numSubjects:subjectMinSize or
652 // subset = seed:{Metadata,...,Metadata}:numSubjects 657 // subset = seed:{Metadata,...,Metadata}:numSubjects
653 int seed = 0, subjectMaxSize = std::numeric_limits<int>::max(), numSubjects = std::numeric_limits<int>::max(), subjectMinSize = 0; 658 int seed = 0, subjectMaxSize = std::numeric_limits<int>::max(), numSubjects = std::numeric_limits<int>::max(), subjectMinSize = 0;
@@ -673,6 +678,7 @@ class dbGallery : public Gallery @@ -673,6 +678,7 @@ class dbGallery : public Gallery
673 QHash<QString, QList<Entry> > entries; // QHash<Label, QList<Entry> > 678 QHash<QString, QList<Entry> > entries; // QHash<Label, QList<Entry> >
674 while (q.next()) { 679 while (q.next()) {
675 if (hasFilter && (seed >= 0) && (qHash(q.value(2).toString()) % 2 != (uint)seed % 2)) continue; // Ensures training and testing filters don't overlap 680 if (hasFilter && (seed >= 0) && (qHash(q.value(2).toString()) % 2 != (uint)seed % 2)) continue; // Ensures training and testing filters don't overlap
  681 +
676 if (metadataFields.isEmpty()) 682 if (metadataFields.isEmpty())
677 entries[hasMetadata ? q.value(1).toString() : ""].append(QPair<QString,QString>(q.value(0).toString(), hasFilter ? q.value(2).toString() : "")); 683 entries[hasMetadata ? q.value(1).toString() : ""].append(QPair<QString,QString>(q.value(0).toString(), hasFilter ? q.value(2).toString() : ""));
678 else 684 else
@@ -707,8 +713,10 @@ class dbGallery : public Gallery @@ -707,8 +713,10 @@ class dbGallery : public Gallery
707 713
708 if (entryList.size() > subjectMaxSize) 714 if (entryList.size() > subjectMaxSize)
709 std::random_shuffle(entryList.begin(), entryList.end()); 715 std::random_shuffle(entryList.begin(), entryList.end());
710 - foreach (const Entry &entry, entryList.mid(0, subjectMaxSize))  
711 - templates.append(File(entry.first, label)); 716 + foreach (const Entry &entry, entryList.mid(0, subjectMaxSize)) {
  717 + templates.append(File(entry.first));
  718 + templates.last().file.set(labelName, label);
  719 + }
712 numSubjects--; 720 numSubjects--;
713 } 721 }
714 } 722 }
@@ -816,7 +824,7 @@ class statGallery : public Gallery @@ -816,7 +824,7 @@ class statGallery : public Gallery
816 824
817 void write(const Template &t) 825 void write(const Template &t)
818 { 826 {
819 - subjects.insert(t.file.get<QString>("Subject")); 827 + subjects.insert(t.file.get<QString>("Label"));
820 bytes.append(t.bytes()); 828 bytes.append(t.bytes());
821 } 829 }
822 }; 830 };
openbr/plugins/independent.cpp
@@ -9,37 +9,36 @@ using namespace cv; @@ -9,37 +9,36 @@ using namespace cv;
9 namespace br 9 namespace br
10 { 10 {
11 11
12 -static TemplateList Downsample(const TemplateList &templates, const Transform *transform) 12 +static TemplateList Downsample(const TemplateList &templates, int classes, int instances, float fraction, const QString & inputVariable)
13 { 13 {
14 // Return early when no downsampling is required 14 // Return early when no downsampling is required
15 - if ((transform->classes == std::numeric_limits<int>::max()) &&  
16 - (transform->instances == std::numeric_limits<int>::max()) &&  
17 - (transform->fraction >= 1)) 15 + if ((classes == std::numeric_limits<int>::max()) &&
  16 + (instances == std::numeric_limits<int>::max()) &&
  17 + (fraction >= 1))
18 return templates; 18 return templates;
19 19
20 - const bool atLeast = transform->instances < 0;  
21 - const int instances = abs(transform->instances); 20 + const bool atLeast = instances < 0;
  21 + instances = abs(instances);
22 22
23 - QList<QString> allLabels = File::get<QString>(templates, "Subject"); 23 + QList<QString> allLabels = File::get<QString>(templates, inputVariable);
24 QList<QString> uniqueLabels = allLabels.toSet().toList(); 24 QList<QString> uniqueLabels = allLabels.toSet().toList();
25 qSort(uniqueLabels); 25 qSort(uniqueLabels);
26 26
27 - QMap<QString,int> counts = templates.countValues<QString>("Subject", instances != std::numeric_limits<int>::max()); 27 + QMap<QString,int> counts = templates.countValues<QString>(inputVariable, instances != std::numeric_limits<int>::max());
28 28
29 - if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max())) 29 + if ((instances != std::numeric_limits<int>::max()) && (classes != std::numeric_limits<int>::max()))
30 foreach (const QString & label, counts.keys()) 30 foreach (const QString & label, counts.keys())
31 if (counts[label] < instances) 31 if (counts[label] < instances)
32 counts.remove(label); 32 counts.remove(label);
33 33
34 uniqueLabels = counts.keys(); 34 uniqueLabels = counts.keys();
35 - if ((transform->classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < transform->classes))  
36 - qWarning("Downsample requested %d classes but only %d are available.", transform->classes, uniqueLabels.size()); 35 + if ((classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < classes))
  36 + qWarning("Downsample requested %d classes but only %d are available.", classes, uniqueLabels.size());
37 37
38 - Common::seedRNG();  
39 QList<QString> selectedLabels = uniqueLabels; 38 QList<QString> selectedLabels = uniqueLabels;
40 - if (transform->classes < uniqueLabels.size()) { 39 + if (classes < uniqueLabels.size()) {
41 std::random_shuffle(selectedLabels.begin(), selectedLabels.end()); 40 std::random_shuffle(selectedLabels.begin(), selectedLabels.end());
42 - selectedLabels = selectedLabels.mid(0, transform->classes); 41 + selectedLabels = selectedLabels.mid(0, classes);
43 } 42 }
44 43
45 TemplateList downsample; 44 TemplateList downsample;
@@ -56,14 +55,45 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t @@ -56,14 +55,45 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t
56 downsample.append(templates.value(indices[j])); 55 downsample.append(templates.value(indices[j]));
57 } 56 }
58 57
59 - if (transform->fraction < 1) { 58 + if (fraction < 1) {
60 std::random_shuffle(downsample.begin(), downsample.end()); 59 std::random_shuffle(downsample.begin(), downsample.end());
61 - downsample = downsample.mid(0, downsample.size()*transform->fraction); 60 + downsample = downsample.mid(0, downsample.size()*fraction);
62 } 61 }
63 62
64 return downsample; 63 return downsample;
65 } 64 }
66 65
  66 +class DownsampleTrainingTransform : public Transform
  67 +{
  68 + Q_OBJECT
  69 + Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform STORED true)
  70 + Q_PROPERTY(int classes READ get_classes WRITE set_classes RESET reset_classes STORED false)
  71 + Q_PROPERTY(int instances READ get_instances WRITE set_instances RESET reset_instances STORED false)
  72 + Q_PROPERTY(float fraction READ get_fraction WRITE set_fraction RESET reset_fraction STORED false)
  73 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  74 + BR_PROPERTY(br::Transform*, transform, NULL)
  75 + BR_PROPERTY(int, classes, std::numeric_limits<int>::max())
  76 + BR_PROPERTY(int, instances, std::numeric_limits<int>::max())
  77 + BR_PROPERTY(float, fraction, 1)
  78 + BR_PROPERTY(QString, inputVariable, "Label")
  79 +
  80 + void project(const Template & src, Template & dst) const
  81 + {
  82 + transform->project(src,dst);
  83 + }
  84 +
  85 +
  86 + void train(const TemplateList &data)
  87 + {
  88 + if (!transform || !transform->trainable)
  89 + return;
  90 +
  91 + TemplateList downsampled = Downsample(data, classes, instances, fraction, inputVariable);
  92 + transform->train(downsampled);
  93 + }
  94 +};
  95 +BR_REGISTER(Transform, DownsampleTrainingTransform)
  96 +
67 /*! 97 /*!
68 * \ingroup transforms 98 * \ingroup transforms
69 * \brief Clones the transform so that it can be applied independently. 99 * \brief Clones the transform so that it can be applied independently.
@@ -124,13 +154,10 @@ class IndependentTransform : public MetaTransform @@ -124,13 +154,10 @@ class IndependentTransform : public MetaTransform
124 while (transforms.size() < templatesList.size()) 154 while (transforms.size() < templatesList.size())
125 transforms.append(transform->clone()); 155 transforms.append(transform->clone());
126 156
127 - for (int i=0; i<templatesList.size(); i++)  
128 - templatesList[i] = Downsample(templatesList[i], transforms[i]);  
129 -  
130 QFutureSynchronizer<void> futures; 157 QFutureSynchronizer<void> futures;
131 for (int i=0; i<templatesList.size(); i++) 158 for (int i=0; i<templatesList.size(); i++)
132 - futures.addFuture(QtConcurrent::run(_train, transforms[i], &templatesList[i]));  
133 - futures.waitForFinished(); 159 + futures.addFuture(QtConcurrent::run(_train, transforms[i], &templatesList[i]));
  160 + futures.waitForFinished();
134 } 161 }
135 162
136 void project(const Template &src, Template &dst) const 163 void project(const Template &src, Template &dst) const
openbr/plugins/mask.cpp
@@ -158,6 +158,9 @@ class LargestConvexAreaTransform : public UntrainableTransform @@ -158,6 +158,9 @@ class LargestConvexAreaTransform : public UntrainableTransform
158 { 158 {
159 Q_OBJECT 159 Q_OBJECT
160 160
  161 + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
  162 + BR_PROPERTY(QString, outputVariable, "Label")
  163 +
161 void project(const Template &src, Template &dst) const 164 void project(const Template &src, Template &dst) const
162 { 165 {
163 std::vector< std::vector<Point> > contours; 166 std::vector< std::vector<Point> > contours;
@@ -171,7 +174,7 @@ class LargestConvexAreaTransform : public UntrainableTransform @@ -171,7 +174,7 @@ class LargestConvexAreaTransform : public UntrainableTransform
171 if (area / hullArea > 0.98) 174 if (area / hullArea > 0.98)
172 maxArea = std::max(maxArea, area); 175 maxArea = std::max(maxArea, area);
173 } 176 }
174 - dst.file.set("Label", maxArea); 177 + dst.file.set(outputVariable, maxArea);
175 } 178 }
176 }; 179 };
177 180
openbr/plugins/normalize.cpp
@@ -97,6 +97,7 @@ class CenterTransform : public Transform @@ -97,6 +97,7 @@ class CenterTransform : public Transform
97 Q_OBJECT 97 Q_OBJECT
98 Q_ENUMS(Method) 98 Q_ENUMS(Method)
99 Q_PROPERTY(Method method READ get_method WRITE set_method RESET reset_method STORED false) 99 Q_PROPERTY(Method method READ get_method WRITE set_method RESET reset_method STORED false)
  100 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
100 101
101 public: 102 public:
102 /*!< */ 103 /*!< */
@@ -107,6 +108,7 @@ public: @@ -107,6 +108,7 @@ public:
107 108
108 private: 109 private:
109 BR_PROPERTY(Method, method, Mean) 110 BR_PROPERTY(Method, method, Mean)
  111 + BR_PROPERTY(QString, inputVariable, "Label")
110 112
111 Mat a, b; // dst = (src - b) / a 113 Mat a, b; // dst = (src - b) / a
112 114
@@ -127,7 +129,7 @@ private: @@ -127,7 +129,7 @@ private:
127 Mat m; 129 Mat m;
128 OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F); 130 OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F);
129 131
130 - const QList<int> labels = data.indexProperty("Subject"); 132 + const QList<int> labels = data.indexProperty(inputVariable);
131 const int dims = m.cols; 133 const int dims = m.cols;
132 134
133 vector<Mat> mv, av, bv; 135 vector<Mat> mv, av, bv;
openbr/plugins/openbr_internal.h
@@ -219,10 +219,6 @@ public: @@ -219,10 +219,6 @@ public:
219 } 219 }
220 220
221 output->file = this->file; 221 output->file = this->file;
222 - output->classes = classes;  
223 - output->instances = instances;  
224 - output->fraction = fraction;  
225 -  
226 output->init(); 222 output->init();
227 223
228 return output; 224 return output;
openbr/plugins/output.cpp
@@ -146,8 +146,8 @@ class meltOutput : public MatrixOutput @@ -146,8 +146,8 @@ class meltOutput : public MatrixOutput
146 QStringList lines; 146 QStringList lines;
147 if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys)); 147 if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys));
148 148
149 - QList<QString> queryLabels = File::get<QString>(queryFiles, "Subject");  
150 - QList<QString> targetLabels = File::get<QString>(targetFiles, "Subject"); 149 + QList<QString> queryLabels = File::get<QString>(queryFiles, "Label");
  150 + QList<QString> targetLabels = File::get<QString>(targetFiles, "Label");
151 151
152 for (int i=0; i<queryFiles.size(); i++) { 152 for (int i=0; i<queryFiles.size(); i++) {
153 for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) { 153 for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) {
@@ -298,7 +298,7 @@ class txtOutput : public MatrixOutput @@ -298,7 +298,7 @@ class txtOutput : public MatrixOutput
298 if (file.isNull() || targetFiles.isEmpty() || queryFiles.isEmpty()) return; 298 if (file.isNull() || targetFiles.isEmpty() || queryFiles.isEmpty()) return;
299 QStringList lines; 299 QStringList lines;
300 foreach (const File &file, queryFiles) 300 foreach (const File &file, queryFiles)
301 - lines.append(file.name + " " + file.get<QString>("Subject")); 301 + lines.append(file.name + " " + file.get<QString>("Label"));
302 QtUtils::writeFile(file, lines); 302 QtUtils::writeFile(file, lines);
303 } 303 }
304 }; 304 };
@@ -428,7 +428,7 @@ class rankOutput : public MatrixOutput @@ -428,7 +428,7 @@ class rankOutput : public MatrixOutput
428 foreach (const Pair &pair, Common::Sort(OpenCVUtils::matrixToVector<float>(data.row(i)), true)) { 428 foreach (const Pair &pair, Common::Sort(OpenCVUtils::matrixToVector<float>(data.row(i)), true)) {
429 if (Globals->crossValidate > 0 ? (targetFiles[pair.second].get<int>("Partition",-1) == queryFiles[i].get<int>("Partition",-1)) : true) { 429 if (Globals->crossValidate > 0 ? (targetFiles[pair.second].get<int>("Partition",-1) == queryFiles[i].get<int>("Partition",-1)) : true) {
430 if (QString(targetFiles[pair.second]) != QString(queryFiles[i])) { 430 if (QString(targetFiles[pair.second]) != QString(queryFiles[i])) {
431 - if (targetFiles[pair.second].get<QString>("Subject") == queryFiles[i].get<QString>("Subject")) { 431 + if (targetFiles[pair.second].get<QString>("Label") == queryFiles[i].get<QString>("Label")) {
432 ranks.append(rank); 432 ranks.append(rank);
433 positions.append(pair.second); 433 positions.append(pair.second);
434 scores.append(pair.first); 434 scores.append(pair.first);
openbr/plugins/quality.cpp
@@ -19,17 +19,20 @@ class ImpostorUniquenessMeasureTransform : public Transform @@ -19,17 +19,20 @@ class ImpostorUniquenessMeasureTransform : public Transform
19 Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) 19 Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
20 Q_PROPERTY(double mean READ get_mean WRITE set_mean RESET reset_mean) 20 Q_PROPERTY(double mean READ get_mean WRITE set_mean RESET reset_mean)
21 Q_PROPERTY(double stddev READ get_stddev WRITE set_stddev RESET reset_stddev) 21 Q_PROPERTY(double stddev READ get_stddev WRITE set_stddev RESET reset_stddev)
  22 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
22 BR_PROPERTY(br::Distance*, distance, Distance::make("Dist(L2)", this)) 23 BR_PROPERTY(br::Distance*, distance, Distance::make("Dist(L2)", this))
23 BR_PROPERTY(double, mean, 0) 24 BR_PROPERTY(double, mean, 0)
24 BR_PROPERTY(double, stddev, 1) 25 BR_PROPERTY(double, stddev, 1)
  26 + BR_PROPERTY(QString, inputVariable, "Label")
  27 +
25 TemplateList impostors; 28 TemplateList impostors;
26 29
27 float calculateIUM(const Template &probe, const TemplateList &gallery) const 30 float calculateIUM(const Template &probe, const TemplateList &gallery) const
28 { 31 {
29 - const QString probeLabel = probe.file.get<QString>("Subject"); 32 + const QString probeLabel = probe.file.get<QString>(inputVariable);
30 TemplateList subset = gallery; 33 TemplateList subset = gallery;
31 for (int j=subset.size()-1; j>=0; j--) 34 for (int j=subset.size()-1; j>=0; j--)
32 - if (subset[j].file.get<QString>("Subject") == probeLabel) 35 + if (subset[j].file.get<QString>(inputVariable) == probeLabel)
33 subset.removeAt(j); 36 subset.removeAt(j);
34 37
35 QList<float> scores = distance->compare(subset, probe); 38 QList<float> scores = distance->compare(subset, probe);
@@ -151,6 +154,7 @@ class MatchProbabilityDistance : public Distance @@ -151,6 +154,7 @@ class MatchProbabilityDistance : public Distance
151 Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) 154 Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
152 Q_PROPERTY(bool gaussian READ get_gaussian WRITE set_gaussian RESET reset_gaussian STORED false) 155 Q_PROPERTY(bool gaussian READ get_gaussian WRITE set_gaussian RESET reset_gaussian STORED false)
153 Q_PROPERTY(bool crossModality READ get_crossModality WRITE set_crossModality RESET reset_crossModality STORED false) 156 Q_PROPERTY(bool crossModality READ get_crossModality WRITE set_crossModality RESET reset_crossModality STORED false)
  157 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
154 158
155 MP mp; 159 MP mp;
156 160
@@ -158,7 +162,7 @@ class MatchProbabilityDistance : public Distance @@ -158,7 +162,7 @@ class MatchProbabilityDistance : public Distance
158 { 162 {
159 distance->train(src); 163 distance->train(src);
160 164
161 - const QList<int> labels = src.indexProperty("Subject"); 165 + const QList<int> labels = src.indexProperty(inputVariable);
162 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); 166 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size())));
163 distance->compare(src, src, matrixOutput.data()); 167 distance->compare(src, src, matrixOutput.data());
164 168
@@ -201,6 +205,7 @@ protected: @@ -201,6 +205,7 @@ protected:
201 BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) 205 BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))
202 BR_PROPERTY(bool, gaussian, true) 206 BR_PROPERTY(bool, gaussian, true)
203 BR_PROPERTY(bool, crossModality, false) 207 BR_PROPERTY(bool, crossModality, false)
  208 + BR_PROPERTY(QString, inputVariable, "Label")
204 }; 209 };
205 210
206 BR_REGISTER(Distance, MatchProbabilityDistance) 211 BR_REGISTER(Distance, MatchProbabilityDistance)
@@ -217,10 +222,12 @@ class HeatMapDistance : public Distance @@ -217,10 +222,12 @@ class HeatMapDistance : public Distance
217 Q_PROPERTY(bool gaussian READ get_gaussian WRITE set_gaussian RESET reset_gaussian STORED false) 222 Q_PROPERTY(bool gaussian READ get_gaussian WRITE set_gaussian RESET reset_gaussian STORED false)
218 Q_PROPERTY(bool crossModality READ get_crossModality WRITE set_crossModality RESET reset_crossModality STORED false) 223 Q_PROPERTY(bool crossModality READ get_crossModality WRITE set_crossModality RESET reset_crossModality STORED false)
219 Q_PROPERTY(int step READ get_step WRITE set_step RESET reset_step STORED false) 224 Q_PROPERTY(int step READ get_step WRITE set_step RESET reset_step STORED false)
  225 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
220 BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) 226 BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))
221 BR_PROPERTY(bool, gaussian, true) 227 BR_PROPERTY(bool, gaussian, true)
222 BR_PROPERTY(bool, crossModality, false) 228 BR_PROPERTY(bool, crossModality, false)
223 BR_PROPERTY(int, step, 1) 229 BR_PROPERTY(int, step, 1)
  230 + BR_PROPERTY(QString, inputVariable, "Label")
224 231
225 QList<MP> mp; 232 QList<MP> mp;
226 233
@@ -228,7 +235,7 @@ class HeatMapDistance : public Distance @@ -228,7 +235,7 @@ class HeatMapDistance : public Distance
228 { 235 {
229 distance->train(src); 236 distance->train(src);
230 237
231 - const QList<int> labels = src.indexProperty("Subject"); 238 + const QList<int> labels = src.indexProperty(inputVariable);
232 239
233 QList<TemplateList> patches; 240 QList<TemplateList> patches;
234 241
@@ -307,14 +314,16 @@ class UnitDistance : public Distance @@ -307,14 +314,16 @@ class UnitDistance : public Distance
307 Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance) 314 Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance)
308 Q_PROPERTY(float a READ get_a WRITE set_a RESET reset_a) 315 Q_PROPERTY(float a READ get_a WRITE set_a RESET reset_a)
309 Q_PROPERTY(float b READ get_b WRITE set_b RESET reset_b) 316 Q_PROPERTY(float b READ get_b WRITE set_b RESET reset_b)
  317 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
310 BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) 318 BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))
311 BR_PROPERTY(float, a, 1) 319 BR_PROPERTY(float, a, 1)
312 BR_PROPERTY(float, b, 0) 320 BR_PROPERTY(float, b, 0)
  321 + BR_PROPERTY(QString, inputVariable, "Label")
313 322
314 void train(const TemplateList &templates) 323 void train(const TemplateList &templates)
315 { 324 {
316 const TemplateList samples = templates.mid(0, 2000); 325 const TemplateList samples = templates.mid(0, 2000);
317 - const QList<int> sampleLabels = samples.indexProperty("Subject"); 326 + const QList<int> sampleLabels = samples.indexProperty(inputVariable);
318 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size()))); 327 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size())));
319 Distance::compare(samples, samples, matrixOutput.data()); 328 Distance::compare(samples, samples, matrixOutput.data());
320 329
openbr/plugins/quantize.cpp
@@ -120,6 +120,10 @@ BR_REGISTER(Transform, HistEqQuantizationTransform) @@ -120,6 +120,10 @@ BR_REGISTER(Transform, HistEqQuantizationTransform)
120 class BayesianQuantizationDistance : public Distance 120 class BayesianQuantizationDistance : public Distance
121 { 121 {
122 Q_OBJECT 122 Q_OBJECT
  123 +
  124 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  125 + BR_PROPERTY(QString, inputVariable, "Label")
  126 +
123 QVector<float> loglikelihoods; 127 QVector<float> loglikelihoods;
124 128
125 static void computeLogLikelihood(const Mat &data, const QList<int> &labels, float *loglikelihood) 129 static void computeLogLikelihood(const Mat &data, const QList<int> &labels, float *loglikelihood)
@@ -150,7 +154,7 @@ class BayesianQuantizationDistance : public Distance @@ -150,7 +154,7 @@ class BayesianQuantizationDistance : public Distance
150 qFatal("Expected sigle matrix templates of type CV_8UC1!"); 154 qFatal("Expected sigle matrix templates of type CV_8UC1!");
151 155
152 const Mat data = OpenCVUtils::toMat(src.data()); 156 const Mat data = OpenCVUtils::toMat(src.data());
153 - const QList<int> templateLabels = src.indexProperty("Subject"); 157 + const QList<int> templateLabels = src.indexProperty(inputVariable);
154 loglikelihoods = QVector<float>(data.cols*256, 0); 158 loglikelihoods = QVector<float>(data.cols*256, 0);
155 159
156 QFutureSynchronizer<void> futures; 160 QFutureSynchronizer<void> futures;
@@ -343,9 +347,11 @@ class ProductQuantizationTransform : public Transform @@ -343,9 +347,11 @@ class ProductQuantizationTransform : public Transform
343 Q_PROPERTY(int n READ get_n WRITE set_n RESET reset_n STORED false) 347 Q_PROPERTY(int n READ get_n WRITE set_n RESET reset_n STORED false)
344 Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance STORED false) 348 Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
345 Q_PROPERTY(bool bayesian READ get_bayesian WRITE set_bayesian RESET reset_bayesian STORED false) 349 Q_PROPERTY(bool bayesian READ get_bayesian WRITE set_bayesian RESET reset_bayesian STORED false)
  350 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
346 BR_PROPERTY(int, n, 2) 351 BR_PROPERTY(int, n, 2)
347 BR_PROPERTY(br::Distance*, distance, Distance::make("L2", this)) 352 BR_PROPERTY(br::Distance*, distance, Distance::make("L2", this))
348 BR_PROPERTY(bool, bayesian, false) 353 BR_PROPERTY(bool, bayesian, false)
  354 + BR_PROPERTY(QString, inputVariable, "Label")
349 355
350 quint16 index; 356 quint16 index;
351 QList<Mat> centers; 357 QList<Mat> centers;
@@ -474,7 +480,7 @@ private: @@ -474,7 +480,7 @@ private:
474 Mat data = OpenCVUtils::toMat(src.data()); 480 Mat data = OpenCVUtils::toMat(src.data());
475 const int step = getStep(data.cols); 481 const int step = getStep(data.cols);
476 482
477 - const QList<int> labels = src.indexProperty("Subject"); 483 + const QList<int> labels = src.indexProperty(inputVariable);
478 484
479 Mat &lut = ProductQuantizationLUTs[index]; 485 Mat &lut = ProductQuantizationLUTs[index];
480 lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1); 486 lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1);
openbr/plugins/quantize2.cpp
@@ -19,6 +19,10 @@ namespace br @@ -19,6 +19,10 @@ namespace br
19 class BayesianQuantizationTransform : public Transform 19 class BayesianQuantizationTransform : public Transform
20 { 20 {
21 Q_OBJECT 21 Q_OBJECT
  22 +
  23 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  24 + BR_PROPERTY(QString, inputVariable, "Label")
  25 +
22 QVector<float> thresholds; 26 QVector<float> thresholds;
23 27
24 static void computeThresholdsRecursive(const QVector<int> &cumulativeGenuines, const QVector<int> &cumulativeImpostors, 28 static void computeThresholdsRecursive(const QVector<int> &cumulativeGenuines, const QVector<int> &cumulativeImpostors,
@@ -77,7 +81,7 @@ class BayesianQuantizationTransform : public Transform @@ -77,7 +81,7 @@ class BayesianQuantizationTransform : public Transform
77 void train(const TemplateList &src) 81 void train(const TemplateList &src)
78 { 82 {
79 const Mat data = OpenCVUtils::toMat(src.data()); 83 const Mat data = OpenCVUtils::toMat(src.data());
80 - const QList<int> labels = src.indexProperty("Subject"); 84 + const QList<int> labels = src.indexProperty(inputVariable);
81 85
82 thresholds = QVector<float>(256*data.cols); 86 thresholds = QVector<float>(256*data.cols);
83 87
openbr/plugins/svm.cpp
@@ -101,6 +101,8 @@ class SVMTransform : public Transform @@ -101,6 +101,8 @@ class SVMTransform : public Transform
101 Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) 101 Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false)
102 Q_PROPERTY(float C READ get_C WRITE set_C RESET reset_C STORED false) 102 Q_PROPERTY(float C READ get_C WRITE set_C RESET reset_C STORED false)
103 Q_PROPERTY(float gamma READ get_gamma WRITE set_gamma RESET reset_gamma STORED false) 103 Q_PROPERTY(float gamma READ get_gamma WRITE set_gamma RESET reset_gamma STORED false)
  104 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  105 + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
104 106
105 public: 107 public:
106 enum Kernel { Linear = CvSVM::LINEAR, 108 enum Kernel { Linear = CvSVM::LINEAR,
@@ -119,6 +121,9 @@ private: @@ -119,6 +121,9 @@ private:
119 BR_PROPERTY(Type, type, C_SVC) 121 BR_PROPERTY(Type, type, C_SVC)
120 BR_PROPERTY(float, C, -1) 122 BR_PROPERTY(float, C, -1)
121 BR_PROPERTY(float, gamma, -1) 123 BR_PROPERTY(float, gamma, -1)
  124 + BR_PROPERTY(QString, inputVariable, "")
  125 + BR_PROPERTY(QString, outputVariable, "")
  126 +
122 127
123 SVM svm; 128 SVM svm;
124 QHash<QString, int> labelMap; 129 QHash<QString, int> labelMap;
@@ -128,14 +133,15 @@ private: @@ -128,14 +133,15 @@ private:
128 { 133 {
129 Mat data = OpenCVUtils::toMat(_data.data()); 134 Mat data = OpenCVUtils::toMat(_data.data());
130 Mat lab; 135 Mat lab;
131 - // If we are doing regression, assume subject has float values 136 + // If we are doing regression, the input variable should have float
  137 + // values
132 if (type == EPS_SVR || type == NU_SVR) { 138 if (type == EPS_SVR || type == NU_SVR) {
133 - lab = OpenCVUtils::toMat(File::get<float>(_data, "Subject")); 139 + lab = OpenCVUtils::toMat(File::get<float>(_data, inputVariable));
134 } 140 }
135 - // If we are doing classification, assume subject has discrete values, map them  
136 - // and store the mapping data 141 + // If we are doing classification, we should be dealing with discrete
  142 + // values. Map them and store the mapping data
137 else { 143 else {
138 - QList<int> dataLabels = _data.indexProperty("Subject", labelMap, reverseLookup); 144 + QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup);
139 lab = OpenCVUtils::toMat(dataLabels); 145 lab = OpenCVUtils::toMat(dataLabels);
140 } 146 }
141 trainSVM(svm, data, lab, kernel, type, C, gamma); 147 trainSVM(svm, data, lab, kernel, type, C, gamma);
@@ -146,9 +152,9 @@ private: @@ -146,9 +152,9 @@ private:
146 dst = src; 152 dst = src;
147 float prediction = svm.predict(src.m().reshape(1, 1)); 153 float prediction = svm.predict(src.m().reshape(1, 1));
148 if (type == EPS_SVR || type == NU_SVR) 154 if (type == EPS_SVR || type == NU_SVR)
149 - dst.file.set("Subject", prediction); 155 + dst.file.set(outputVariable, prediction);
150 else 156 else
151 - dst.file.set("Subject", reverseLookup[prediction]); 157 + dst.file.set(outputVariable, reverseLookup[prediction]);
152 } 158 }
153 159
154 void store(QDataStream &stream) const 160 void store(QDataStream &stream) const
@@ -162,6 +168,24 @@ private: @@ -162,6 +168,24 @@ private:
162 loadSVM(svm, stream); 168 loadSVM(svm, stream);
163 stream >> labelMap >> reverseLookup; 169 stream >> labelMap >> reverseLookup;
164 } 170 }
  171 +
  172 + void init()
  173 + {
  174 + // Since SVM can do regression or classification, we have to check the problem type before
  175 + // specifying target variable names
  176 + if (inputVariable.isEmpty())
  177 + {
  178 + if (type == EPS_SVR || type == NU_SVR) {
  179 + inputVariable = "Regressor";
  180 + if (outputVariable.isEmpty())
  181 + outputVariable = "Regressand";
  182 + }
  183 + else
  184 + inputVariable = "Label";
  185 + }
  186 + if (outputVariable.isEmpty())
  187 + outputVariable = inputVariable;
  188 + }
165 }; 189 };
166 190
167 BR_REGISTER(Transform, SVMTransform) 191 BR_REGISTER(Transform, SVMTransform)
@@ -178,6 +202,8 @@ class SVMDistance : public Distance @@ -178,6 +202,8 @@ class SVMDistance : public Distance
178 Q_ENUMS(Type) 202 Q_ENUMS(Type)
179 Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) 203 Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false)
180 Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) 204 Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false)
  205 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  206 +
181 207
182 public: 208 public:
183 enum Kernel { Linear = CvSVM::LINEAR, 209 enum Kernel { Linear = CvSVM::LINEAR,
@@ -194,13 +220,14 @@ public: @@ -194,13 +220,14 @@ public:
194 private: 220 private:
195 BR_PROPERTY(Kernel, kernel, Linear) 221 BR_PROPERTY(Kernel, kernel, Linear)
196 BR_PROPERTY(Type, type, EPS_SVR) 222 BR_PROPERTY(Type, type, EPS_SVR)
  223 + BR_PROPERTY(QString, inputVariable, "Label")
197 224
198 SVM svm; 225 SVM svm;
199 226
200 void train(const TemplateList &src) 227 void train(const TemplateList &src)
201 { 228 {
202 const Mat data = OpenCVUtils::toMat(src.data()); 229 const Mat data = OpenCVUtils::toMat(src.data());
203 - const QList<int> lab = src.indexProperty("Subject"); 230 + const QList<int> lab = src.indexProperty(inputVariable);
204 231
205 const int instances = data.rows * (data.rows+1) / 2; 232 const int instances = data.rows * (data.rows+1) / 2;
206 Mat deltaData(instances, data.cols, data.type()); 233 Mat deltaData(instances, data.cols, data.type());
scripts/evalAgeRegression-PCSO.sh
@@ -4,8 +4,12 @@ if [ ! -f evalAgeRegression-PCSO.sh ]; then @@ -4,8 +4,12 @@ if [ ! -f evalAgeRegression-PCSO.sh ]; then
4 exit 4 exit
5 fi 5 fi
6 6
  7 +export BR="../build/app/br/br -useGui 0"
  8 +export PCSO_DIR=/user/pripshare/Databases/FaceDatabases/PCSO/PCSO/
  9 +export ageAlg=AgeRegression
  10 +
7 # Create a file list by querying the database 11 # Create a file list by querying the database
8 -br -quiet -algorithm Identity -enroll "../data/PCSO/PCSO.db[query='SELECT File,Age,PersonID FROM PCSO WHERE Age >= 15 AND AGE <= 75', subset=1:200]" terminal.txt > Input.txt 12 +$BR -quiet -algorithm Identity -enroll "$PCSO_DIR/PCSO.db[query='SELECT File,Age,PersonID FROM PCSO WHERE Age >= 17 AND AGE <= 68', subset=1:200]" terminal.txt > Input.txt
9 13
10 # Enroll the file list and evaluate performance 14 # Enroll the file list and evaluate performance
11 -br -algorithm AgeRegression -path ../data/PCSO/img -enroll Input.txt Output.txt -evalRegression Output.txt Input.txt 15 +$BR -algorithm $ageAlg -path $PCSO_DIR/Images -enroll Input.txt Output.txt -evalRegression Output.txt Input.txt Age
scripts/evalFaceRecognition-MEDS.sh
@@ -20,11 +20,11 @@ if [ ! -e Algorithm_Dataset ]; then @@ -20,11 +20,11 @@ if [ ! -e Algorithm_Dataset ]; then
20 fi 20 fi
21 21
22 if [ ! -e MEDS.mask ]; then 22 if [ ! -e MEDS.mask ]; then
23 - br -makeMask ../data/MEDS/sigset/MEDS_frontal_target.xml ../data/MEDS/sigset/MEDS_frontal_query.xml MEDS.mask 23 + br -useGui 0 -makeMask ../data/MEDS/sigset/MEDS_frontal_target.xml ../data/MEDS/sigset/MEDS_frontal_query.xml MEDS.mask
24 fi 24 fi
25 25
26 # Run Algorithm on MEDS 26 # Run Algorithm on MEDS
27 -br -algorithm ${ALGORITHM} -path ../data/MEDS/img -compare ../data/MEDS/sigset/MEDS_frontal_target.xml ../data/MEDS/sigset/MEDS_frontal_query.xml ${ALGORITHM}_MEDS.mtx -eval ${ALGORITHM}_MEDS.mtx MEDS.mask Algorithm_Dataset/${ALGORITHM}_MEDS.csv 27 +br -useGui 0 -algorithm ${ALGORITHM} -path ../data/MEDS/img -compare ../data/MEDS/sigset/MEDS_frontal_target.xml ../data/MEDS/sigset/MEDS_frontal_query.xml ${ALGORITHM}_MEDS.mtx -eval ${ALGORITHM}_MEDS.mtx MEDS.mask Algorithm_Dataset/${ALGORITHM}_MEDS.csv
28 28
29 # Plot results 29 # Plot results
30 -br -plot Algorithm_Dataset/*_MEDS.csv MEDS 30 +br -useGui 0 -plot Algorithm_Dataset/*_MEDS.csv MEDS
scripts/evalGenderClassification-PCSO.sh
@@ -4,8 +4,13 @@ if [ ! -f evalGenderClassification-PCSO.sh ]; then @@ -4,8 +4,13 @@ if [ ! -f evalGenderClassification-PCSO.sh ]; then
4 exit 4 exit
5 fi 5 fi
6 6
  7 +export BR=../build/app/br/br
  8 +export genderAlg=GenderClassification
  9 +
  10 +export PCSO_DIR=/user/pripshare/Databases/FaceDatabases/PCSO/PCSO/
  11 +
7 # Create a file list by querying the database 12 # Create a file list by querying the database
8 -br -quiet -algorithm Identity -enroll "../data/PCSO/PCSO.db[query='SELECT File,Gender,PersonID FROM PCSO', subset=1:8000]" terminal.txt > Input.txt 13 +$BR -useGui 0 -quiet -algorithm Identity -enroll "$PCSO_DIR/PCSO.db[query='SELECT File,Gender,PersonID FROM PCSO', subset=1:8000]" terminal.txt > Input.txt
9 14
10 # Enroll the file list and evaluate performance 15 # Enroll the file list and evaluate performance
11 -br -algorithm GenderClassification -path ../data/PCSO/img -enroll Input.txt Output.txt -evalClassification Output.txt Input.txt 16 +$BR -useGui 0 -algorithm $genderAlg -path $PCSO_DIR/Images -enroll Input.txt Output.txt -evalClassification Output.txt Input.txt Gender
12 \ No newline at end of file 17 \ No newline at end of file
scripts/trainAgeRegression-PCSO.sh
@@ -6,6 +6,11 @@ fi @@ -6,6 +6,11 @@ fi
6 6
7 #rm -f ../share/openbr/models/features/FaceClassificationRegistration 7 #rm -f ../share/openbr/models/features/FaceClassificationRegistration
8 #rm -f ../share/openbr/models/features/FaceClassificationExtraction 8 #rm -f ../share/openbr/models/features/FaceClassificationExtraction
9 -rm -f ../share/openbr/models/algorithms/AgeRegression 9 +#rm -f ../share/openbr/models/algorithms/AgeRegression
10 10
11 -br -algorithm AgeRegression -path ../data/PCSO/Images -train "../data/PCSO/PCSO.db[query='SELECT File,Age,PersonID FROM PCSO WHERE Age >= 15 AND AGE <= 75', subset=0:200]" ../share/openbr/models/algorithms/AgeRegression 11 +export BR=../build/app/br/br
  12 +export ageAlg=AgeRegression
  13 +
  14 +export PCSO_DIR=/user/pripshare/Databases/FaceDatabases/PCSO/PCSO/
  15 +
  16 +$BR -useGui 0 -algorithm $ageAlg -path $PCSO_DIR/Images -train "$PCSO_DIR/PCSO.db[query='SELECT File,Age,PersonID FROM PCSO WHERE Age >= 17 AND AGE <= 68', subset=0:200]" ../share/openbr/models/algorithms/AgeRegression
scripts/trainFaceRecognition-PCSO.sh
@@ -8,6 +8,13 @@ fi @@ -8,6 +8,13 @@ fi
8 #rm -f ../share/openbr/models/features/FaceRecognitionExtraction 8 #rm -f ../share/openbr/models/features/FaceRecognitionExtraction
9 #rm -f ../share/openbr/models/features/FaceRecognitionEmbedding 9 #rm -f ../share/openbr/models/features/FaceRecognitionEmbedding
10 #rm -f ../share/openbr/models/features/FaceRecognitionQuantization 10 #rm -f ../share/openbr/models/features/FaceRecognitionQuantization
11 -rm -f ../share/openbr/models/algorithms/FaceRecognition 11 +#rm -f ../share/openbr/models/algorithms/FaceRecognition
  12 +
  13 +export BR=../build/app/br/br
  14 +
  15 +export PCSO_DIR=/user/pripshare/Databases/FaceDatabases/PCSO/PCSO/
  16 +
  17 +
  18 +
  19 +$BR -useGui 0 -algorithm FaceRecognition -path "$PCSO_DIR/Images/" -train "$PCSO_DIR/PCSO.db[query='SELECT File,PersonID as Label,PersonID FROM PCSO', subset=0:5:6000]" ../share/openbr/models/algorithms/FaceRecognition
12 20
13 -br -algorithm FaceRecognition -path ../data/PCSO/img -train "../data/PCSO/PCSO.db[query='SELECT File,'S'||PersonID,PersonID FROM PCSO', subset=0:5:6000]" ../share/openbr/models/algorithms/FaceRecognition  
scripts/trainGenderClassification-PCSO.sh
@@ -6,6 +6,11 @@ fi @@ -6,6 +6,11 @@ fi
6 6
7 #rm -f ../share/openbr/models/features/FaceClassificationRegistration 7 #rm -f ../share/openbr/models/features/FaceClassificationRegistration
8 #rm -f ../share/openbr/models/features/FaceClassificationExtraction 8 #rm -f ../share/openbr/models/features/FaceClassificationExtraction
9 -rm -f ../share/openbr/models/algorithms/GenderClassification 9 +#rm -f ../share/openbr/models/algorithms/GenderClassification
10 10
11 -br -algorithm GenderClassification -path ../data/PCSO/Images -train "../data/PCSO/PCSO.db[query='SELECT File,Gender,PersonID FROM PCSO', subset=0:8000]" ../share/openbr/models/algorithms/GenderClassification 11 +export BR=../build/app/br/br
  12 +export genderAlg=GenderClassification
  13 +
  14 +export PCSO_DIR=/user/pripshare/Databases/FaceDatabases/PCSO/PCSO/
  15 +
  16 +$BR -useGui 0 -algorithm $genderAlg -path $PCSO_DIR/Images -train "$PCSO_DIR/PCSO.db[query='SELECT File,Gender,PersonID FROM PCSO', subset=0:8000]" ../share/openbr/models/algorithms/GenderClassification
1 -Subproject commit dccddf4dd3a5239911807beeec39308f8890b1e4 1 +Subproject commit a73d51013ea05f263e88a28539393159fff2183e