Commit 72a5968a012360dda117defa4d8e0521806a31b5

Authored by Charles Otto
1 parent 16039af1

Further changes to label/subject

Remove global label/subject lookup table
Consistently use "Subject" rather than "Label", subject is assumed to be
convertable to QString. When desirable, map discrete subject values to ints.

For classifiers such as svm that require numeric labels, generate a string->int
mapping for the training data, and store it (local to the transform).

Utility functions for collecting all values of a given property (on a template
list), and mapping discrete property values to 0-based integers

Some outstanding issues include use of label/subject in matrix output
app/examples/age_estimation.cpp
... ... @@ -29,7 +29,7 @@
29 29  
30 30 static void printTemplate(const br::Template &t)
31 31 {
32   - printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Label")));
  32 + printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Subject")));
33 33 }
34 34  
35 35 int main(int argc, char *argv[])
... ...
openbr/core/bee.cpp
... ... @@ -260,26 +260,28 @@ void BEE::makeMask(const QString &amp;targetInput, const QString &amp;queryInput, const
260 260  
261 261 cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition)
262 262 {
263   - QList<float> targetLabels = targets.collectValues<float>("Label");
264   - QList<float> queryLabels = queries.collectValues<float>("Label");
  263 + // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet
  264 + // -cao
  265 + QList<QString> targetLabels = targets.collectValues<QString>("Subject", "-1");
  266 + QList<QString> queryLabels = queries.collectValues<QString>("Subject", "-1");
265 267 QList<int> targetPartitions = targets.crossValidationPartitions();
266 268 QList<int> queryPartitions = queries.crossValidationPartitions();
267 269  
268 270 Mat mask(queries.size(), targets.size(), CV_8UC1);
269 271 for (int i=0; i<queries.size(); i++) {
270 272 const QString &fileA = queries[i];
271   - const int labelA = queryLabels[i];
  273 + const QString labelA = queryLabels[i];
272 274 const int partitionA = queryPartitions[i];
273 275  
274 276 for (int j=0; j<targets.size(); j++) {
275 277 const QString &fileB = targets[j];
276   - const int labelB = targetLabels[j];
  278 + const QString labelB = targetLabels[j];
277 279 const int partitionB = targetPartitions[j];
278 280  
279 281 Mask_t val;
280 282 if (fileA == fileB) val = DontCare;
281   - else if (labelA == -1) val = DontCare;
282   - else if (labelB == -1) val = DontCare;
  283 + else if (labelA == "-1") val = DontCare;
  284 + else if (labelB == "-1") val = DontCare;
283 285 else if (partitionA != partition) val = DontCare;
284 286 else if (partitionB == -1) val = NonMatch;
285 287 else if (partitionB != partition) val = DontCare;
... ...
openbr/core/classify.cpp
... ... @@ -104,9 +104,9 @@ void br::EvalRegression(const QString &amp;predictedInput, const QString &amp;truthInput
104 104 for (int i=0; i<predicted.size(); i++) {
105 105 if (predicted[i].file.name != truth[i].file.name)
106 106 qFatal("Input order mismatch.");
107   - rmsError += pow(predicted[i].file.get<float>("Label")-truth[i].file.get<float>("Label"), 2.f);
108   - truthValues.append(QString::number(truth[i].file.get<float>("Label")));
109   - predictedValues.append(QString::number(predicted[i].file.get<float>("Label")));
  107 + rmsError += pow(predicted[i].file.get<float>("Subject")-truth[i].file.get<float>("Subject"), 2.f);
  108 + truthValues.append(QString::number(truth[i].file.get<float>("Subject")));
  109 + predictedValues.append(QString::number(predicted[i].file.get<float>("Subject")));
110 110 }
111 111  
112 112 QStringList rSource;
... ...
openbr/core/cluster.cpp
... ... @@ -278,7 +278,9 @@ void br::EvalClustering(const QString &amp;csv, const QString &amp;input)
278 278 {
279 279 qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input));
280 280  
281   - QList<float> labels = TemplateList::fromGallery(input).files().collectValues<float>("Label");
  281 + // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are
  282 + // not named).
  283 + QList<int> labels = TemplateList::fromGallery(input).files().collectValues<int>("Subject");
282 284  
283 285 QHash<int, int> labelToIndex;
284 286 int nClusters = 0;
... ...
openbr/core/core.cpp
... ... @@ -78,7 +78,6 @@ struct AlgorithmCore
78 78 const bool hasComparer = !distance.isNull();
79 79 out << hasComparer;
80 80 if (hasComparer) distance->store(out);
81   - out << Globals->subjects;
82 81  
83 82 // Compress and save to file
84 83 QtUtils::writeFile(model, data, -1);
... ... @@ -98,7 +97,6 @@ struct AlgorithmCore
98 97 transform->load(in);
99 98 bool hasDistance; in >> hasDistance;
100 99 if (hasDistance) distance->load(in);
101   - in >> Globals->subjects;
102 100 }
103 101  
104 102 File getMemoryGallery(const File &file) const
... ...
openbr/core/opencvutils.cpp
... ... @@ -119,6 +119,17 @@ Mat OpenCVUtils::toMat(const QList&lt;float&gt; &amp;src, int rows)
119 119 return dst;
120 120 }
121 121  
  122 +Mat OpenCVUtils::toMat(const QList<int> &src, int rows)
  123 +{
  124 + if (rows == -1) rows = src.size();
  125 + int columns = src.isEmpty() ? 0 : src.size() / rows;
  126 + if (rows*columns != src.size()) qFatal("Invalid matrix size.");
  127 + Mat dst(rows, columns, CV_32FC1);
  128 + for (int i=0; i<src.size(); i++)
  129 + dst.at<float>(i/columns,i%columns) = src[i];
  130 + return dst;
  131 +}
  132 +
122 133 Mat OpenCVUtils::toMat(const QList<Mat> &src)
123 134 {
124 135 if (src.isEmpty()) return Mat();
... ...
openbr/core/opencvutils.h
... ... @@ -35,6 +35,8 @@ namespace OpenCVUtils
35 35  
36 36 // To image
37 37 cv::Mat toMat(const QList<float> &src, int rows = -1);
  38 + cv::Mat toMat(const QList<int> &src, int rows = -1);
  39 +
38 40 cv::Mat toMat(const QList<cv::Mat> &src); // Data organized one matrix per row
39 41 cv::Mat toMatByRow(const QList<cv::Mat> &src); // Data organized one row per row
40 42  
... ...
openbr/frvt2012.cpp
... ... @@ -132,7 +132,7 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &amp;input_face, int32_t &amp;age)
132 132 TemplateList templates;
133 133 templates.append(templateFromONEFACE(input_face));
134 134 templates >> *frvt2012_age_transform.data();
135   - age = templates.first().file.get<float>("Label");
  135 + age = templates.first().file.get<float>("Subject");
136 136 return templates.first().file.failed() ? 4 : 0;
137 137 }
138 138  
... ... @@ -141,6 +141,7 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &amp;input_face, int8_t &amp;gender,
141 141 TemplateList templates;
142 142 templates.append(templateFromONEFACE(input_face));
143 143 templates >> *frvt2012_gender_transform.data();
144   - mf = gender = templates.first().file.get<float>("Label");
  144 + // TODO: lookup gender strings/expected int outputs -cao
  145 + mf = gender = templates.first().file.get<QString>("Subject") == "Male" ? 0 : 1;
145 146 return templates.first().file.failed() ? 4 : 0;
146 147 }
... ...
openbr/gui/classifier.cpp
... ... @@ -42,16 +42,15 @@ void Classifier::_classify(File file)
42 42 if (!f.contains("Label"))
43 43 continue;
44 44  
45   - // What's with the special casing -cao
46 45 if (algorithm == "GenderClassification") {
47 46 key = "Gender";
48 47 value = (f.get<QString>("Subject"));
49 48 } else if (algorithm == "AgeRegression") {
50 49 key = "Age";
51   - value = QString::number(int(f.get<float>("Label")+0.5)) + " Years";
  50 + value = QString::number(int(f.get<float>("Subject")+0.5)) + " Years";
52 51 } else {
53 52 key = algorithm;
54   - value = QString::number(f.get<float>("Label"));
  53 + value = f.get<QString>("Subject");
55 54 }
56 55 break;
57 56 }
... ...
openbr/openbr_plugin.cpp
... ... @@ -435,11 +435,13 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery)
435 435 return templates;
436 436 }
437 437  
  438 +// indexes some property, assigns an integer id to each unique value of propName
  439 +// stores the index values in "Label" of the output template list -cao
438 440 TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propName)
439 441 {
440   - const QList<int> originalLabels = tl.collectValues<int>(propName);
441   - QHash<int,int> labelTable;
442   - foreach (int label, originalLabels)
  442 + const QList<QString> originalLabels = tl.collectValues<QString>(propName);
  443 + QHash<QString,int> labelTable;
  444 + foreach (const QString & label, originalLabels)
443 445 if (!labelTable.contains(label))
444 446 labelTable.insert(label, labelTable.size());
445 447  
... ... @@ -449,6 +451,52 @@ TemplateList TemplateList::relabel(const TemplateList &amp;tl, const QString &amp; propN
449 451 return result;
450 452 }
451 453  
  454 +QList<int> TemplateList::indexProperty(const QString & propName, QHash<QString, int> * valueMap,QHash<int, QVariant> * reverseLookup) const
  455 +{
  456 + QHash<QString, int> dummyForwards;
  457 + QHash<int, QVariant> dummyBackwards;
  458 +
  459 + if (!valueMap) valueMap = &dummyForwards;
  460 + if (!reverseLookup) reverseLookup = &dummyBackwards;
  461 +
  462 + return indexProperty(propName, *valueMap, *reverseLookup);
  463 +}
  464 +
  465 +QList<int> TemplateList::indexProperty(const QString & propName, QHash<QString, int> & valueMap, QHash<int, QVariant> & reverseLookup) const
  466 +{
  467 + valueMap.clear();
  468 + reverseLookup.clear();
  469 +
  470 + const QList<QVariant> originalLabels = collectValues<QVariant>(propName);
  471 + foreach (const QVariant & label, originalLabels) {
  472 + QString labelString = label.toString();
  473 + if (!valueMap.contains(labelString)) {
  474 + reverseLookup.insert(valueMap.size(), label);
  475 + valueMap.insert(labelString, valueMap.size());
  476 + }
  477 + }
  478 +
  479 + QList<int> result;
  480 + for (int i=0; i<originalLabels.size(); i++)
  481 + result.append(valueMap[originalLabels[i].toString()]);
  482 +
  483 + return result;
  484 +}
  485 +
  486 +// uses -1 for missing values
  487 +QList<int> TemplateList::applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const
  488 +{
  489 + const QList<QString> originalLabels = collectValues<QString>(propName);
  490 +
  491 + QList<int> result;
  492 + for (int i=0; i<originalLabels.size(); i++) {
  493 + if (!valueMap.contains(originalLabels[i])) result.append(-1);
  494 + else result.append(valueMap[originalLabels[i]]);
  495 + }
  496 +
  497 + return result;
  498 +}
  499 +
452 500 /* Object - public methods */
453 501 QStringList Object::parameters() const
454 502 {
... ... @@ -989,7 +1037,8 @@ QString MatrixOutput::toString(int row, int column) const
989 1037 {
990 1038 if (targetFiles[column] == "Subject") {
991 1039 const int label = data.at<float>(row,column);
992   - return Globals->subjects.key(label, QString::number(label));
  1040 + // problem -cao
  1041 + return QString::number(label);
993 1042 }
994 1043 return QString::number(data.at<float>(row,column));
995 1044 }
... ...
openbr/openbr_plugin.h
... ... @@ -306,6 +306,15 @@ struct BR_EXPORT FileList : public QList&lt;File&gt;
306 306 values.append(f.get<T>(propName));
307 307 return values;
308 308 }
  309 + template<typename T>
  310 + QList<T> collectValues(const QString & propName, T defaultValue) const
  311 + {
  312 + QList<T> values; values.reserve(size());
  313 + foreach (const File &f, *this)
  314 + values.append(f.contains(propName) ? f.get<T>(propName) : defaultValue);
  315 + return values;
  316 + }
  317 +
309 318 QList<int> crossValidationPartitions() const; /*!< \brief Returns the cross-validation partition (default=0) for each file in the list. */
310 319 int failures() const; /*!< \brief Returns the number of files with br::File::failed(). */
311 320 };
... ... @@ -393,6 +402,11 @@ struct TemplateList : public QList&lt;Template&gt;
393 402 /*!< \brief Ensure labels are in the range [0,numClasses-1]. */
394 403 BR_EXPORT static TemplateList relabel(const TemplateList & tl, const QString & propName);
395 404  
  405 + QList<int> indexProperty(const QString & propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const;
  406 + QList<int> indexProperty(const QString & propName, QHash<QString, int> & valueMap, QHash<int, QVariant> & reverseLookup) const;
  407 + QList<int> applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const;
  408 +
  409 +
396 410 /*!
397 411 * \brief Returns the total number of bytes in all the templates.
398 412 */
... ... @@ -658,7 +672,6 @@ public:
658 672 BR_PROPERTY(int, crossValidate, 0)
659 673  
660 674 QHash<QString,QString> abbreviations; /*!< \brief Used by br::Transform::make() to expand abbreviated algorithms into their complete definitions. */
661   - QHash<QString,int> subjects; /*!< \brief Used by classifiers to associate text class labels with unique integers IDs. */
662 675 QTime startTime; /*!< \brief Used to estimate timeRemaining(). */
663 676  
664 677 /*!
... ...
openbr/plugins/eigen3.cpp
... ... @@ -329,7 +329,8 @@ class LDATransform : public Transform
329 329  
330 330 void train(const TemplateList &_trainingSet)
331 331 {
332   - TemplateList trainingSet = TemplateList::relabel(_trainingSet, "Label");
  332 + // creates "Label" -cao
  333 + TemplateList trainingSet = TemplateList::relabel(_trainingSet, "Subject");
333 334  
334 335 int instances = trainingSet.size();
335 336  
... ... @@ -342,12 +343,13 @@ class LDATransform : public Transform
342 343  
343 344 TemplateList ldaTrainingSet;
344 345 static_cast<Transform*>(&pca)->project(trainingSet, ldaTrainingSet);
  346 + // Reindex label, is this still necessary? -cao
345 347 ldaTrainingSet = TemplateList::relabel(ldaTrainingSet, "Label");
346 348  
347 349 int dimsIn = ldaTrainingSet.first().m().rows * ldaTrainingSet.first().m().cols;
348 350  
349 351 // OpenBR ensures that class values range from 0 to numClasses-1.
350   - // Assumed label is stored as float or int? -cao
  352 + // Label exists because we created it earlier with relabel
351 353 QList<int> classes = trainingSet.collectValues<int>("Label");
352 354 QMap<int, int> classCounts = trainingSet.countValues<int>("Label");
353 355 const int numClasses = classCounts.size();
... ...
openbr/plugins/gallery.cpp
... ... @@ -65,7 +65,7 @@ class arffGallery : public Gallery
65 65 const int dimensions = t.m().rows * t.m().cols;
66 66 for (int i=0; i<dimensions; i++)
67 67 arffFile.write(qPrintable("@ATTRIBUTE v" + QString::number(i) + " REAL\n"));
68   - arffFile.write(qPrintable("@ATTRIBUTE class {" + QStringList(Globals->subjects.keys()).join(',') + "}\n"));
  68 + arffFile.write(qPrintable("@ATTRIBUTE class string\n"));
69 69  
70 70 arffFile.write("\n@DATA\n");
71 71 }
... ... @@ -518,9 +518,8 @@ class csvGallery : public Gallery
518 518 static QString getCSVElement(const QString &key, const QVariant &value, bool header)
519 519 {
520 520 if ((key == "Label") && !header) {
521   - QString stringLabel = Globals->subjects.key(value.value<int>());
522   - if (stringLabel.isEmpty()) return value.value<QString>();
523   - else return stringLabel;
  521 + // problem -cao
  522 + return value.value<QString>();
524 523 } else if (value.canConvert<QString>()) {
525 524 if (header) return key;
526 525 else return value.value<QString>();
... ...
openbr/plugins/independent.cpp
... ... @@ -20,21 +20,23 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t
20 20 const bool atLeast = transform->instances < 0;
21 21 const int instances = abs(transform->instances);
22 22  
23   - QList<int> allLabels = templates.collectValues<int>("Label");
24   - QList<int> uniqueLabels = allLabels.toSet().toList();
  23 + QList<QString> allLabels = templates.collectValues<QString>("Subject");
  24 + QList<QString> uniqueLabels = allLabels.toSet().toList();
25 25 qSort(uniqueLabels);
26 26  
27   - QMap<int,int> counts = templates.countValues<int>("Label", instances != std::numeric_limits<int>::max());
  27 + QMap<QString,int> counts = templates.countValues<QString>("Subject", instances != std::numeric_limits<int>::max());
  28 +
28 29 if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max()))
29   - foreach (int label, counts.keys())
  30 + foreach (const QString & label, counts.keys())
30 31 if (counts[label] < instances)
31 32 counts.remove(label);
  33 +
32 34 uniqueLabels = counts.keys();
33 35 if ((transform->classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < transform->classes))
34 36 qWarning("Downsample requested %d classes but only %d are available.", transform->classes, uniqueLabels.size());
35 37  
36 38 Common::seedRNG();
37   - QList<int> selectedLabels = uniqueLabels;
  39 + QList<QString> selectedLabels = uniqueLabels;
38 40 if (transform->classes < uniqueLabels.size()) {
39 41 std::random_shuffle(selectedLabels.begin(), selectedLabels.end());
40 42 selectedLabels = selectedLabels.mid(0, transform->classes);
... ... @@ -42,7 +44,7 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t
42 44  
43 45 TemplateList downsample;
44 46 for (int i=0; i<selectedLabels.size(); i++) {
45   - const int selectedLabel = selectedLabels[i];
  47 + const QString selectedLabel = selectedLabels[i];
46 48 QList<int> indices;
47 49 for (int j=0; j<allLabels.size(); j++)
48 50 if ((allLabels[j] == selectedLabel) && (!templates.value(j).file.get<bool>("FTE", false)))
... ...
openbr/plugins/meta.cpp
... ... @@ -453,7 +453,8 @@ private:
453 453 const QString &file = src.file;
454 454 if (cache.contains(file)) {
455 455 dst = cache[file];
456   - dst.file.set("Label", src.file.value("Label"));
  456 + // don't get this -cao
  457 +// dst.file.set("Label", src.file.value("Label"));
457 458 } else {
458 459 transform->project(src, dst);
459 460 cacheLock.lock();
... ...
openbr/plugins/normalize.cpp
... ... @@ -126,7 +126,8 @@ private:
126 126 {
127 127 Mat m;
128 128 OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F);
129   - const QList<int> labels = data.collectValues<int>("Label");
  129 +
  130 + const QList<int> labels = data.indexProperty("Subject");
130 131 const int dims = m.cols;
131 132  
132 133 vector<Mat> mv, av, bv;
... ...
openbr/plugins/output.cpp
... ... @@ -153,8 +153,10 @@ class meltOutput : public MatrixOutput
153 153  
154 154 QStringList lines;
155 155 if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys));
156   - QList<float> queryLabels = queryFiles.collectValues<float>("Label");
157   - QList<float> targetLabels = targetFiles.collectValues<float>("Label");
  156 +
  157 + QList<QString> queryLabels = queryFiles.collectValues<QString>("Subject");
  158 + QList<QString> targetLabels = targetFiles.collectValues<QString>("Subject");
  159 +
158 160 for (int i=0; i<queryFiles.size(); i++) {
159 161 for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) {
160 162 const bool genuine = queryLabels[i] == targetLabels[j];
... ...
openbr/plugins/pixel.cpp
... ... @@ -28,6 +28,7 @@ namespace br
28 28 */
29 29 class PerPixelClassifierTransform : public MetaTransform
30 30 {
  31 + // problematic -cao
31 32 Q_OBJECT
32 33 Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform)
33 34 Q_PROPERTY(int pixels READ get_pixels WRITE set_pixels RESET reset_pixels STORED false)
... ... @@ -273,6 +274,7 @@ BR_REGISTER(Transform, ToBinaryVectorTransform)
273 274 * \author E. Taborsky \cite mmtaborsky
274 275 */
275 276  
  277 +// What does this do? -cao
276 278 class ToMetadataTransform : public UntrainableMetaTransform
277 279 {
278 280 Q_OBJECT
... ...
openbr/plugins/quality.cpp
... ... @@ -26,10 +26,10 @@ class ImpostorUniquenessMeasureTransform : public Transform
26 26  
27 27 float calculateIUM(const Template &probe, const TemplateList &gallery) const
28 28 {
29   - const int probeLabel = probe.file.get<float>("Label");
  29 + const QString probeLabel = probe.file.get<QString>("Subject");
30 30 TemplateList subset = gallery;
31 31 for (int j=subset.size()-1; j>=0; j--)
32   - if (subset[j].file.get<int>("Label") == probeLabel)
  32 + if (subset[j].file.get<QString>("Subject") == probeLabel)
33 33 subset.removeAt(j);
34 34  
35 35 QList<float> scores = distance->compare(subset, probe);
... ... @@ -159,7 +159,7 @@ class MatchProbabilityDistance : public Distance
159 159 {
160 160 distance->train(src);
161 161  
162   - const QList<int> labels = src.collectValues<int>("Label");
  162 + const QList<int> labels = src.indexProperty("Subject");
163 163 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size())));
164 164 distance->compare(src, src, matrixOutput.data());
165 165  
... ... @@ -219,7 +219,7 @@ class UnitDistance : public Distance
219 219 void train(const TemplateList &templates)
220 220 {
221 221 const TemplateList samples = templates.mid(0, 2000);
222   - const QList<float> sampleLabels = samples.collectValues<float>("Label");
  222 + const QList<int> sampleLabels = samples.indexProperty("Subject");
223 223 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size())));
224 224 Distance::compare(samples, samples, matrixOutput.data());
225 225  
... ...
openbr/plugins/quantize.cpp
... ... @@ -150,7 +150,7 @@ class BayesianQuantizationDistance : public Distance
150 150 qFatal("Expected sigle matrix templates of type CV_8UC1!");
151 151  
152 152 const Mat data = OpenCVUtils::toMat(src.data());
153   - const QList<int> templateLabels = src.collectValues<int>("Label");
  153 + const QList<int> templateLabels = src.indexProperty("Subject");
154 154 loglikelihoods = QVector<float>(data.cols*256, 0);
155 155  
156 156 QFutureSynchronizer<void> futures;
... ... @@ -473,7 +473,8 @@ private:
473 473 {
474 474 Mat data = OpenCVUtils::toMat(src.data());
475 475 const int step = getStep(data.cols);
476   - const QList<int> labels = src.collectValues<int>("Label");
  476 +
  477 + const QList<int> labels = src.indexProperty("Subject");
477 478  
478 479 Mat &lut = ProductQuantizationLUTs[index];
479 480 lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1);
... ...
openbr/plugins/quantize2.cpp
... ... @@ -77,7 +77,7 @@ class BayesianQuantizationTransform : public Transform
77 77 void train(const TemplateList &src)
78 78 {
79 79 const Mat data = OpenCVUtils::toMat(src.data());
80   - const QList<int> labels = src.collectValues<int>("Label");
  80 + const QList<int> labels = src.indexProperty("Subject");
81 81  
82 82 thresholds = QVector<float>(256*data.cols);
83 83  
... ...
openbr/plugins/svm.cpp
... ... @@ -121,28 +121,46 @@ private:
121 121 BR_PROPERTY(float, gamma, -1)
122 122  
123 123 SVM svm;
  124 + QHash<QString, int> labelMap;
  125 + QHash<int, QVariant> reverseLookup;
124 126  
125 127 void train(const TemplateList &_data)
126 128 {
127 129 Mat data = OpenCVUtils::toMat(_data.data());
128   - Mat lab = OpenCVUtils::toMat(_data.collectValues<float>("Label"));
  130 + Mat lab;
  131 + // If we are doing regression, assume subject has float values
  132 + if (type == EPS_SVR || type == NU_SVR) {
  133 + lab = OpenCVUtils::toMat(_data.collectValues<float>("Subject"));
  134 + }
  135 + // If we are doing classification, assume subject has discrete values, map them
  136 + // and store the mapping data
  137 + else {
  138 + QList<int> dataLabels = _data.indexProperty("Subject", labelMap, reverseLookup);
  139 + lab = OpenCVUtils::toMat(dataLabels);
  140 + }
129 141 trainSVM(svm, data, lab, kernel, type, C, gamma);
130 142 }
131 143  
132 144 void project(const Template &src, Template &dst) const
133 145 {
134 146 dst = src;
135   - dst.file.set("Label", svm.predict(src.m().reshape(1, 1)));
  147 + float prediction = svm.predict(src.m().reshape(1, 1));
  148 + if (type == EPS_SVR || type == NU_SVR)
  149 + dst.file.set("Subject", prediction);
  150 + else
  151 + dst.file.set("Subject", reverseLookup[prediction]);
136 152 }
137 153  
138 154 void store(QDataStream &stream) const
139 155 {
140 156 storeSVM(svm, stream);
  157 + stream << labelMap << reverseLookup;
141 158 }
142 159  
143 160 void load(QDataStream &stream)
144 161 {
145 162 loadSVM(svm, stream);
  163 + stream >> labelMap >> reverseLookup;
146 164 }
147 165 };
148 166  
... ...