Commit 72a5968a012360dda117defa4d8e0521806a31b5
1 parent
16039af1
Further changes to label/subject
Remove global label/subject lookup table Consistently use "Subject" rather than "Label", subject is assumed to be convertable to QString. When desirable, map discrete subject values to ints. For classifiers such as svm that require numeric labels, generate a string->int mapping for the training data, and store it (local to the transform). Utility functions for collecting all values of a given property (on a template list), and mapping discrete property values to 0-based integers Some outstanding issues include use of label/subject in matrix output
Showing
22 changed files
with
153 additions
and
48 deletions
app/examples/age_estimation.cpp
| @@ -29,7 +29,7 @@ | @@ -29,7 +29,7 @@ | ||
| 29 | 29 | ||
| 30 | static void printTemplate(const br::Template &t) | 30 | static void printTemplate(const br::Template &t) |
| 31 | { | 31 | { |
| 32 | - printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Label"))); | 32 | + printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Subject"))); |
| 33 | } | 33 | } |
| 34 | 34 | ||
| 35 | int main(int argc, char *argv[]) | 35 | int main(int argc, char *argv[]) |
openbr/core/bee.cpp
| @@ -260,26 +260,28 @@ void BEE::makeMask(const QString &targetInput, const QString &queryInput, const | @@ -260,26 +260,28 @@ void BEE::makeMask(const QString &targetInput, const QString &queryInput, const | ||
| 260 | 260 | ||
| 261 | cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition) | 261 | cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition) |
| 262 | { | 262 | { |
| 263 | - QList<float> targetLabels = targets.collectValues<float>("Label"); | ||
| 264 | - QList<float> queryLabels = queries.collectValues<float>("Label"); | 263 | + // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet |
| 264 | + // -cao | ||
| 265 | + QList<QString> targetLabels = targets.collectValues<QString>("Subject", "-1"); | ||
| 266 | + QList<QString> queryLabels = queries.collectValues<QString>("Subject", "-1"); | ||
| 265 | QList<int> targetPartitions = targets.crossValidationPartitions(); | 267 | QList<int> targetPartitions = targets.crossValidationPartitions(); |
| 266 | QList<int> queryPartitions = queries.crossValidationPartitions(); | 268 | QList<int> queryPartitions = queries.crossValidationPartitions(); |
| 267 | 269 | ||
| 268 | Mat mask(queries.size(), targets.size(), CV_8UC1); | 270 | Mat mask(queries.size(), targets.size(), CV_8UC1); |
| 269 | for (int i=0; i<queries.size(); i++) { | 271 | for (int i=0; i<queries.size(); i++) { |
| 270 | const QString &fileA = queries[i]; | 272 | const QString &fileA = queries[i]; |
| 271 | - const int labelA = queryLabels[i]; | 273 | + const QString labelA = queryLabels[i]; |
| 272 | const int partitionA = queryPartitions[i]; | 274 | const int partitionA = queryPartitions[i]; |
| 273 | 275 | ||
| 274 | for (int j=0; j<targets.size(); j++) { | 276 | for (int j=0; j<targets.size(); j++) { |
| 275 | const QString &fileB = targets[j]; | 277 | const QString &fileB = targets[j]; |
| 276 | - const int labelB = targetLabels[j]; | 278 | + const QString labelB = targetLabels[j]; |
| 277 | const int partitionB = targetPartitions[j]; | 279 | const int partitionB = targetPartitions[j]; |
| 278 | 280 | ||
| 279 | Mask_t val; | 281 | Mask_t val; |
| 280 | if (fileA == fileB) val = DontCare; | 282 | if (fileA == fileB) val = DontCare; |
| 281 | - else if (labelA == -1) val = DontCare; | ||
| 282 | - else if (labelB == -1) val = DontCare; | 283 | + else if (labelA == "-1") val = DontCare; |
| 284 | + else if (labelB == "-1") val = DontCare; | ||
| 283 | else if (partitionA != partition) val = DontCare; | 285 | else if (partitionA != partition) val = DontCare; |
| 284 | else if (partitionB == -1) val = NonMatch; | 286 | else if (partitionB == -1) val = NonMatch; |
| 285 | else if (partitionB != partition) val = DontCare; | 287 | else if (partitionB != partition) val = DontCare; |
openbr/core/classify.cpp
| @@ -104,9 +104,9 @@ void br::EvalRegression(const QString &predictedInput, const QString &truthInput | @@ -104,9 +104,9 @@ void br::EvalRegression(const QString &predictedInput, const QString &truthInput | ||
| 104 | for (int i=0; i<predicted.size(); i++) { | 104 | for (int i=0; i<predicted.size(); i++) { |
| 105 | if (predicted[i].file.name != truth[i].file.name) | 105 | if (predicted[i].file.name != truth[i].file.name) |
| 106 | qFatal("Input order mismatch."); | 106 | qFatal("Input order mismatch."); |
| 107 | - rmsError += pow(predicted[i].file.get<float>("Label")-truth[i].file.get<float>("Label"), 2.f); | ||
| 108 | - truthValues.append(QString::number(truth[i].file.get<float>("Label"))); | ||
| 109 | - predictedValues.append(QString::number(predicted[i].file.get<float>("Label"))); | 107 | + rmsError += pow(predicted[i].file.get<float>("Subject")-truth[i].file.get<float>("Subject"), 2.f); |
| 108 | + truthValues.append(QString::number(truth[i].file.get<float>("Subject"))); | ||
| 109 | + predictedValues.append(QString::number(predicted[i].file.get<float>("Subject"))); | ||
| 110 | } | 110 | } |
| 111 | 111 | ||
| 112 | QStringList rSource; | 112 | QStringList rSource; |
openbr/core/cluster.cpp
| @@ -278,7 +278,9 @@ void br::EvalClustering(const QString &csv, const QString &input) | @@ -278,7 +278,9 @@ void br::EvalClustering(const QString &csv, const QString &input) | ||
| 278 | { | 278 | { |
| 279 | qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input)); | 279 | qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input)); |
| 280 | 280 | ||
| 281 | - QList<float> labels = TemplateList::fromGallery(input).files().collectValues<float>("Label"); | 281 | + // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are |
| 282 | + // not named). | ||
| 283 | + QList<int> labels = TemplateList::fromGallery(input).files().collectValues<int>("Subject"); | ||
| 282 | 284 | ||
| 283 | QHash<int, int> labelToIndex; | 285 | QHash<int, int> labelToIndex; |
| 284 | int nClusters = 0; | 286 | int nClusters = 0; |
openbr/core/core.cpp
| @@ -78,7 +78,6 @@ struct AlgorithmCore | @@ -78,7 +78,6 @@ struct AlgorithmCore | ||
| 78 | const bool hasComparer = !distance.isNull(); | 78 | const bool hasComparer = !distance.isNull(); |
| 79 | out << hasComparer; | 79 | out << hasComparer; |
| 80 | if (hasComparer) distance->store(out); | 80 | if (hasComparer) distance->store(out); |
| 81 | - out << Globals->subjects; | ||
| 82 | 81 | ||
| 83 | // Compress and save to file | 82 | // Compress and save to file |
| 84 | QtUtils::writeFile(model, data, -1); | 83 | QtUtils::writeFile(model, data, -1); |
| @@ -98,7 +97,6 @@ struct AlgorithmCore | @@ -98,7 +97,6 @@ struct AlgorithmCore | ||
| 98 | transform->load(in); | 97 | transform->load(in); |
| 99 | bool hasDistance; in >> hasDistance; | 98 | bool hasDistance; in >> hasDistance; |
| 100 | if (hasDistance) distance->load(in); | 99 | if (hasDistance) distance->load(in); |
| 101 | - in >> Globals->subjects; | ||
| 102 | } | 100 | } |
| 103 | 101 | ||
| 104 | File getMemoryGallery(const File &file) const | 102 | File getMemoryGallery(const File &file) const |
openbr/core/opencvutils.cpp
| @@ -119,6 +119,17 @@ Mat OpenCVUtils::toMat(const QList<float> &src, int rows) | @@ -119,6 +119,17 @@ Mat OpenCVUtils::toMat(const QList<float> &src, int rows) | ||
| 119 | return dst; | 119 | return dst; |
| 120 | } | 120 | } |
| 121 | 121 | ||
| 122 | +Mat OpenCVUtils::toMat(const QList<int> &src, int rows) | ||
| 123 | +{ | ||
| 124 | + if (rows == -1) rows = src.size(); | ||
| 125 | + int columns = src.isEmpty() ? 0 : src.size() / rows; | ||
| 126 | + if (rows*columns != src.size()) qFatal("Invalid matrix size."); | ||
| 127 | + Mat dst(rows, columns, CV_32FC1); | ||
| 128 | + for (int i=0; i<src.size(); i++) | ||
| 129 | + dst.at<float>(i/columns,i%columns) = src[i]; | ||
| 130 | + return dst; | ||
| 131 | +} | ||
| 132 | + | ||
| 122 | Mat OpenCVUtils::toMat(const QList<Mat> &src) | 133 | Mat OpenCVUtils::toMat(const QList<Mat> &src) |
| 123 | { | 134 | { |
| 124 | if (src.isEmpty()) return Mat(); | 135 | if (src.isEmpty()) return Mat(); |
openbr/core/opencvutils.h
| @@ -35,6 +35,8 @@ namespace OpenCVUtils | @@ -35,6 +35,8 @@ namespace OpenCVUtils | ||
| 35 | 35 | ||
| 36 | // To image | 36 | // To image |
| 37 | cv::Mat toMat(const QList<float> &src, int rows = -1); | 37 | cv::Mat toMat(const QList<float> &src, int rows = -1); |
| 38 | + cv::Mat toMat(const QList<int> &src, int rows = -1); | ||
| 39 | + | ||
| 38 | cv::Mat toMat(const QList<cv::Mat> &src); // Data organized one matrix per row | 40 | cv::Mat toMat(const QList<cv::Mat> &src); // Data organized one matrix per row |
| 39 | cv::Mat toMatByRow(const QList<cv::Mat> &src); // Data organized one row per row | 41 | cv::Mat toMatByRow(const QList<cv::Mat> &src); // Data organized one row per row |
| 40 | 42 |
openbr/frvt2012.cpp
| @@ -132,7 +132,7 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &input_face, int32_t &age) | @@ -132,7 +132,7 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &input_face, int32_t &age) | ||
| 132 | TemplateList templates; | 132 | TemplateList templates; |
| 133 | templates.append(templateFromONEFACE(input_face)); | 133 | templates.append(templateFromONEFACE(input_face)); |
| 134 | templates >> *frvt2012_age_transform.data(); | 134 | templates >> *frvt2012_age_transform.data(); |
| 135 | - age = templates.first().file.get<float>("Label"); | 135 | + age = templates.first().file.get<float>("Subject"); |
| 136 | return templates.first().file.failed() ? 4 : 0; | 136 | return templates.first().file.failed() ? 4 : 0; |
| 137 | } | 137 | } |
| 138 | 138 | ||
| @@ -141,6 +141,7 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &input_face, int8_t &gender, | @@ -141,6 +141,7 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &input_face, int8_t &gender, | ||
| 141 | TemplateList templates; | 141 | TemplateList templates; |
| 142 | templates.append(templateFromONEFACE(input_face)); | 142 | templates.append(templateFromONEFACE(input_face)); |
| 143 | templates >> *frvt2012_gender_transform.data(); | 143 | templates >> *frvt2012_gender_transform.data(); |
| 144 | - mf = gender = templates.first().file.get<float>("Label"); | 144 | + // TODO: lookup gender strings/expected int outputs -cao |
| 145 | + mf = gender = templates.first().file.get<QString>("Subject") == "Male" ? 0 : 1; | ||
| 145 | return templates.first().file.failed() ? 4 : 0; | 146 | return templates.first().file.failed() ? 4 : 0; |
| 146 | } | 147 | } |
openbr/gui/classifier.cpp
| @@ -42,16 +42,15 @@ void Classifier::_classify(File file) | @@ -42,16 +42,15 @@ void Classifier::_classify(File file) | ||
| 42 | if (!f.contains("Label")) | 42 | if (!f.contains("Label")) |
| 43 | continue; | 43 | continue; |
| 44 | 44 | ||
| 45 | - // What's with the special casing -cao | ||
| 46 | if (algorithm == "GenderClassification") { | 45 | if (algorithm == "GenderClassification") { |
| 47 | key = "Gender"; | 46 | key = "Gender"; |
| 48 | value = (f.get<QString>("Subject")); | 47 | value = (f.get<QString>("Subject")); |
| 49 | } else if (algorithm == "AgeRegression") { | 48 | } else if (algorithm == "AgeRegression") { |
| 50 | key = "Age"; | 49 | key = "Age"; |
| 51 | - value = QString::number(int(f.get<float>("Label")+0.5)) + " Years"; | 50 | + value = QString::number(int(f.get<float>("Subject")+0.5)) + " Years"; |
| 52 | } else { | 51 | } else { |
| 53 | key = algorithm; | 52 | key = algorithm; |
| 54 | - value = QString::number(f.get<float>("Label")); | 53 | + value = f.get<QString>("Subject"); |
| 55 | } | 54 | } |
| 56 | break; | 55 | break; |
| 57 | } | 56 | } |
openbr/openbr_plugin.cpp
| @@ -435,11 +435,13 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | @@ -435,11 +435,13 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | ||
| 435 | return templates; | 435 | return templates; |
| 436 | } | 436 | } |
| 437 | 437 | ||
| 438 | +// indexes some property, assigns an integer id to each unique value of propName | ||
| 439 | +// stores the index values in "Label" of the output template list -cao | ||
| 438 | TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propName) | 440 | TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propName) |
| 439 | { | 441 | { |
| 440 | - const QList<int> originalLabels = tl.collectValues<int>(propName); | ||
| 441 | - QHash<int,int> labelTable; | ||
| 442 | - foreach (int label, originalLabels) | 442 | + const QList<QString> originalLabels = tl.collectValues<QString>(propName); |
| 443 | + QHash<QString,int> labelTable; | ||
| 444 | + foreach (const QString & label, originalLabels) | ||
| 443 | if (!labelTable.contains(label)) | 445 | if (!labelTable.contains(label)) |
| 444 | labelTable.insert(label, labelTable.size()); | 446 | labelTable.insert(label, labelTable.size()); |
| 445 | 447 | ||
| @@ -449,6 +451,52 @@ TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propN | @@ -449,6 +451,52 @@ TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propN | ||
| 449 | return result; | 451 | return result; |
| 450 | } | 452 | } |
| 451 | 453 | ||
| 454 | +QList<int> TemplateList::indexProperty(const QString & propName, QHash<QString, int> * valueMap,QHash<int, QVariant> * reverseLookup) const | ||
| 455 | +{ | ||
| 456 | + QHash<QString, int> dummyForwards; | ||
| 457 | + QHash<int, QVariant> dummyBackwards; | ||
| 458 | + | ||
| 459 | + if (!valueMap) valueMap = &dummyForwards; | ||
| 460 | + if (!reverseLookup) reverseLookup = &dummyBackwards; | ||
| 461 | + | ||
| 462 | + return indexProperty(propName, *valueMap, *reverseLookup); | ||
| 463 | +} | ||
| 464 | + | ||
| 465 | +QList<int> TemplateList::indexProperty(const QString & propName, QHash<QString, int> & valueMap, QHash<int, QVariant> & reverseLookup) const | ||
| 466 | +{ | ||
| 467 | + valueMap.clear(); | ||
| 468 | + reverseLookup.clear(); | ||
| 469 | + | ||
| 470 | + const QList<QVariant> originalLabels = collectValues<QVariant>(propName); | ||
| 471 | + foreach (const QVariant & label, originalLabels) { | ||
| 472 | + QString labelString = label.toString(); | ||
| 473 | + if (!valueMap.contains(labelString)) { | ||
| 474 | + reverseLookup.insert(valueMap.size(), label); | ||
| 475 | + valueMap.insert(labelString, valueMap.size()); | ||
| 476 | + } | ||
| 477 | + } | ||
| 478 | + | ||
| 479 | + QList<int> result; | ||
| 480 | + for (int i=0; i<originalLabels.size(); i++) | ||
| 481 | + result.append(valueMap[originalLabels[i].toString()]); | ||
| 482 | + | ||
| 483 | + return result; | ||
| 484 | +} | ||
| 485 | + | ||
| 486 | +// uses -1 for missing values | ||
| 487 | +QList<int> TemplateList::applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const | ||
| 488 | +{ | ||
| 489 | + const QList<QString> originalLabels = collectValues<QString>(propName); | ||
| 490 | + | ||
| 491 | + QList<int> result; | ||
| 492 | + for (int i=0; i<originalLabels.size(); i++) { | ||
| 493 | + if (!valueMap.contains(originalLabels[i])) result.append(-1); | ||
| 494 | + else result.append(valueMap[originalLabels[i]]); | ||
| 495 | + } | ||
| 496 | + | ||
| 497 | + return result; | ||
| 498 | +} | ||
| 499 | + | ||
| 452 | /* Object - public methods */ | 500 | /* Object - public methods */ |
| 453 | QStringList Object::parameters() const | 501 | QStringList Object::parameters() const |
| 454 | { | 502 | { |
| @@ -989,7 +1037,8 @@ QString MatrixOutput::toString(int row, int column) const | @@ -989,7 +1037,8 @@ QString MatrixOutput::toString(int row, int column) const | ||
| 989 | { | 1037 | { |
| 990 | if (targetFiles[column] == "Subject") { | 1038 | if (targetFiles[column] == "Subject") { |
| 991 | const int label = data.at<float>(row,column); | 1039 | const int label = data.at<float>(row,column); |
| 992 | - return Globals->subjects.key(label, QString::number(label)); | 1040 | + // problem -cao |
| 1041 | + return QString::number(label); | ||
| 993 | } | 1042 | } |
| 994 | return QString::number(data.at<float>(row,column)); | 1043 | return QString::number(data.at<float>(row,column)); |
| 995 | } | 1044 | } |
openbr/openbr_plugin.h
| @@ -306,6 +306,15 @@ struct BR_EXPORT FileList : public QList<File> | @@ -306,6 +306,15 @@ struct BR_EXPORT FileList : public QList<File> | ||
| 306 | values.append(f.get<T>(propName)); | 306 | values.append(f.get<T>(propName)); |
| 307 | return values; | 307 | return values; |
| 308 | } | 308 | } |
| 309 | + template<typename T> | ||
| 310 | + QList<T> collectValues(const QString & propName, T defaultValue) const | ||
| 311 | + { | ||
| 312 | + QList<T> values; values.reserve(size()); | ||
| 313 | + foreach (const File &f, *this) | ||
| 314 | + values.append(f.contains(propName) ? f.get<T>(propName) : defaultValue); | ||
| 315 | + return values; | ||
| 316 | + } | ||
| 317 | + | ||
| 309 | QList<int> crossValidationPartitions() const; /*!< \brief Returns the cross-validation partition (default=0) for each file in the list. */ | 318 | QList<int> crossValidationPartitions() const; /*!< \brief Returns the cross-validation partition (default=0) for each file in the list. */ |
| 310 | int failures() const; /*!< \brief Returns the number of files with br::File::failed(). */ | 319 | int failures() const; /*!< \brief Returns the number of files with br::File::failed(). */ |
| 311 | }; | 320 | }; |
| @@ -393,6 +402,11 @@ struct TemplateList : public QList<Template> | @@ -393,6 +402,11 @@ struct TemplateList : public QList<Template> | ||
| 393 | /*!< \brief Ensure labels are in the range [0,numClasses-1]. */ | 402 | /*!< \brief Ensure labels are in the range [0,numClasses-1]. */ |
| 394 | BR_EXPORT static TemplateList relabel(const TemplateList & tl, const QString & propName); | 403 | BR_EXPORT static TemplateList relabel(const TemplateList & tl, const QString & propName); |
| 395 | 404 | ||
| 405 | + QList<int> indexProperty(const QString & propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const; | ||
| 406 | + QList<int> indexProperty(const QString & propName, QHash<QString, int> & valueMap, QHash<int, QVariant> & reverseLookup) const; | ||
| 407 | + QList<int> applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const; | ||
| 408 | + | ||
| 409 | + | ||
| 396 | /*! | 410 | /*! |
| 397 | * \brief Returns the total number of bytes in all the templates. | 411 | * \brief Returns the total number of bytes in all the templates. |
| 398 | */ | 412 | */ |
| @@ -658,7 +672,6 @@ public: | @@ -658,7 +672,6 @@ public: | ||
| 658 | BR_PROPERTY(int, crossValidate, 0) | 672 | BR_PROPERTY(int, crossValidate, 0) |
| 659 | 673 | ||
| 660 | QHash<QString,QString> abbreviations; /*!< \brief Used by br::Transform::make() to expand abbreviated algorithms into their complete definitions. */ | 674 | QHash<QString,QString> abbreviations; /*!< \brief Used by br::Transform::make() to expand abbreviated algorithms into their complete definitions. */ |
| 661 | - QHash<QString,int> subjects; /*!< \brief Used by classifiers to associate text class labels with unique integers IDs. */ | ||
| 662 | QTime startTime; /*!< \brief Used to estimate timeRemaining(). */ | 675 | QTime startTime; /*!< \brief Used to estimate timeRemaining(). */ |
| 663 | 676 | ||
| 664 | /*! | 677 | /*! |
openbr/plugins/eigen3.cpp
| @@ -329,7 +329,8 @@ class LDATransform : public Transform | @@ -329,7 +329,8 @@ class LDATransform : public Transform | ||
| 329 | 329 | ||
| 330 | void train(const TemplateList &_trainingSet) | 330 | void train(const TemplateList &_trainingSet) |
| 331 | { | 331 | { |
| 332 | - TemplateList trainingSet = TemplateList::relabel(_trainingSet, "Label"); | 332 | + // creates "Label" -cao |
| 333 | + TemplateList trainingSet = TemplateList::relabel(_trainingSet, "Subject"); | ||
| 333 | 334 | ||
| 334 | int instances = trainingSet.size(); | 335 | int instances = trainingSet.size(); |
| 335 | 336 | ||
| @@ -342,12 +343,13 @@ class LDATransform : public Transform | @@ -342,12 +343,13 @@ class LDATransform : public Transform | ||
| 342 | 343 | ||
| 343 | TemplateList ldaTrainingSet; | 344 | TemplateList ldaTrainingSet; |
| 344 | static_cast<Transform*>(&pca)->project(trainingSet, ldaTrainingSet); | 345 | static_cast<Transform*>(&pca)->project(trainingSet, ldaTrainingSet); |
| 346 | + // Reindex label, is this still necessary? -cao | ||
| 345 | ldaTrainingSet = TemplateList::relabel(ldaTrainingSet, "Label"); | 347 | ldaTrainingSet = TemplateList::relabel(ldaTrainingSet, "Label"); |
| 346 | 348 | ||
| 347 | int dimsIn = ldaTrainingSet.first().m().rows * ldaTrainingSet.first().m().cols; | 349 | int dimsIn = ldaTrainingSet.first().m().rows * ldaTrainingSet.first().m().cols; |
| 348 | 350 | ||
| 349 | // OpenBR ensures that class values range from 0 to numClasses-1. | 351 | // OpenBR ensures that class values range from 0 to numClasses-1. |
| 350 | - // Assumed label is stored as float or int? -cao | 352 | + // Label exists because we created it earlier with relabel |
| 351 | QList<int> classes = trainingSet.collectValues<int>("Label"); | 353 | QList<int> classes = trainingSet.collectValues<int>("Label"); |
| 352 | QMap<int, int> classCounts = trainingSet.countValues<int>("Label"); | 354 | QMap<int, int> classCounts = trainingSet.countValues<int>("Label"); |
| 353 | const int numClasses = classCounts.size(); | 355 | const int numClasses = classCounts.size(); |
openbr/plugins/gallery.cpp
| @@ -65,7 +65,7 @@ class arffGallery : public Gallery | @@ -65,7 +65,7 @@ class arffGallery : public Gallery | ||
| 65 | const int dimensions = t.m().rows * t.m().cols; | 65 | const int dimensions = t.m().rows * t.m().cols; |
| 66 | for (int i=0; i<dimensions; i++) | 66 | for (int i=0; i<dimensions; i++) |
| 67 | arffFile.write(qPrintable("@ATTRIBUTE v" + QString::number(i) + " REAL\n")); | 67 | arffFile.write(qPrintable("@ATTRIBUTE v" + QString::number(i) + " REAL\n")); |
| 68 | - arffFile.write(qPrintable("@ATTRIBUTE class {" + QStringList(Globals->subjects.keys()).join(',') + "}\n")); | 68 | + arffFile.write(qPrintable("@ATTRIBUTE class string\n")); |
| 69 | 69 | ||
| 70 | arffFile.write("\n@DATA\n"); | 70 | arffFile.write("\n@DATA\n"); |
| 71 | } | 71 | } |
| @@ -518,9 +518,8 @@ class csvGallery : public Gallery | @@ -518,9 +518,8 @@ class csvGallery : public Gallery | ||
| 518 | static QString getCSVElement(const QString &key, const QVariant &value, bool header) | 518 | static QString getCSVElement(const QString &key, const QVariant &value, bool header) |
| 519 | { | 519 | { |
| 520 | if ((key == "Label") && !header) { | 520 | if ((key == "Label") && !header) { |
| 521 | - QString stringLabel = Globals->subjects.key(value.value<int>()); | ||
| 522 | - if (stringLabel.isEmpty()) return value.value<QString>(); | ||
| 523 | - else return stringLabel; | 521 | + // problem -cao |
| 522 | + return value.value<QString>(); | ||
| 524 | } else if (value.canConvert<QString>()) { | 523 | } else if (value.canConvert<QString>()) { |
| 525 | if (header) return key; | 524 | if (header) return key; |
| 526 | else return value.value<QString>(); | 525 | else return value.value<QString>(); |
openbr/plugins/independent.cpp
| @@ -20,21 +20,23 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t | @@ -20,21 +20,23 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t | ||
| 20 | const bool atLeast = transform->instances < 0; | 20 | const bool atLeast = transform->instances < 0; |
| 21 | const int instances = abs(transform->instances); | 21 | const int instances = abs(transform->instances); |
| 22 | 22 | ||
| 23 | - QList<int> allLabels = templates.collectValues<int>("Label"); | ||
| 24 | - QList<int> uniqueLabels = allLabels.toSet().toList(); | 23 | + QList<QString> allLabels = templates.collectValues<QString>("Subject"); |
| 24 | + QList<QString> uniqueLabels = allLabels.toSet().toList(); | ||
| 25 | qSort(uniqueLabels); | 25 | qSort(uniqueLabels); |
| 26 | 26 | ||
| 27 | - QMap<int,int> counts = templates.countValues<int>("Label", instances != std::numeric_limits<int>::max()); | 27 | + QMap<QString,int> counts = templates.countValues<QString>("Subject", instances != std::numeric_limits<int>::max()); |
| 28 | + | ||
| 28 | if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max())) | 29 | if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max())) |
| 29 | - foreach (int label, counts.keys()) | 30 | + foreach (const QString & label, counts.keys()) |
| 30 | if (counts[label] < instances) | 31 | if (counts[label] < instances) |
| 31 | counts.remove(label); | 32 | counts.remove(label); |
| 33 | + | ||
| 32 | uniqueLabels = counts.keys(); | 34 | uniqueLabels = counts.keys(); |
| 33 | if ((transform->classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < transform->classes)) | 35 | if ((transform->classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < transform->classes)) |
| 34 | qWarning("Downsample requested %d classes but only %d are available.", transform->classes, uniqueLabels.size()); | 36 | qWarning("Downsample requested %d classes but only %d are available.", transform->classes, uniqueLabels.size()); |
| 35 | 37 | ||
| 36 | Common::seedRNG(); | 38 | Common::seedRNG(); |
| 37 | - QList<int> selectedLabels = uniqueLabels; | 39 | + QList<QString> selectedLabels = uniqueLabels; |
| 38 | if (transform->classes < uniqueLabels.size()) { | 40 | if (transform->classes < uniqueLabels.size()) { |
| 39 | std::random_shuffle(selectedLabels.begin(), selectedLabels.end()); | 41 | std::random_shuffle(selectedLabels.begin(), selectedLabels.end()); |
| 40 | selectedLabels = selectedLabels.mid(0, transform->classes); | 42 | selectedLabels = selectedLabels.mid(0, transform->classes); |
| @@ -42,7 +44,7 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t | @@ -42,7 +44,7 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t | ||
| 42 | 44 | ||
| 43 | TemplateList downsample; | 45 | TemplateList downsample; |
| 44 | for (int i=0; i<selectedLabels.size(); i++) { | 46 | for (int i=0; i<selectedLabels.size(); i++) { |
| 45 | - const int selectedLabel = selectedLabels[i]; | 47 | + const QString selectedLabel = selectedLabels[i]; |
| 46 | QList<int> indices; | 48 | QList<int> indices; |
| 47 | for (int j=0; j<allLabels.size(); j++) | 49 | for (int j=0; j<allLabels.size(); j++) |
| 48 | if ((allLabels[j] == selectedLabel) && (!templates.value(j).file.get<bool>("FTE", false))) | 50 | if ((allLabels[j] == selectedLabel) && (!templates.value(j).file.get<bool>("FTE", false))) |
openbr/plugins/meta.cpp
| @@ -453,7 +453,8 @@ private: | @@ -453,7 +453,8 @@ private: | ||
| 453 | const QString &file = src.file; | 453 | const QString &file = src.file; |
| 454 | if (cache.contains(file)) { | 454 | if (cache.contains(file)) { |
| 455 | dst = cache[file]; | 455 | dst = cache[file]; |
| 456 | - dst.file.set("Label", src.file.value("Label")); | 456 | + // don't get this -cao |
| 457 | +// dst.file.set("Label", src.file.value("Label")); | ||
| 457 | } else { | 458 | } else { |
| 458 | transform->project(src, dst); | 459 | transform->project(src, dst); |
| 459 | cacheLock.lock(); | 460 | cacheLock.lock(); |
openbr/plugins/normalize.cpp
| @@ -126,7 +126,8 @@ private: | @@ -126,7 +126,8 @@ private: | ||
| 126 | { | 126 | { |
| 127 | Mat m; | 127 | Mat m; |
| 128 | OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F); | 128 | OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F); |
| 129 | - const QList<int> labels = data.collectValues<int>("Label"); | 129 | + |
| 130 | + const QList<int> labels = data.indexProperty("Subject"); | ||
| 130 | const int dims = m.cols; | 131 | const int dims = m.cols; |
| 131 | 132 | ||
| 132 | vector<Mat> mv, av, bv; | 133 | vector<Mat> mv, av, bv; |
openbr/plugins/output.cpp
| @@ -153,8 +153,10 @@ class meltOutput : public MatrixOutput | @@ -153,8 +153,10 @@ class meltOutput : public MatrixOutput | ||
| 153 | 153 | ||
| 154 | QStringList lines; | 154 | QStringList lines; |
| 155 | if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys)); | 155 | if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys)); |
| 156 | - QList<float> queryLabels = queryFiles.collectValues<float>("Label"); | ||
| 157 | - QList<float> targetLabels = targetFiles.collectValues<float>("Label"); | 156 | + |
| 157 | + QList<QString> queryLabels = queryFiles.collectValues<QString>("Subject"); | ||
| 158 | + QList<QString> targetLabels = targetFiles.collectValues<QString>("Subject"); | ||
| 159 | + | ||
| 158 | for (int i=0; i<queryFiles.size(); i++) { | 160 | for (int i=0; i<queryFiles.size(); i++) { |
| 159 | for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) { | 161 | for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) { |
| 160 | const bool genuine = queryLabels[i] == targetLabels[j]; | 162 | const bool genuine = queryLabels[i] == targetLabels[j]; |
openbr/plugins/pixel.cpp
| @@ -28,6 +28,7 @@ namespace br | @@ -28,6 +28,7 @@ namespace br | ||
| 28 | */ | 28 | */ |
| 29 | class PerPixelClassifierTransform : public MetaTransform | 29 | class PerPixelClassifierTransform : public MetaTransform |
| 30 | { | 30 | { |
| 31 | + // problematic -cao | ||
| 31 | Q_OBJECT | 32 | Q_OBJECT |
| 32 | Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform) | 33 | Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform) |
| 33 | Q_PROPERTY(int pixels READ get_pixels WRITE set_pixels RESET reset_pixels STORED false) | 34 | Q_PROPERTY(int pixels READ get_pixels WRITE set_pixels RESET reset_pixels STORED false) |
| @@ -273,6 +274,7 @@ BR_REGISTER(Transform, ToBinaryVectorTransform) | @@ -273,6 +274,7 @@ BR_REGISTER(Transform, ToBinaryVectorTransform) | ||
| 273 | * \author E. Taborsky \cite mmtaborsky | 274 | * \author E. Taborsky \cite mmtaborsky |
| 274 | */ | 275 | */ |
| 275 | 276 | ||
| 277 | +// What does this do? -cao | ||
| 276 | class ToMetadataTransform : public UntrainableMetaTransform | 278 | class ToMetadataTransform : public UntrainableMetaTransform |
| 277 | { | 279 | { |
| 278 | Q_OBJECT | 280 | Q_OBJECT |
openbr/plugins/quality.cpp
| @@ -26,10 +26,10 @@ class ImpostorUniquenessMeasureTransform : public Transform | @@ -26,10 +26,10 @@ class ImpostorUniquenessMeasureTransform : public Transform | ||
| 26 | 26 | ||
| 27 | float calculateIUM(const Template &probe, const TemplateList &gallery) const | 27 | float calculateIUM(const Template &probe, const TemplateList &gallery) const |
| 28 | { | 28 | { |
| 29 | - const int probeLabel = probe.file.get<float>("Label"); | 29 | + const QString probeLabel = probe.file.get<QString>("Subject"); |
| 30 | TemplateList subset = gallery; | 30 | TemplateList subset = gallery; |
| 31 | for (int j=subset.size()-1; j>=0; j--) | 31 | for (int j=subset.size()-1; j>=0; j--) |
| 32 | - if (subset[j].file.get<int>("Label") == probeLabel) | 32 | + if (subset[j].file.get<QString>("Subject") == probeLabel) |
| 33 | subset.removeAt(j); | 33 | subset.removeAt(j); |
| 34 | 34 | ||
| 35 | QList<float> scores = distance->compare(subset, probe); | 35 | QList<float> scores = distance->compare(subset, probe); |
| @@ -159,7 +159,7 @@ class MatchProbabilityDistance : public Distance | @@ -159,7 +159,7 @@ class MatchProbabilityDistance : public Distance | ||
| 159 | { | 159 | { |
| 160 | distance->train(src); | 160 | distance->train(src); |
| 161 | 161 | ||
| 162 | - const QList<int> labels = src.collectValues<int>("Label"); | 162 | + const QList<int> labels = src.indexProperty("Subject"); |
| 163 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); | 163 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); |
| 164 | distance->compare(src, src, matrixOutput.data()); | 164 | distance->compare(src, src, matrixOutput.data()); |
| 165 | 165 | ||
| @@ -219,7 +219,7 @@ class UnitDistance : public Distance | @@ -219,7 +219,7 @@ class UnitDistance : public Distance | ||
| 219 | void train(const TemplateList &templates) | 219 | void train(const TemplateList &templates) |
| 220 | { | 220 | { |
| 221 | const TemplateList samples = templates.mid(0, 2000); | 221 | const TemplateList samples = templates.mid(0, 2000); |
| 222 | - const QList<float> sampleLabels = samples.collectValues<float>("Label"); | 222 | + const QList<int> sampleLabels = samples.indexProperty("Subject"); |
| 223 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size()))); | 223 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size()))); |
| 224 | Distance::compare(samples, samples, matrixOutput.data()); | 224 | Distance::compare(samples, samples, matrixOutput.data()); |
| 225 | 225 |
openbr/plugins/quantize.cpp
| @@ -150,7 +150,7 @@ class BayesianQuantizationDistance : public Distance | @@ -150,7 +150,7 @@ class BayesianQuantizationDistance : public Distance | ||
| 150 | qFatal("Expected sigle matrix templates of type CV_8UC1!"); | 150 | qFatal("Expected sigle matrix templates of type CV_8UC1!"); |
| 151 | 151 | ||
| 152 | const Mat data = OpenCVUtils::toMat(src.data()); | 152 | const Mat data = OpenCVUtils::toMat(src.data()); |
| 153 | - const QList<int> templateLabels = src.collectValues<int>("Label"); | 153 | + const QList<int> templateLabels = src.indexProperty("Subject"); |
| 154 | loglikelihoods = QVector<float>(data.cols*256, 0); | 154 | loglikelihoods = QVector<float>(data.cols*256, 0); |
| 155 | 155 | ||
| 156 | QFutureSynchronizer<void> futures; | 156 | QFutureSynchronizer<void> futures; |
| @@ -473,7 +473,8 @@ private: | @@ -473,7 +473,8 @@ private: | ||
| 473 | { | 473 | { |
| 474 | Mat data = OpenCVUtils::toMat(src.data()); | 474 | Mat data = OpenCVUtils::toMat(src.data()); |
| 475 | const int step = getStep(data.cols); | 475 | const int step = getStep(data.cols); |
| 476 | - const QList<int> labels = src.collectValues<int>("Label"); | 476 | + |
| 477 | + const QList<int> labels = src.indexProperty("Subject"); | ||
| 477 | 478 | ||
| 478 | Mat &lut = ProductQuantizationLUTs[index]; | 479 | Mat &lut = ProductQuantizationLUTs[index]; |
| 479 | lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1); | 480 | lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1); |
openbr/plugins/quantize2.cpp
| @@ -77,7 +77,7 @@ class BayesianQuantizationTransform : public Transform | @@ -77,7 +77,7 @@ class BayesianQuantizationTransform : public Transform | ||
| 77 | void train(const TemplateList &src) | 77 | void train(const TemplateList &src) |
| 78 | { | 78 | { |
| 79 | const Mat data = OpenCVUtils::toMat(src.data()); | 79 | const Mat data = OpenCVUtils::toMat(src.data()); |
| 80 | - const QList<int> labels = src.collectValues<int>("Label"); | 80 | + const QList<int> labels = src.indexProperty("Subject"); |
| 81 | 81 | ||
| 82 | thresholds = QVector<float>(256*data.cols); | 82 | thresholds = QVector<float>(256*data.cols); |
| 83 | 83 |
openbr/plugins/svm.cpp
| @@ -121,28 +121,46 @@ private: | @@ -121,28 +121,46 @@ private: | ||
| 121 | BR_PROPERTY(float, gamma, -1) | 121 | BR_PROPERTY(float, gamma, -1) |
| 122 | 122 | ||
| 123 | SVM svm; | 123 | SVM svm; |
| 124 | + QHash<QString, int> labelMap; | ||
| 125 | + QHash<int, QVariant> reverseLookup; | ||
| 124 | 126 | ||
| 125 | void train(const TemplateList &_data) | 127 | void train(const TemplateList &_data) |
| 126 | { | 128 | { |
| 127 | Mat data = OpenCVUtils::toMat(_data.data()); | 129 | Mat data = OpenCVUtils::toMat(_data.data()); |
| 128 | - Mat lab = OpenCVUtils::toMat(_data.collectValues<float>("Label")); | 130 | + Mat lab; |
| 131 | + // If we are doing regression, assume subject has float values | ||
| 132 | + if (type == EPS_SVR || type == NU_SVR) { | ||
| 133 | + lab = OpenCVUtils::toMat(_data.collectValues<float>("Subject")); | ||
| 134 | + } | ||
| 135 | + // If we are doing classification, assume subject has discrete values, map them | ||
| 136 | + // and store the mapping data | ||
| 137 | + else { | ||
| 138 | + QList<int> dataLabels = _data.indexProperty("Subject", labelMap, reverseLookup); | ||
| 139 | + lab = OpenCVUtils::toMat(dataLabels); | ||
| 140 | + } | ||
| 129 | trainSVM(svm, data, lab, kernel, type, C, gamma); | 141 | trainSVM(svm, data, lab, kernel, type, C, gamma); |
| 130 | } | 142 | } |
| 131 | 143 | ||
| 132 | void project(const Template &src, Template &dst) const | 144 | void project(const Template &src, Template &dst) const |
| 133 | { | 145 | { |
| 134 | dst = src; | 146 | dst = src; |
| 135 | - dst.file.set("Label", svm.predict(src.m().reshape(1, 1))); | 147 | + float prediction = svm.predict(src.m().reshape(1, 1)); |
| 148 | + if (type == EPS_SVR || type == NU_SVR) | ||
| 149 | + dst.file.set("Subject", prediction); | ||
| 150 | + else | ||
| 151 | + dst.file.set("Subject", reverseLookup[prediction]); | ||
| 136 | } | 152 | } |
| 137 | 153 | ||
| 138 | void store(QDataStream &stream) const | 154 | void store(QDataStream &stream) const |
| 139 | { | 155 | { |
| 140 | storeSVM(svm, stream); | 156 | storeSVM(svm, stream); |
| 157 | + stream << labelMap << reverseLookup; | ||
| 141 | } | 158 | } |
| 142 | 159 | ||
| 143 | void load(QDataStream &stream) | 160 | void load(QDataStream &stream) |
| 144 | { | 161 | { |
| 145 | loadSVM(svm, stream); | 162 | loadSVM(svm, stream); |
| 163 | + stream >> labelMap >> reverseLookup; | ||
| 146 | } | 164 | } |
| 147 | }; | 165 | }; |
| 148 | 166 |