Commit 72a5968a012360dda117defa4d8e0521806a31b5

Authored by Charles Otto
1 parent 16039af1

Further changes to label/subject

Remove global label/subject lookup table
Consistently use "Subject" rather than "Label", subject is assumed to be
convertable to QString. When desirable, map discrete subject values to ints.

For classifiers such as svm that require numeric labels, generate a string->int
mapping for the training data, and store it (local to the transform).

Utility functions for collecting all values of a given property (on a template
list), and mapping discrete property values to 0-based integers

Some outstanding issues include use of label/subject in matrix output
app/examples/age_estimation.cpp
@@ -29,7 +29,7 @@ @@ -29,7 +29,7 @@
29 29
30 static void printTemplate(const br::Template &t) 30 static void printTemplate(const br::Template &t)
31 { 31 {
32 - printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Label"))); 32 + printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Subject")));
33 } 33 }
34 34
35 int main(int argc, char *argv[]) 35 int main(int argc, char *argv[])
openbr/core/bee.cpp
@@ -260,26 +260,28 @@ void BEE::makeMask(const QString &amp;targetInput, const QString &amp;queryInput, const @@ -260,26 +260,28 @@ void BEE::makeMask(const QString &amp;targetInput, const QString &amp;queryInput, const
260 260
261 cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition) 261 cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition)
262 { 262 {
263 - QList<float> targetLabels = targets.collectValues<float>("Label");  
264 - QList<float> queryLabels = queries.collectValues<float>("Label"); 263 + // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet
  264 + // -cao
  265 + QList<QString> targetLabels = targets.collectValues<QString>("Subject", "-1");
  266 + QList<QString> queryLabels = queries.collectValues<QString>("Subject", "-1");
265 QList<int> targetPartitions = targets.crossValidationPartitions(); 267 QList<int> targetPartitions = targets.crossValidationPartitions();
266 QList<int> queryPartitions = queries.crossValidationPartitions(); 268 QList<int> queryPartitions = queries.crossValidationPartitions();
267 269
268 Mat mask(queries.size(), targets.size(), CV_8UC1); 270 Mat mask(queries.size(), targets.size(), CV_8UC1);
269 for (int i=0; i<queries.size(); i++) { 271 for (int i=0; i<queries.size(); i++) {
270 const QString &fileA = queries[i]; 272 const QString &fileA = queries[i];
271 - const int labelA = queryLabels[i]; 273 + const QString labelA = queryLabels[i];
272 const int partitionA = queryPartitions[i]; 274 const int partitionA = queryPartitions[i];
273 275
274 for (int j=0; j<targets.size(); j++) { 276 for (int j=0; j<targets.size(); j++) {
275 const QString &fileB = targets[j]; 277 const QString &fileB = targets[j];
276 - const int labelB = targetLabels[j]; 278 + const QString labelB = targetLabels[j];
277 const int partitionB = targetPartitions[j]; 279 const int partitionB = targetPartitions[j];
278 280
279 Mask_t val; 281 Mask_t val;
280 if (fileA == fileB) val = DontCare; 282 if (fileA == fileB) val = DontCare;
281 - else if (labelA == -1) val = DontCare;  
282 - else if (labelB == -1) val = DontCare; 283 + else if (labelA == "-1") val = DontCare;
  284 + else if (labelB == "-1") val = DontCare;
283 else if (partitionA != partition) val = DontCare; 285 else if (partitionA != partition) val = DontCare;
284 else if (partitionB == -1) val = NonMatch; 286 else if (partitionB == -1) val = NonMatch;
285 else if (partitionB != partition) val = DontCare; 287 else if (partitionB != partition) val = DontCare;
openbr/core/classify.cpp
@@ -104,9 +104,9 @@ void br::EvalRegression(const QString &amp;predictedInput, const QString &amp;truthInput @@ -104,9 +104,9 @@ void br::EvalRegression(const QString &amp;predictedInput, const QString &amp;truthInput
104 for (int i=0; i<predicted.size(); i++) { 104 for (int i=0; i<predicted.size(); i++) {
105 if (predicted[i].file.name != truth[i].file.name) 105 if (predicted[i].file.name != truth[i].file.name)
106 qFatal("Input order mismatch."); 106 qFatal("Input order mismatch.");
107 - rmsError += pow(predicted[i].file.get<float>("Label")-truth[i].file.get<float>("Label"), 2.f);  
108 - truthValues.append(QString::number(truth[i].file.get<float>("Label")));  
109 - predictedValues.append(QString::number(predicted[i].file.get<float>("Label"))); 107 + rmsError += pow(predicted[i].file.get<float>("Subject")-truth[i].file.get<float>("Subject"), 2.f);
  108 + truthValues.append(QString::number(truth[i].file.get<float>("Subject")));
  109 + predictedValues.append(QString::number(predicted[i].file.get<float>("Subject")));
110 } 110 }
111 111
112 QStringList rSource; 112 QStringList rSource;
openbr/core/cluster.cpp
@@ -278,7 +278,9 @@ void br::EvalClustering(const QString &amp;csv, const QString &amp;input) @@ -278,7 +278,9 @@ void br::EvalClustering(const QString &amp;csv, const QString &amp;input)
278 { 278 {
279 qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input)); 279 qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input));
280 280
281 - QList<float> labels = TemplateList::fromGallery(input).files().collectValues<float>("Label"); 281 + // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are
  282 + // not named).
  283 + QList<int> labels = TemplateList::fromGallery(input).files().collectValues<int>("Subject");
282 284
283 QHash<int, int> labelToIndex; 285 QHash<int, int> labelToIndex;
284 int nClusters = 0; 286 int nClusters = 0;
openbr/core/core.cpp
@@ -78,7 +78,6 @@ struct AlgorithmCore @@ -78,7 +78,6 @@ struct AlgorithmCore
78 const bool hasComparer = !distance.isNull(); 78 const bool hasComparer = !distance.isNull();
79 out << hasComparer; 79 out << hasComparer;
80 if (hasComparer) distance->store(out); 80 if (hasComparer) distance->store(out);
81 - out << Globals->subjects;  
82 81
83 // Compress and save to file 82 // Compress and save to file
84 QtUtils::writeFile(model, data, -1); 83 QtUtils::writeFile(model, data, -1);
@@ -98,7 +97,6 @@ struct AlgorithmCore @@ -98,7 +97,6 @@ struct AlgorithmCore
98 transform->load(in); 97 transform->load(in);
99 bool hasDistance; in >> hasDistance; 98 bool hasDistance; in >> hasDistance;
100 if (hasDistance) distance->load(in); 99 if (hasDistance) distance->load(in);
101 - in >> Globals->subjects;  
102 } 100 }
103 101
104 File getMemoryGallery(const File &file) const 102 File getMemoryGallery(const File &file) const
openbr/core/opencvutils.cpp
@@ -119,6 +119,17 @@ Mat OpenCVUtils::toMat(const QList&lt;float&gt; &amp;src, int rows) @@ -119,6 +119,17 @@ Mat OpenCVUtils::toMat(const QList&lt;float&gt; &amp;src, int rows)
119 return dst; 119 return dst;
120 } 120 }
121 121
  122 +Mat OpenCVUtils::toMat(const QList<int> &src, int rows)
  123 +{
  124 + if (rows == -1) rows = src.size();
  125 + int columns = src.isEmpty() ? 0 : src.size() / rows;
  126 + if (rows*columns != src.size()) qFatal("Invalid matrix size.");
  127 + Mat dst(rows, columns, CV_32FC1);
  128 + for (int i=0; i<src.size(); i++)
  129 + dst.at<float>(i/columns,i%columns) = src[i];
  130 + return dst;
  131 +}
  132 +
122 Mat OpenCVUtils::toMat(const QList<Mat> &src) 133 Mat OpenCVUtils::toMat(const QList<Mat> &src)
123 { 134 {
124 if (src.isEmpty()) return Mat(); 135 if (src.isEmpty()) return Mat();
openbr/core/opencvutils.h
@@ -35,6 +35,8 @@ namespace OpenCVUtils @@ -35,6 +35,8 @@ namespace OpenCVUtils
35 35
36 // To image 36 // To image
37 cv::Mat toMat(const QList<float> &src, int rows = -1); 37 cv::Mat toMat(const QList<float> &src, int rows = -1);
  38 + cv::Mat toMat(const QList<int> &src, int rows = -1);
  39 +
38 cv::Mat toMat(const QList<cv::Mat> &src); // Data organized one matrix per row 40 cv::Mat toMat(const QList<cv::Mat> &src); // Data organized one matrix per row
39 cv::Mat toMatByRow(const QList<cv::Mat> &src); // Data organized one row per row 41 cv::Mat toMatByRow(const QList<cv::Mat> &src); // Data organized one row per row
40 42
openbr/frvt2012.cpp
@@ -132,7 +132,7 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &amp;input_face, int32_t &amp;age) @@ -132,7 +132,7 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &amp;input_face, int32_t &amp;age)
132 TemplateList templates; 132 TemplateList templates;
133 templates.append(templateFromONEFACE(input_face)); 133 templates.append(templateFromONEFACE(input_face));
134 templates >> *frvt2012_age_transform.data(); 134 templates >> *frvt2012_age_transform.data();
135 - age = templates.first().file.get<float>("Label"); 135 + age = templates.first().file.get<float>("Subject");
136 return templates.first().file.failed() ? 4 : 0; 136 return templates.first().file.failed() ? 4 : 0;
137 } 137 }
138 138
@@ -141,6 +141,7 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &amp;input_face, int8_t &amp;gender, @@ -141,6 +141,7 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &amp;input_face, int8_t &amp;gender,
141 TemplateList templates; 141 TemplateList templates;
142 templates.append(templateFromONEFACE(input_face)); 142 templates.append(templateFromONEFACE(input_face));
143 templates >> *frvt2012_gender_transform.data(); 143 templates >> *frvt2012_gender_transform.data();
144 - mf = gender = templates.first().file.get<float>("Label"); 144 + // TODO: lookup gender strings/expected int outputs -cao
  145 + mf = gender = templates.first().file.get<QString>("Subject") == "Male" ? 0 : 1;
145 return templates.first().file.failed() ? 4 : 0; 146 return templates.first().file.failed() ? 4 : 0;
146 } 147 }
openbr/gui/classifier.cpp
@@ -42,16 +42,15 @@ void Classifier::_classify(File file) @@ -42,16 +42,15 @@ void Classifier::_classify(File file)
42 if (!f.contains("Label")) 42 if (!f.contains("Label"))
43 continue; 43 continue;
44 44
45 - // What's with the special casing -cao  
46 if (algorithm == "GenderClassification") { 45 if (algorithm == "GenderClassification") {
47 key = "Gender"; 46 key = "Gender";
48 value = (f.get<QString>("Subject")); 47 value = (f.get<QString>("Subject"));
49 } else if (algorithm == "AgeRegression") { 48 } else if (algorithm == "AgeRegression") {
50 key = "Age"; 49 key = "Age";
51 - value = QString::number(int(f.get<float>("Label")+0.5)) + " Years"; 50 + value = QString::number(int(f.get<float>("Subject")+0.5)) + " Years";
52 } else { 51 } else {
53 key = algorithm; 52 key = algorithm;
54 - value = QString::number(f.get<float>("Label")); 53 + value = f.get<QString>("Subject");
55 } 54 }
56 break; 55 break;
57 } 56 }
openbr/openbr_plugin.cpp
@@ -435,11 +435,13 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery) @@ -435,11 +435,13 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery)
435 return templates; 435 return templates;
436 } 436 }
437 437
  438 +// indexes some property, assigns an integer id to each unique value of propName
  439 +// stores the index values in "Label" of the output template list -cao
438 TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propName) 440 TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propName)
439 { 441 {
440 - const QList<int> originalLabels = tl.collectValues<int>(propName);  
441 - QHash<int,int> labelTable;  
442 - foreach (int label, originalLabels) 442 + const QList<QString> originalLabels = tl.collectValues<QString>(propName);
  443 + QHash<QString,int> labelTable;
  444 + foreach (const QString & label, originalLabels)
443 if (!labelTable.contains(label)) 445 if (!labelTable.contains(label))
444 labelTable.insert(label, labelTable.size()); 446 labelTable.insert(label, labelTable.size());
445 447
@@ -449,6 +451,52 @@ TemplateList TemplateList::relabel(const TemplateList &amp;tl, const QString &amp; propN @@ -449,6 +451,52 @@ TemplateList TemplateList::relabel(const TemplateList &amp;tl, const QString &amp; propN
449 return result; 451 return result;
450 } 452 }
451 453
  454 +QList<int> TemplateList::indexProperty(const QString & propName, QHash<QString, int> * valueMap,QHash<int, QVariant> * reverseLookup) const
  455 +{
  456 + QHash<QString, int> dummyForwards;
  457 + QHash<int, QVariant> dummyBackwards;
  458 +
  459 + if (!valueMap) valueMap = &dummyForwards;
  460 + if (!reverseLookup) reverseLookup = &dummyBackwards;
  461 +
  462 + return indexProperty(propName, *valueMap, *reverseLookup);
  463 +}
  464 +
  465 +QList<int> TemplateList::indexProperty(const QString & propName, QHash<QString, int> & valueMap, QHash<int, QVariant> & reverseLookup) const
  466 +{
  467 + valueMap.clear();
  468 + reverseLookup.clear();
  469 +
  470 + const QList<QVariant> originalLabels = collectValues<QVariant>(propName);
  471 + foreach (const QVariant & label, originalLabels) {
  472 + QString labelString = label.toString();
  473 + if (!valueMap.contains(labelString)) {
  474 + reverseLookup.insert(valueMap.size(), label);
  475 + valueMap.insert(labelString, valueMap.size());
  476 + }
  477 + }
  478 +
  479 + QList<int> result;
  480 + for (int i=0; i<originalLabels.size(); i++)
  481 + result.append(valueMap[originalLabels[i].toString()]);
  482 +
  483 + return result;
  484 +}
  485 +
  486 +// uses -1 for missing values
  487 +QList<int> TemplateList::applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const
  488 +{
  489 + const QList<QString> originalLabels = collectValues<QString>(propName);
  490 +
  491 + QList<int> result;
  492 + for (int i=0; i<originalLabels.size(); i++) {
  493 + if (!valueMap.contains(originalLabels[i])) result.append(-1);
  494 + else result.append(valueMap[originalLabels[i]]);
  495 + }
  496 +
  497 + return result;
  498 +}
  499 +
452 /* Object - public methods */ 500 /* Object - public methods */
453 QStringList Object::parameters() const 501 QStringList Object::parameters() const
454 { 502 {
@@ -989,7 +1037,8 @@ QString MatrixOutput::toString(int row, int column) const @@ -989,7 +1037,8 @@ QString MatrixOutput::toString(int row, int column) const
989 { 1037 {
990 if (targetFiles[column] == "Subject") { 1038 if (targetFiles[column] == "Subject") {
991 const int label = data.at<float>(row,column); 1039 const int label = data.at<float>(row,column);
992 - return Globals->subjects.key(label, QString::number(label)); 1040 + // problem -cao
  1041 + return QString::number(label);
993 } 1042 }
994 return QString::number(data.at<float>(row,column)); 1043 return QString::number(data.at<float>(row,column));
995 } 1044 }
openbr/openbr_plugin.h
@@ -306,6 +306,15 @@ struct BR_EXPORT FileList : public QList&lt;File&gt; @@ -306,6 +306,15 @@ struct BR_EXPORT FileList : public QList&lt;File&gt;
306 values.append(f.get<T>(propName)); 306 values.append(f.get<T>(propName));
307 return values; 307 return values;
308 } 308 }
  309 + template<typename T>
  310 + QList<T> collectValues(const QString & propName, T defaultValue) const
  311 + {
  312 + QList<T> values; values.reserve(size());
  313 + foreach (const File &f, *this)
  314 + values.append(f.contains(propName) ? f.get<T>(propName) : defaultValue);
  315 + return values;
  316 + }
  317 +
309 QList<int> crossValidationPartitions() const; /*!< \brief Returns the cross-validation partition (default=0) for each file in the list. */ 318 QList<int> crossValidationPartitions() const; /*!< \brief Returns the cross-validation partition (default=0) for each file in the list. */
310 int failures() const; /*!< \brief Returns the number of files with br::File::failed(). */ 319 int failures() const; /*!< \brief Returns the number of files with br::File::failed(). */
311 }; 320 };
@@ -393,6 +402,11 @@ struct TemplateList : public QList&lt;Template&gt; @@ -393,6 +402,11 @@ struct TemplateList : public QList&lt;Template&gt;
393 /*!< \brief Ensure labels are in the range [0,numClasses-1]. */ 402 /*!< \brief Ensure labels are in the range [0,numClasses-1]. */
394 BR_EXPORT static TemplateList relabel(const TemplateList & tl, const QString & propName); 403 BR_EXPORT static TemplateList relabel(const TemplateList & tl, const QString & propName);
395 404
  405 + QList<int> indexProperty(const QString & propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const;
  406 + QList<int> indexProperty(const QString & propName, QHash<QString, int> & valueMap, QHash<int, QVariant> & reverseLookup) const;
  407 + QList<int> applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const;
  408 +
  409 +
396 /*! 410 /*!
397 * \brief Returns the total number of bytes in all the templates. 411 * \brief Returns the total number of bytes in all the templates.
398 */ 412 */
@@ -658,7 +672,6 @@ public: @@ -658,7 +672,6 @@ public:
658 BR_PROPERTY(int, crossValidate, 0) 672 BR_PROPERTY(int, crossValidate, 0)
659 673
660 QHash<QString,QString> abbreviations; /*!< \brief Used by br::Transform::make() to expand abbreviated algorithms into their complete definitions. */ 674 QHash<QString,QString> abbreviations; /*!< \brief Used by br::Transform::make() to expand abbreviated algorithms into their complete definitions. */
661 - QHash<QString,int> subjects; /*!< \brief Used by classifiers to associate text class labels with unique integers IDs. */  
662 QTime startTime; /*!< \brief Used to estimate timeRemaining(). */ 675 QTime startTime; /*!< \brief Used to estimate timeRemaining(). */
663 676
664 /*! 677 /*!
openbr/plugins/eigen3.cpp
@@ -329,7 +329,8 @@ class LDATransform : public Transform @@ -329,7 +329,8 @@ class LDATransform : public Transform
329 329
330 void train(const TemplateList &_trainingSet) 330 void train(const TemplateList &_trainingSet)
331 { 331 {
332 - TemplateList trainingSet = TemplateList::relabel(_trainingSet, "Label"); 332 + // creates "Label" -cao
  333 + TemplateList trainingSet = TemplateList::relabel(_trainingSet, "Subject");
333 334
334 int instances = trainingSet.size(); 335 int instances = trainingSet.size();
335 336
@@ -342,12 +343,13 @@ class LDATransform : public Transform @@ -342,12 +343,13 @@ class LDATransform : public Transform
342 343
343 TemplateList ldaTrainingSet; 344 TemplateList ldaTrainingSet;
344 static_cast<Transform*>(&pca)->project(trainingSet, ldaTrainingSet); 345 static_cast<Transform*>(&pca)->project(trainingSet, ldaTrainingSet);
  346 + // Reindex label, is this still necessary? -cao
345 ldaTrainingSet = TemplateList::relabel(ldaTrainingSet, "Label"); 347 ldaTrainingSet = TemplateList::relabel(ldaTrainingSet, "Label");
346 348
347 int dimsIn = ldaTrainingSet.first().m().rows * ldaTrainingSet.first().m().cols; 349 int dimsIn = ldaTrainingSet.first().m().rows * ldaTrainingSet.first().m().cols;
348 350
349 // OpenBR ensures that class values range from 0 to numClasses-1. 351 // OpenBR ensures that class values range from 0 to numClasses-1.
350 - // Assumed label is stored as float or int? -cao 352 + // Label exists because we created it earlier with relabel
351 QList<int> classes = trainingSet.collectValues<int>("Label"); 353 QList<int> classes = trainingSet.collectValues<int>("Label");
352 QMap<int, int> classCounts = trainingSet.countValues<int>("Label"); 354 QMap<int, int> classCounts = trainingSet.countValues<int>("Label");
353 const int numClasses = classCounts.size(); 355 const int numClasses = classCounts.size();
openbr/plugins/gallery.cpp
@@ -65,7 +65,7 @@ class arffGallery : public Gallery @@ -65,7 +65,7 @@ class arffGallery : public Gallery
65 const int dimensions = t.m().rows * t.m().cols; 65 const int dimensions = t.m().rows * t.m().cols;
66 for (int i=0; i<dimensions; i++) 66 for (int i=0; i<dimensions; i++)
67 arffFile.write(qPrintable("@ATTRIBUTE v" + QString::number(i) + " REAL\n")); 67 arffFile.write(qPrintable("@ATTRIBUTE v" + QString::number(i) + " REAL\n"));
68 - arffFile.write(qPrintable("@ATTRIBUTE class {" + QStringList(Globals->subjects.keys()).join(',') + "}\n")); 68 + arffFile.write(qPrintable("@ATTRIBUTE class string\n"));
69 69
70 arffFile.write("\n@DATA\n"); 70 arffFile.write("\n@DATA\n");
71 } 71 }
@@ -518,9 +518,8 @@ class csvGallery : public Gallery @@ -518,9 +518,8 @@ class csvGallery : public Gallery
518 static QString getCSVElement(const QString &key, const QVariant &value, bool header) 518 static QString getCSVElement(const QString &key, const QVariant &value, bool header)
519 { 519 {
520 if ((key == "Label") && !header) { 520 if ((key == "Label") && !header) {
521 - QString stringLabel = Globals->subjects.key(value.value<int>());  
522 - if (stringLabel.isEmpty()) return value.value<QString>();  
523 - else return stringLabel; 521 + // problem -cao
  522 + return value.value<QString>();
524 } else if (value.canConvert<QString>()) { 523 } else if (value.canConvert<QString>()) {
525 if (header) return key; 524 if (header) return key;
526 else return value.value<QString>(); 525 else return value.value<QString>();
openbr/plugins/independent.cpp
@@ -20,21 +20,23 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t @@ -20,21 +20,23 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t
20 const bool atLeast = transform->instances < 0; 20 const bool atLeast = transform->instances < 0;
21 const int instances = abs(transform->instances); 21 const int instances = abs(transform->instances);
22 22
23 - QList<int> allLabels = templates.collectValues<int>("Label");  
24 - QList<int> uniqueLabels = allLabels.toSet().toList(); 23 + QList<QString> allLabels = templates.collectValues<QString>("Subject");
  24 + QList<QString> uniqueLabels = allLabels.toSet().toList();
25 qSort(uniqueLabels); 25 qSort(uniqueLabels);
26 26
27 - QMap<int,int> counts = templates.countValues<int>("Label", instances != std::numeric_limits<int>::max()); 27 + QMap<QString,int> counts = templates.countValues<QString>("Subject", instances != std::numeric_limits<int>::max());
  28 +
28 if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max())) 29 if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max()))
29 - foreach (int label, counts.keys()) 30 + foreach (const QString & label, counts.keys())
30 if (counts[label] < instances) 31 if (counts[label] < instances)
31 counts.remove(label); 32 counts.remove(label);
  33 +
32 uniqueLabels = counts.keys(); 34 uniqueLabels = counts.keys();
33 if ((transform->classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < transform->classes)) 35 if ((transform->classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < transform->classes))
34 qWarning("Downsample requested %d classes but only %d are available.", transform->classes, uniqueLabels.size()); 36 qWarning("Downsample requested %d classes but only %d are available.", transform->classes, uniqueLabels.size());
35 37
36 Common::seedRNG(); 38 Common::seedRNG();
37 - QList<int> selectedLabels = uniqueLabels; 39 + QList<QString> selectedLabels = uniqueLabels;
38 if (transform->classes < uniqueLabels.size()) { 40 if (transform->classes < uniqueLabels.size()) {
39 std::random_shuffle(selectedLabels.begin(), selectedLabels.end()); 41 std::random_shuffle(selectedLabels.begin(), selectedLabels.end());
40 selectedLabels = selectedLabels.mid(0, transform->classes); 42 selectedLabels = selectedLabels.mid(0, transform->classes);
@@ -42,7 +44,7 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t @@ -42,7 +44,7 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t
42 44
43 TemplateList downsample; 45 TemplateList downsample;
44 for (int i=0; i<selectedLabels.size(); i++) { 46 for (int i=0; i<selectedLabels.size(); i++) {
45 - const int selectedLabel = selectedLabels[i]; 47 + const QString selectedLabel = selectedLabels[i];
46 QList<int> indices; 48 QList<int> indices;
47 for (int j=0; j<allLabels.size(); j++) 49 for (int j=0; j<allLabels.size(); j++)
48 if ((allLabels[j] == selectedLabel) && (!templates.value(j).file.get<bool>("FTE", false))) 50 if ((allLabels[j] == selectedLabel) && (!templates.value(j).file.get<bool>("FTE", false)))
openbr/plugins/meta.cpp
@@ -453,7 +453,8 @@ private: @@ -453,7 +453,8 @@ private:
453 const QString &file = src.file; 453 const QString &file = src.file;
454 if (cache.contains(file)) { 454 if (cache.contains(file)) {
455 dst = cache[file]; 455 dst = cache[file];
456 - dst.file.set("Label", src.file.value("Label")); 456 + // don't get this -cao
  457 +// dst.file.set("Label", src.file.value("Label"));
457 } else { 458 } else {
458 transform->project(src, dst); 459 transform->project(src, dst);
459 cacheLock.lock(); 460 cacheLock.lock();
openbr/plugins/normalize.cpp
@@ -126,7 +126,8 @@ private: @@ -126,7 +126,8 @@ private:
126 { 126 {
127 Mat m; 127 Mat m;
128 OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F); 128 OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F);
129 - const QList<int> labels = data.collectValues<int>("Label"); 129 +
  130 + const QList<int> labels = data.indexProperty("Subject");
130 const int dims = m.cols; 131 const int dims = m.cols;
131 132
132 vector<Mat> mv, av, bv; 133 vector<Mat> mv, av, bv;
openbr/plugins/output.cpp
@@ -153,8 +153,10 @@ class meltOutput : public MatrixOutput @@ -153,8 +153,10 @@ class meltOutput : public MatrixOutput
153 153
154 QStringList lines; 154 QStringList lines;
155 if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys)); 155 if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys));
156 - QList<float> queryLabels = queryFiles.collectValues<float>("Label");  
157 - QList<float> targetLabels = targetFiles.collectValues<float>("Label"); 156 +
  157 + QList<QString> queryLabels = queryFiles.collectValues<QString>("Subject");
  158 + QList<QString> targetLabels = targetFiles.collectValues<QString>("Subject");
  159 +
158 for (int i=0; i<queryFiles.size(); i++) { 160 for (int i=0; i<queryFiles.size(); i++) {
159 for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) { 161 for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) {
160 const bool genuine = queryLabels[i] == targetLabels[j]; 162 const bool genuine = queryLabels[i] == targetLabels[j];
openbr/plugins/pixel.cpp
@@ -28,6 +28,7 @@ namespace br @@ -28,6 +28,7 @@ namespace br
28 */ 28 */
29 class PerPixelClassifierTransform : public MetaTransform 29 class PerPixelClassifierTransform : public MetaTransform
30 { 30 {
  31 + // problematic -cao
31 Q_OBJECT 32 Q_OBJECT
32 Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform) 33 Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform)
33 Q_PROPERTY(int pixels READ get_pixels WRITE set_pixels RESET reset_pixels STORED false) 34 Q_PROPERTY(int pixels READ get_pixels WRITE set_pixels RESET reset_pixels STORED false)
@@ -273,6 +274,7 @@ BR_REGISTER(Transform, ToBinaryVectorTransform) @@ -273,6 +274,7 @@ BR_REGISTER(Transform, ToBinaryVectorTransform)
273 * \author E. Taborsky \cite mmtaborsky 274 * \author E. Taborsky \cite mmtaborsky
274 */ 275 */
275 276
  277 +// What does this do? -cao
276 class ToMetadataTransform : public UntrainableMetaTransform 278 class ToMetadataTransform : public UntrainableMetaTransform
277 { 279 {
278 Q_OBJECT 280 Q_OBJECT
openbr/plugins/quality.cpp
@@ -26,10 +26,10 @@ class ImpostorUniquenessMeasureTransform : public Transform @@ -26,10 +26,10 @@ class ImpostorUniquenessMeasureTransform : public Transform
26 26
27 float calculateIUM(const Template &probe, const TemplateList &gallery) const 27 float calculateIUM(const Template &probe, const TemplateList &gallery) const
28 { 28 {
29 - const int probeLabel = probe.file.get<float>("Label"); 29 + const QString probeLabel = probe.file.get<QString>("Subject");
30 TemplateList subset = gallery; 30 TemplateList subset = gallery;
31 for (int j=subset.size()-1; j>=0; j--) 31 for (int j=subset.size()-1; j>=0; j--)
32 - if (subset[j].file.get<int>("Label") == probeLabel) 32 + if (subset[j].file.get<QString>("Subject") == probeLabel)
33 subset.removeAt(j); 33 subset.removeAt(j);
34 34
35 QList<float> scores = distance->compare(subset, probe); 35 QList<float> scores = distance->compare(subset, probe);
@@ -159,7 +159,7 @@ class MatchProbabilityDistance : public Distance @@ -159,7 +159,7 @@ class MatchProbabilityDistance : public Distance
159 { 159 {
160 distance->train(src); 160 distance->train(src);
161 161
162 - const QList<int> labels = src.collectValues<int>("Label"); 162 + const QList<int> labels = src.indexProperty("Subject");
163 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); 163 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size())));
164 distance->compare(src, src, matrixOutput.data()); 164 distance->compare(src, src, matrixOutput.data());
165 165
@@ -219,7 +219,7 @@ class UnitDistance : public Distance @@ -219,7 +219,7 @@ class UnitDistance : public Distance
219 void train(const TemplateList &templates) 219 void train(const TemplateList &templates)
220 { 220 {
221 const TemplateList samples = templates.mid(0, 2000); 221 const TemplateList samples = templates.mid(0, 2000);
222 - const QList<float> sampleLabels = samples.collectValues<float>("Label"); 222 + const QList<int> sampleLabels = samples.indexProperty("Subject");
223 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size()))); 223 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size())));
224 Distance::compare(samples, samples, matrixOutput.data()); 224 Distance::compare(samples, samples, matrixOutput.data());
225 225
openbr/plugins/quantize.cpp
@@ -150,7 +150,7 @@ class BayesianQuantizationDistance : public Distance @@ -150,7 +150,7 @@ class BayesianQuantizationDistance : public Distance
150 qFatal("Expected sigle matrix templates of type CV_8UC1!"); 150 qFatal("Expected sigle matrix templates of type CV_8UC1!");
151 151
152 const Mat data = OpenCVUtils::toMat(src.data()); 152 const Mat data = OpenCVUtils::toMat(src.data());
153 - const QList<int> templateLabels = src.collectValues<int>("Label"); 153 + const QList<int> templateLabels = src.indexProperty("Subject");
154 loglikelihoods = QVector<float>(data.cols*256, 0); 154 loglikelihoods = QVector<float>(data.cols*256, 0);
155 155
156 QFutureSynchronizer<void> futures; 156 QFutureSynchronizer<void> futures;
@@ -473,7 +473,8 @@ private: @@ -473,7 +473,8 @@ private:
473 { 473 {
474 Mat data = OpenCVUtils::toMat(src.data()); 474 Mat data = OpenCVUtils::toMat(src.data());
475 const int step = getStep(data.cols); 475 const int step = getStep(data.cols);
476 - const QList<int> labels = src.collectValues<int>("Label"); 476 +
  477 + const QList<int> labels = src.indexProperty("Subject");
477 478
478 Mat &lut = ProductQuantizationLUTs[index]; 479 Mat &lut = ProductQuantizationLUTs[index];
479 lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1); 480 lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1);
openbr/plugins/quantize2.cpp
@@ -77,7 +77,7 @@ class BayesianQuantizationTransform : public Transform @@ -77,7 +77,7 @@ class BayesianQuantizationTransform : public Transform
77 void train(const TemplateList &src) 77 void train(const TemplateList &src)
78 { 78 {
79 const Mat data = OpenCVUtils::toMat(src.data()); 79 const Mat data = OpenCVUtils::toMat(src.data());
80 - const QList<int> labels = src.collectValues<int>("Label"); 80 + const QList<int> labels = src.indexProperty("Subject");
81 81
82 thresholds = QVector<float>(256*data.cols); 82 thresholds = QVector<float>(256*data.cols);
83 83
openbr/plugins/svm.cpp
@@ -121,28 +121,46 @@ private: @@ -121,28 +121,46 @@ private:
121 BR_PROPERTY(float, gamma, -1) 121 BR_PROPERTY(float, gamma, -1)
122 122
123 SVM svm; 123 SVM svm;
  124 + QHash<QString, int> labelMap;
  125 + QHash<int, QVariant> reverseLookup;
124 126
125 void train(const TemplateList &_data) 127 void train(const TemplateList &_data)
126 { 128 {
127 Mat data = OpenCVUtils::toMat(_data.data()); 129 Mat data = OpenCVUtils::toMat(_data.data());
128 - Mat lab = OpenCVUtils::toMat(_data.collectValues<float>("Label")); 130 + Mat lab;
  131 + // If we are doing regression, assume subject has float values
  132 + if (type == EPS_SVR || type == NU_SVR) {
  133 + lab = OpenCVUtils::toMat(_data.collectValues<float>("Subject"));
  134 + }
  135 + // If we are doing classification, assume subject has discrete values, map them
  136 + // and store the mapping data
  137 + else {
  138 + QList<int> dataLabels = _data.indexProperty("Subject", labelMap, reverseLookup);
  139 + lab = OpenCVUtils::toMat(dataLabels);
  140 + }
129 trainSVM(svm, data, lab, kernel, type, C, gamma); 141 trainSVM(svm, data, lab, kernel, type, C, gamma);
130 } 142 }
131 143
132 void project(const Template &src, Template &dst) const 144 void project(const Template &src, Template &dst) const
133 { 145 {
134 dst = src; 146 dst = src;
135 - dst.file.set("Label", svm.predict(src.m().reshape(1, 1))); 147 + float prediction = svm.predict(src.m().reshape(1, 1));
  148 + if (type == EPS_SVR || type == NU_SVR)
  149 + dst.file.set("Subject", prediction);
  150 + else
  151 + dst.file.set("Subject", reverseLookup[prediction]);
136 } 152 }
137 153
138 void store(QDataStream &stream) const 154 void store(QDataStream &stream) const
139 { 155 {
140 storeSVM(svm, stream); 156 storeSVM(svm, stream);
  157 + stream << labelMap << reverseLookup;
141 } 158 }
142 159
143 void load(QDataStream &stream) 160 void load(QDataStream &stream)
144 { 161 {
145 loadSVM(svm, stream); 162 loadSVM(svm, stream);
  163 + stream >> labelMap >> reverseLookup;
146 } 164 }
147 }; 165 };
148 166