Commit 72a5968a012360dda117defa4d8e0521806a31b5
1 parent
16039af1
Further changes to label/subject
Remove global label/subject lookup table Consistently use "Subject" rather than "Label", subject is assumed to be convertable to QString. When desirable, map discrete subject values to ints. For classifiers such as svm that require numeric labels, generate a string->int mapping for the training data, and store it (local to the transform). Utility functions for collecting all values of a given property (on a template list), and mapping discrete property values to 0-based integers Some outstanding issues include use of label/subject in matrix output
Showing
22 changed files
with
153 additions
and
48 deletions
app/examples/age_estimation.cpp
| ... | ... | @@ -29,7 +29,7 @@ |
| 29 | 29 | |
| 30 | 30 | static void printTemplate(const br::Template &t) |
| 31 | 31 | { |
| 32 | - printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Label"))); | |
| 32 | + printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Subject"))); | |
| 33 | 33 | } |
| 34 | 34 | |
| 35 | 35 | int main(int argc, char *argv[]) | ... | ... |
openbr/core/bee.cpp
| ... | ... | @@ -260,26 +260,28 @@ void BEE::makeMask(const QString &targetInput, const QString &queryInput, const |
| 260 | 260 | |
| 261 | 261 | cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition) |
| 262 | 262 | { |
| 263 | - QList<float> targetLabels = targets.collectValues<float>("Label"); | |
| 264 | - QList<float> queryLabels = queries.collectValues<float>("Label"); | |
| 263 | + // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet | |
| 264 | + // -cao | |
| 265 | + QList<QString> targetLabels = targets.collectValues<QString>("Subject", "-1"); | |
| 266 | + QList<QString> queryLabels = queries.collectValues<QString>("Subject", "-1"); | |
| 265 | 267 | QList<int> targetPartitions = targets.crossValidationPartitions(); |
| 266 | 268 | QList<int> queryPartitions = queries.crossValidationPartitions(); |
| 267 | 269 | |
| 268 | 270 | Mat mask(queries.size(), targets.size(), CV_8UC1); |
| 269 | 271 | for (int i=0; i<queries.size(); i++) { |
| 270 | 272 | const QString &fileA = queries[i]; |
| 271 | - const int labelA = queryLabels[i]; | |
| 273 | + const QString labelA = queryLabels[i]; | |
| 272 | 274 | const int partitionA = queryPartitions[i]; |
| 273 | 275 | |
| 274 | 276 | for (int j=0; j<targets.size(); j++) { |
| 275 | 277 | const QString &fileB = targets[j]; |
| 276 | - const int labelB = targetLabels[j]; | |
| 278 | + const QString labelB = targetLabels[j]; | |
| 277 | 279 | const int partitionB = targetPartitions[j]; |
| 278 | 280 | |
| 279 | 281 | Mask_t val; |
| 280 | 282 | if (fileA == fileB) val = DontCare; |
| 281 | - else if (labelA == -1) val = DontCare; | |
| 282 | - else if (labelB == -1) val = DontCare; | |
| 283 | + else if (labelA == "-1") val = DontCare; | |
| 284 | + else if (labelB == "-1") val = DontCare; | |
| 283 | 285 | else if (partitionA != partition) val = DontCare; |
| 284 | 286 | else if (partitionB == -1) val = NonMatch; |
| 285 | 287 | else if (partitionB != partition) val = DontCare; | ... | ... |
openbr/core/classify.cpp
| ... | ... | @@ -104,9 +104,9 @@ void br::EvalRegression(const QString &predictedInput, const QString &truthInput |
| 104 | 104 | for (int i=0; i<predicted.size(); i++) { |
| 105 | 105 | if (predicted[i].file.name != truth[i].file.name) |
| 106 | 106 | qFatal("Input order mismatch."); |
| 107 | - rmsError += pow(predicted[i].file.get<float>("Label")-truth[i].file.get<float>("Label"), 2.f); | |
| 108 | - truthValues.append(QString::number(truth[i].file.get<float>("Label"))); | |
| 109 | - predictedValues.append(QString::number(predicted[i].file.get<float>("Label"))); | |
| 107 | + rmsError += pow(predicted[i].file.get<float>("Subject")-truth[i].file.get<float>("Subject"), 2.f); | |
| 108 | + truthValues.append(QString::number(truth[i].file.get<float>("Subject"))); | |
| 109 | + predictedValues.append(QString::number(predicted[i].file.get<float>("Subject"))); | |
| 110 | 110 | } |
| 111 | 111 | |
| 112 | 112 | QStringList rSource; | ... | ... |
openbr/core/cluster.cpp
| ... | ... | @@ -278,7 +278,9 @@ void br::EvalClustering(const QString &csv, const QString &input) |
| 278 | 278 | { |
| 279 | 279 | qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input)); |
| 280 | 280 | |
| 281 | - QList<float> labels = TemplateList::fromGallery(input).files().collectValues<float>("Label"); | |
| 281 | + // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are | |
| 282 | + // not named). | |
| 283 | + QList<int> labels = TemplateList::fromGallery(input).files().collectValues<int>("Subject"); | |
| 282 | 284 | |
| 283 | 285 | QHash<int, int> labelToIndex; |
| 284 | 286 | int nClusters = 0; | ... | ... |
openbr/core/core.cpp
| ... | ... | @@ -78,7 +78,6 @@ struct AlgorithmCore |
| 78 | 78 | const bool hasComparer = !distance.isNull(); |
| 79 | 79 | out << hasComparer; |
| 80 | 80 | if (hasComparer) distance->store(out); |
| 81 | - out << Globals->subjects; | |
| 82 | 81 | |
| 83 | 82 | // Compress and save to file |
| 84 | 83 | QtUtils::writeFile(model, data, -1); |
| ... | ... | @@ -98,7 +97,6 @@ struct AlgorithmCore |
| 98 | 97 | transform->load(in); |
| 99 | 98 | bool hasDistance; in >> hasDistance; |
| 100 | 99 | if (hasDistance) distance->load(in); |
| 101 | - in >> Globals->subjects; | |
| 102 | 100 | } |
| 103 | 101 | |
| 104 | 102 | File getMemoryGallery(const File &file) const | ... | ... |
openbr/core/opencvutils.cpp
| ... | ... | @@ -119,6 +119,17 @@ Mat OpenCVUtils::toMat(const QList<float> &src, int rows) |
| 119 | 119 | return dst; |
| 120 | 120 | } |
| 121 | 121 | |
| 122 | +Mat OpenCVUtils::toMat(const QList<int> &src, int rows) | |
| 123 | +{ | |
| 124 | + if (rows == -1) rows = src.size(); | |
| 125 | + int columns = src.isEmpty() ? 0 : src.size() / rows; | |
| 126 | + if (rows*columns != src.size()) qFatal("Invalid matrix size."); | |
| 127 | + Mat dst(rows, columns, CV_32FC1); | |
| 128 | + for (int i=0; i<src.size(); i++) | |
| 129 | + dst.at<float>(i/columns,i%columns) = src[i]; | |
| 130 | + return dst; | |
| 131 | +} | |
| 132 | + | |
| 122 | 133 | Mat OpenCVUtils::toMat(const QList<Mat> &src) |
| 123 | 134 | { |
| 124 | 135 | if (src.isEmpty()) return Mat(); | ... | ... |
openbr/core/opencvutils.h
| ... | ... | @@ -35,6 +35,8 @@ namespace OpenCVUtils |
| 35 | 35 | |
| 36 | 36 | // To image |
| 37 | 37 | cv::Mat toMat(const QList<float> &src, int rows = -1); |
| 38 | + cv::Mat toMat(const QList<int> &src, int rows = -1); | |
| 39 | + | |
| 38 | 40 | cv::Mat toMat(const QList<cv::Mat> &src); // Data organized one matrix per row |
| 39 | 41 | cv::Mat toMatByRow(const QList<cv::Mat> &src); // Data organized one row per row |
| 40 | 42 | ... | ... |
openbr/frvt2012.cpp
| ... | ... | @@ -132,7 +132,7 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &input_face, int32_t &age) |
| 132 | 132 | TemplateList templates; |
| 133 | 133 | templates.append(templateFromONEFACE(input_face)); |
| 134 | 134 | templates >> *frvt2012_age_transform.data(); |
| 135 | - age = templates.first().file.get<float>("Label"); | |
| 135 | + age = templates.first().file.get<float>("Subject"); | |
| 136 | 136 | return templates.first().file.failed() ? 4 : 0; |
| 137 | 137 | } |
| 138 | 138 | |
| ... | ... | @@ -141,6 +141,7 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &input_face, int8_t &gender, |
| 141 | 141 | TemplateList templates; |
| 142 | 142 | templates.append(templateFromONEFACE(input_face)); |
| 143 | 143 | templates >> *frvt2012_gender_transform.data(); |
| 144 | - mf = gender = templates.first().file.get<float>("Label"); | |
| 144 | + // TODO: lookup gender strings/expected int outputs -cao | |
| 145 | + mf = gender = templates.first().file.get<QString>("Subject") == "Male" ? 0 : 1; | |
| 145 | 146 | return templates.first().file.failed() ? 4 : 0; |
| 146 | 147 | } | ... | ... |
openbr/gui/classifier.cpp
| ... | ... | @@ -42,16 +42,15 @@ void Classifier::_classify(File file) |
| 42 | 42 | if (!f.contains("Label")) |
| 43 | 43 | continue; |
| 44 | 44 | |
| 45 | - // What's with the special casing -cao | |
| 46 | 45 | if (algorithm == "GenderClassification") { |
| 47 | 46 | key = "Gender"; |
| 48 | 47 | value = (f.get<QString>("Subject")); |
| 49 | 48 | } else if (algorithm == "AgeRegression") { |
| 50 | 49 | key = "Age"; |
| 51 | - value = QString::number(int(f.get<float>("Label")+0.5)) + " Years"; | |
| 50 | + value = QString::number(int(f.get<float>("Subject")+0.5)) + " Years"; | |
| 52 | 51 | } else { |
| 53 | 52 | key = algorithm; |
| 54 | - value = QString::number(f.get<float>("Label")); | |
| 53 | + value = f.get<QString>("Subject"); | |
| 55 | 54 | } |
| 56 | 55 | break; |
| 57 | 56 | } | ... | ... |
openbr/openbr_plugin.cpp
| ... | ... | @@ -435,11 +435,13 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) |
| 435 | 435 | return templates; |
| 436 | 436 | } |
| 437 | 437 | |
| 438 | +// indexes some property, assigns an integer id to each unique value of propName | |
| 439 | +// stores the index values in "Label" of the output template list -cao | |
| 438 | 440 | TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propName) |
| 439 | 441 | { |
| 440 | - const QList<int> originalLabels = tl.collectValues<int>(propName); | |
| 441 | - QHash<int,int> labelTable; | |
| 442 | - foreach (int label, originalLabels) | |
| 442 | + const QList<QString> originalLabels = tl.collectValues<QString>(propName); | |
| 443 | + QHash<QString,int> labelTable; | |
| 444 | + foreach (const QString & label, originalLabels) | |
| 443 | 445 | if (!labelTable.contains(label)) |
| 444 | 446 | labelTable.insert(label, labelTable.size()); |
| 445 | 447 | |
| ... | ... | @@ -449,6 +451,52 @@ TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propN |
| 449 | 451 | return result; |
| 450 | 452 | } |
| 451 | 453 | |
| 454 | +QList<int> TemplateList::indexProperty(const QString & propName, QHash<QString, int> * valueMap,QHash<int, QVariant> * reverseLookup) const | |
| 455 | +{ | |
| 456 | + QHash<QString, int> dummyForwards; | |
| 457 | + QHash<int, QVariant> dummyBackwards; | |
| 458 | + | |
| 459 | + if (!valueMap) valueMap = &dummyForwards; | |
| 460 | + if (!reverseLookup) reverseLookup = &dummyBackwards; | |
| 461 | + | |
| 462 | + return indexProperty(propName, *valueMap, *reverseLookup); | |
| 463 | +} | |
| 464 | + | |
| 465 | +QList<int> TemplateList::indexProperty(const QString & propName, QHash<QString, int> & valueMap, QHash<int, QVariant> & reverseLookup) const | |
| 466 | +{ | |
| 467 | + valueMap.clear(); | |
| 468 | + reverseLookup.clear(); | |
| 469 | + | |
| 470 | + const QList<QVariant> originalLabels = collectValues<QVariant>(propName); | |
| 471 | + foreach (const QVariant & label, originalLabels) { | |
| 472 | + QString labelString = label.toString(); | |
| 473 | + if (!valueMap.contains(labelString)) { | |
| 474 | + reverseLookup.insert(valueMap.size(), label); | |
| 475 | + valueMap.insert(labelString, valueMap.size()); | |
| 476 | + } | |
| 477 | + } | |
| 478 | + | |
| 479 | + QList<int> result; | |
| 480 | + for (int i=0; i<originalLabels.size(); i++) | |
| 481 | + result.append(valueMap[originalLabels[i].toString()]); | |
| 482 | + | |
| 483 | + return result; | |
| 484 | +} | |
| 485 | + | |
| 486 | +// uses -1 for missing values | |
| 487 | +QList<int> TemplateList::applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const | |
| 488 | +{ | |
| 489 | + const QList<QString> originalLabels = collectValues<QString>(propName); | |
| 490 | + | |
| 491 | + QList<int> result; | |
| 492 | + for (int i=0; i<originalLabels.size(); i++) { | |
| 493 | + if (!valueMap.contains(originalLabels[i])) result.append(-1); | |
| 494 | + else result.append(valueMap[originalLabels[i]]); | |
| 495 | + } | |
| 496 | + | |
| 497 | + return result; | |
| 498 | +} | |
| 499 | + | |
| 452 | 500 | /* Object - public methods */ |
| 453 | 501 | QStringList Object::parameters() const |
| 454 | 502 | { |
| ... | ... | @@ -989,7 +1037,8 @@ QString MatrixOutput::toString(int row, int column) const |
| 989 | 1037 | { |
| 990 | 1038 | if (targetFiles[column] == "Subject") { |
| 991 | 1039 | const int label = data.at<float>(row,column); |
| 992 | - return Globals->subjects.key(label, QString::number(label)); | |
| 1040 | + // problem -cao | |
| 1041 | + return QString::number(label); | |
| 993 | 1042 | } |
| 994 | 1043 | return QString::number(data.at<float>(row,column)); |
| 995 | 1044 | } | ... | ... |
openbr/openbr_plugin.h
| ... | ... | @@ -306,6 +306,15 @@ struct BR_EXPORT FileList : public QList<File> |
| 306 | 306 | values.append(f.get<T>(propName)); |
| 307 | 307 | return values; |
| 308 | 308 | } |
| 309 | + template<typename T> | |
| 310 | + QList<T> collectValues(const QString & propName, T defaultValue) const | |
| 311 | + { | |
| 312 | + QList<T> values; values.reserve(size()); | |
| 313 | + foreach (const File &f, *this) | |
| 314 | + values.append(f.contains(propName) ? f.get<T>(propName) : defaultValue); | |
| 315 | + return values; | |
| 316 | + } | |
| 317 | + | |
| 309 | 318 | QList<int> crossValidationPartitions() const; /*!< \brief Returns the cross-validation partition (default=0) for each file in the list. */ |
| 310 | 319 | int failures() const; /*!< \brief Returns the number of files with br::File::failed(). */ |
| 311 | 320 | }; |
| ... | ... | @@ -393,6 +402,11 @@ struct TemplateList : public QList<Template> |
| 393 | 402 | /*!< \brief Ensure labels are in the range [0,numClasses-1]. */ |
| 394 | 403 | BR_EXPORT static TemplateList relabel(const TemplateList & tl, const QString & propName); |
| 395 | 404 | |
| 405 | + QList<int> indexProperty(const QString & propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const; | |
| 406 | + QList<int> indexProperty(const QString & propName, QHash<QString, int> & valueMap, QHash<int, QVariant> & reverseLookup) const; | |
| 407 | + QList<int> applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const; | |
| 408 | + | |
| 409 | + | |
| 396 | 410 | /*! |
| 397 | 411 | * \brief Returns the total number of bytes in all the templates. |
| 398 | 412 | */ |
| ... | ... | @@ -658,7 +672,6 @@ public: |
| 658 | 672 | BR_PROPERTY(int, crossValidate, 0) |
| 659 | 673 | |
| 660 | 674 | QHash<QString,QString> abbreviations; /*!< \brief Used by br::Transform::make() to expand abbreviated algorithms into their complete definitions. */ |
| 661 | - QHash<QString,int> subjects; /*!< \brief Used by classifiers to associate text class labels with unique integers IDs. */ | |
| 662 | 675 | QTime startTime; /*!< \brief Used to estimate timeRemaining(). */ |
| 663 | 676 | |
| 664 | 677 | /*! | ... | ... |
openbr/plugins/eigen3.cpp
| ... | ... | @@ -329,7 +329,8 @@ class LDATransform : public Transform |
| 329 | 329 | |
| 330 | 330 | void train(const TemplateList &_trainingSet) |
| 331 | 331 | { |
| 332 | - TemplateList trainingSet = TemplateList::relabel(_trainingSet, "Label"); | |
| 332 | + // creates "Label" -cao | |
| 333 | + TemplateList trainingSet = TemplateList::relabel(_trainingSet, "Subject"); | |
| 333 | 334 | |
| 334 | 335 | int instances = trainingSet.size(); |
| 335 | 336 | |
| ... | ... | @@ -342,12 +343,13 @@ class LDATransform : public Transform |
| 342 | 343 | |
| 343 | 344 | TemplateList ldaTrainingSet; |
| 344 | 345 | static_cast<Transform*>(&pca)->project(trainingSet, ldaTrainingSet); |
| 346 | + // Reindex label, is this still necessary? -cao | |
| 345 | 347 | ldaTrainingSet = TemplateList::relabel(ldaTrainingSet, "Label"); |
| 346 | 348 | |
| 347 | 349 | int dimsIn = ldaTrainingSet.first().m().rows * ldaTrainingSet.first().m().cols; |
| 348 | 350 | |
| 349 | 351 | // OpenBR ensures that class values range from 0 to numClasses-1. |
| 350 | - // Assumed label is stored as float or int? -cao | |
| 352 | + // Label exists because we created it earlier with relabel | |
| 351 | 353 | QList<int> classes = trainingSet.collectValues<int>("Label"); |
| 352 | 354 | QMap<int, int> classCounts = trainingSet.countValues<int>("Label"); |
| 353 | 355 | const int numClasses = classCounts.size(); | ... | ... |
openbr/plugins/gallery.cpp
| ... | ... | @@ -65,7 +65,7 @@ class arffGallery : public Gallery |
| 65 | 65 | const int dimensions = t.m().rows * t.m().cols; |
| 66 | 66 | for (int i=0; i<dimensions; i++) |
| 67 | 67 | arffFile.write(qPrintable("@ATTRIBUTE v" + QString::number(i) + " REAL\n")); |
| 68 | - arffFile.write(qPrintable("@ATTRIBUTE class {" + QStringList(Globals->subjects.keys()).join(',') + "}\n")); | |
| 68 | + arffFile.write(qPrintable("@ATTRIBUTE class string\n")); | |
| 69 | 69 | |
| 70 | 70 | arffFile.write("\n@DATA\n"); |
| 71 | 71 | } |
| ... | ... | @@ -518,9 +518,8 @@ class csvGallery : public Gallery |
| 518 | 518 | static QString getCSVElement(const QString &key, const QVariant &value, bool header) |
| 519 | 519 | { |
| 520 | 520 | if ((key == "Label") && !header) { |
| 521 | - QString stringLabel = Globals->subjects.key(value.value<int>()); | |
| 522 | - if (stringLabel.isEmpty()) return value.value<QString>(); | |
| 523 | - else return stringLabel; | |
| 521 | + // problem -cao | |
| 522 | + return value.value<QString>(); | |
| 524 | 523 | } else if (value.canConvert<QString>()) { |
| 525 | 524 | if (header) return key; |
| 526 | 525 | else return value.value<QString>(); | ... | ... |
openbr/plugins/independent.cpp
| ... | ... | @@ -20,21 +20,23 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t |
| 20 | 20 | const bool atLeast = transform->instances < 0; |
| 21 | 21 | const int instances = abs(transform->instances); |
| 22 | 22 | |
| 23 | - QList<int> allLabels = templates.collectValues<int>("Label"); | |
| 24 | - QList<int> uniqueLabels = allLabels.toSet().toList(); | |
| 23 | + QList<QString> allLabels = templates.collectValues<QString>("Subject"); | |
| 24 | + QList<QString> uniqueLabels = allLabels.toSet().toList(); | |
| 25 | 25 | qSort(uniqueLabels); |
| 26 | 26 | |
| 27 | - QMap<int,int> counts = templates.countValues<int>("Label", instances != std::numeric_limits<int>::max()); | |
| 27 | + QMap<QString,int> counts = templates.countValues<QString>("Subject", instances != std::numeric_limits<int>::max()); | |
| 28 | + | |
| 28 | 29 | if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max())) |
| 29 | - foreach (int label, counts.keys()) | |
| 30 | + foreach (const QString & label, counts.keys()) | |
| 30 | 31 | if (counts[label] < instances) |
| 31 | 32 | counts.remove(label); |
| 33 | + | |
| 32 | 34 | uniqueLabels = counts.keys(); |
| 33 | 35 | if ((transform->classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < transform->classes)) |
| 34 | 36 | qWarning("Downsample requested %d classes but only %d are available.", transform->classes, uniqueLabels.size()); |
| 35 | 37 | |
| 36 | 38 | Common::seedRNG(); |
| 37 | - QList<int> selectedLabels = uniqueLabels; | |
| 39 | + QList<QString> selectedLabels = uniqueLabels; | |
| 38 | 40 | if (transform->classes < uniqueLabels.size()) { |
| 39 | 41 | std::random_shuffle(selectedLabels.begin(), selectedLabels.end()); |
| 40 | 42 | selectedLabels = selectedLabels.mid(0, transform->classes); |
| ... | ... | @@ -42,7 +44,7 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t |
| 42 | 44 | |
| 43 | 45 | TemplateList downsample; |
| 44 | 46 | for (int i=0; i<selectedLabels.size(); i++) { |
| 45 | - const int selectedLabel = selectedLabels[i]; | |
| 47 | + const QString selectedLabel = selectedLabels[i]; | |
| 46 | 48 | QList<int> indices; |
| 47 | 49 | for (int j=0; j<allLabels.size(); j++) |
| 48 | 50 | if ((allLabels[j] == selectedLabel) && (!templates.value(j).file.get<bool>("FTE", false))) | ... | ... |
openbr/plugins/meta.cpp
| ... | ... | @@ -453,7 +453,8 @@ private: |
| 453 | 453 | const QString &file = src.file; |
| 454 | 454 | if (cache.contains(file)) { |
| 455 | 455 | dst = cache[file]; |
| 456 | - dst.file.set("Label", src.file.value("Label")); | |
| 456 | + // don't get this -cao | |
| 457 | +// dst.file.set("Label", src.file.value("Label")); | |
| 457 | 458 | } else { |
| 458 | 459 | transform->project(src, dst); |
| 459 | 460 | cacheLock.lock(); | ... | ... |
openbr/plugins/normalize.cpp
| ... | ... | @@ -126,7 +126,8 @@ private: |
| 126 | 126 | { |
| 127 | 127 | Mat m; |
| 128 | 128 | OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F); |
| 129 | - const QList<int> labels = data.collectValues<int>("Label"); | |
| 129 | + | |
| 130 | + const QList<int> labels = data.indexProperty("Subject"); | |
| 130 | 131 | const int dims = m.cols; |
| 131 | 132 | |
| 132 | 133 | vector<Mat> mv, av, bv; | ... | ... |
openbr/plugins/output.cpp
| ... | ... | @@ -153,8 +153,10 @@ class meltOutput : public MatrixOutput |
| 153 | 153 | |
| 154 | 154 | QStringList lines; |
| 155 | 155 | if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys)); |
| 156 | - QList<float> queryLabels = queryFiles.collectValues<float>("Label"); | |
| 157 | - QList<float> targetLabels = targetFiles.collectValues<float>("Label"); | |
| 156 | + | |
| 157 | + QList<QString> queryLabels = queryFiles.collectValues<QString>("Subject"); | |
| 158 | + QList<QString> targetLabels = targetFiles.collectValues<QString>("Subject"); | |
| 159 | + | |
| 158 | 160 | for (int i=0; i<queryFiles.size(); i++) { |
| 159 | 161 | for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) { |
| 160 | 162 | const bool genuine = queryLabels[i] == targetLabels[j]; | ... | ... |
openbr/plugins/pixel.cpp
| ... | ... | @@ -28,6 +28,7 @@ namespace br |
| 28 | 28 | */ |
| 29 | 29 | class PerPixelClassifierTransform : public MetaTransform |
| 30 | 30 | { |
| 31 | + // problematic -cao | |
| 31 | 32 | Q_OBJECT |
| 32 | 33 | Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform) |
| 33 | 34 | Q_PROPERTY(int pixels READ get_pixels WRITE set_pixels RESET reset_pixels STORED false) |
| ... | ... | @@ -273,6 +274,7 @@ BR_REGISTER(Transform, ToBinaryVectorTransform) |
| 273 | 274 | * \author E. Taborsky \cite mmtaborsky |
| 274 | 275 | */ |
| 275 | 276 | |
| 277 | +// What does this do? -cao | |
| 276 | 278 | class ToMetadataTransform : public UntrainableMetaTransform |
| 277 | 279 | { |
| 278 | 280 | Q_OBJECT | ... | ... |
openbr/plugins/quality.cpp
| ... | ... | @@ -26,10 +26,10 @@ class ImpostorUniquenessMeasureTransform : public Transform |
| 26 | 26 | |
| 27 | 27 | float calculateIUM(const Template &probe, const TemplateList &gallery) const |
| 28 | 28 | { |
| 29 | - const int probeLabel = probe.file.get<float>("Label"); | |
| 29 | + const QString probeLabel = probe.file.get<QString>("Subject"); | |
| 30 | 30 | TemplateList subset = gallery; |
| 31 | 31 | for (int j=subset.size()-1; j>=0; j--) |
| 32 | - if (subset[j].file.get<int>("Label") == probeLabel) | |
| 32 | + if (subset[j].file.get<QString>("Subject") == probeLabel) | |
| 33 | 33 | subset.removeAt(j); |
| 34 | 34 | |
| 35 | 35 | QList<float> scores = distance->compare(subset, probe); |
| ... | ... | @@ -159,7 +159,7 @@ class MatchProbabilityDistance : public Distance |
| 159 | 159 | { |
| 160 | 160 | distance->train(src); |
| 161 | 161 | |
| 162 | - const QList<int> labels = src.collectValues<int>("Label"); | |
| 162 | + const QList<int> labels = src.indexProperty("Subject"); | |
| 163 | 163 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); |
| 164 | 164 | distance->compare(src, src, matrixOutput.data()); |
| 165 | 165 | |
| ... | ... | @@ -219,7 +219,7 @@ class UnitDistance : public Distance |
| 219 | 219 | void train(const TemplateList &templates) |
| 220 | 220 | { |
| 221 | 221 | const TemplateList samples = templates.mid(0, 2000); |
| 222 | - const QList<float> sampleLabels = samples.collectValues<float>("Label"); | |
| 222 | + const QList<int> sampleLabels = samples.indexProperty("Subject"); | |
| 223 | 223 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size()))); |
| 224 | 224 | Distance::compare(samples, samples, matrixOutput.data()); |
| 225 | 225 | ... | ... |
openbr/plugins/quantize.cpp
| ... | ... | @@ -150,7 +150,7 @@ class BayesianQuantizationDistance : public Distance |
| 150 | 150 | qFatal("Expected sigle matrix templates of type CV_8UC1!"); |
| 151 | 151 | |
| 152 | 152 | const Mat data = OpenCVUtils::toMat(src.data()); |
| 153 | - const QList<int> templateLabels = src.collectValues<int>("Label"); | |
| 153 | + const QList<int> templateLabels = src.indexProperty("Subject"); | |
| 154 | 154 | loglikelihoods = QVector<float>(data.cols*256, 0); |
| 155 | 155 | |
| 156 | 156 | QFutureSynchronizer<void> futures; |
| ... | ... | @@ -473,7 +473,8 @@ private: |
| 473 | 473 | { |
| 474 | 474 | Mat data = OpenCVUtils::toMat(src.data()); |
| 475 | 475 | const int step = getStep(data.cols); |
| 476 | - const QList<int> labels = src.collectValues<int>("Label"); | |
| 476 | + | |
| 477 | + const QList<int> labels = src.indexProperty("Subject"); | |
| 477 | 478 | |
| 478 | 479 | Mat &lut = ProductQuantizationLUTs[index]; |
| 479 | 480 | lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1); | ... | ... |
openbr/plugins/quantize2.cpp
| ... | ... | @@ -77,7 +77,7 @@ class BayesianQuantizationTransform : public Transform |
| 77 | 77 | void train(const TemplateList &src) |
| 78 | 78 | { |
| 79 | 79 | const Mat data = OpenCVUtils::toMat(src.data()); |
| 80 | - const QList<int> labels = src.collectValues<int>("Label"); | |
| 80 | + const QList<int> labels = src.indexProperty("Subject"); | |
| 81 | 81 | |
| 82 | 82 | thresholds = QVector<float>(256*data.cols); |
| 83 | 83 | ... | ... |
openbr/plugins/svm.cpp
| ... | ... | @@ -121,28 +121,46 @@ private: |
| 121 | 121 | BR_PROPERTY(float, gamma, -1) |
| 122 | 122 | |
| 123 | 123 | SVM svm; |
| 124 | + QHash<QString, int> labelMap; | |
| 125 | + QHash<int, QVariant> reverseLookup; | |
| 124 | 126 | |
| 125 | 127 | void train(const TemplateList &_data) |
| 126 | 128 | { |
| 127 | 129 | Mat data = OpenCVUtils::toMat(_data.data()); |
| 128 | - Mat lab = OpenCVUtils::toMat(_data.collectValues<float>("Label")); | |
| 130 | + Mat lab; | |
| 131 | + // If we are doing regression, assume subject has float values | |
| 132 | + if (type == EPS_SVR || type == NU_SVR) { | |
| 133 | + lab = OpenCVUtils::toMat(_data.collectValues<float>("Subject")); | |
| 134 | + } | |
| 135 | + // If we are doing classification, assume subject has discrete values, map them | |
| 136 | + // and store the mapping data | |
| 137 | + else { | |
| 138 | + QList<int> dataLabels = _data.indexProperty("Subject", labelMap, reverseLookup); | |
| 139 | + lab = OpenCVUtils::toMat(dataLabels); | |
| 140 | + } | |
| 129 | 141 | trainSVM(svm, data, lab, kernel, type, C, gamma); |
| 130 | 142 | } |
| 131 | 143 | |
| 132 | 144 | void project(const Template &src, Template &dst) const |
| 133 | 145 | { |
| 134 | 146 | dst = src; |
| 135 | - dst.file.set("Label", svm.predict(src.m().reshape(1, 1))); | |
| 147 | + float prediction = svm.predict(src.m().reshape(1, 1)); | |
| 148 | + if (type == EPS_SVR || type == NU_SVR) | |
| 149 | + dst.file.set("Subject", prediction); | |
| 150 | + else | |
| 151 | + dst.file.set("Subject", reverseLookup[prediction]); | |
| 136 | 152 | } |
| 137 | 153 | |
| 138 | 154 | void store(QDataStream &stream) const |
| 139 | 155 | { |
| 140 | 156 | storeSVM(svm, stream); |
| 157 | + stream << labelMap << reverseLookup; | |
| 141 | 158 | } |
| 142 | 159 | |
| 143 | 160 | void load(QDataStream &stream) |
| 144 | 161 | { |
| 145 | 162 | loadSVM(svm, stream); |
| 163 | + stream >> labelMap >> reverseLookup; | |
| 146 | 164 | } |
| 147 | 165 | }; |
| 148 | 166 | ... | ... |