Commit 5aec6fb362267eccc48eece9753d2848777c0ac4
Merge pull request #369 from biometrics/bootStrap
Random test/train split in TemplateList::Partition
Showing
7 changed files
with
39 additions
and
16 deletions
app/br/br.cpp
| @@ -69,6 +69,7 @@ public: | @@ -69,6 +69,7 @@ public: | ||
| 69 | 69 | ||
| 70 | bool daemon = false; | 70 | bool daemon = false; |
| 71 | const char *daemon_pipe = NULL; | 71 | const char *daemon_pipe = NULL; |
| 72 | + bool isInt = false; | ||
| 72 | while (daemon || (argc > 0)) { | 73 | while (daemon || (argc > 0)) { |
| 73 | const char *fun; | 74 | const char *fun; |
| 74 | int parc; | 75 | int parc; |
| @@ -78,7 +79,12 @@ public: | @@ -78,7 +79,12 @@ public: | ||
| 78 | 79 | ||
| 79 | fun = argv[0]; | 80 | fun = argv[0]; |
| 80 | if (fun[0] == '-') fun++; | 81 | if (fun[0] == '-') fun++; |
| 81 | - parc = 0; while ((parc+1 < argc) && (argv[parc+1][0] != '-')) parc++; | 82 | + parc = 0; |
| 83 | + QString(argv[parc+1]).toInt(&isInt); | ||
| 84 | + while ((parc+1 < argc) && ((argv[parc+1][0] != '-') || isInt)) { | ||
| 85 | + parc++; | ||
| 86 | + QString(argv[parc+1]).toInt(&isInt); | ||
| 87 | + } | ||
| 82 | parv = (const char **)&argv[1]; | 88 | parv = (const char **)&argv[1]; |
| 83 | argc = argc - (parc+1); | 89 | argc = argc - (parc+1); |
| 84 | argv = &argv[parc+1]; | 90 | argv = &argv[parc+1]; |
openbr/core/bee.cpp
| @@ -230,7 +230,7 @@ void makeMask(const QString &targetInput, const QString &queryInput, const QStri | @@ -230,7 +230,7 @@ void makeMask(const QString &targetInput, const QString &queryInput, const QStri | ||
| 230 | const FileList targets = TemplateList::fromGallery(targetInput).files(); | 230 | const FileList targets = TemplateList::fromGallery(targetInput).files(); |
| 231 | const FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files(); | 231 | const FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files(); |
| 232 | const int partitions = targets.first().get<int>("crossValidate"); | 232 | const int partitions = targets.first().get<int>("crossValidate"); |
| 233 | - if (partitions == 0) { | 233 | + if (partitions <= 0) { |
| 234 | writeMatrix(makeMask(targets, queries), mask, targetInput, queryInput); | 234 | writeMatrix(makeMask(targets, queries), mask, targetInput, queryInput); |
| 235 | } else { | 235 | } else { |
| 236 | if (!mask.contains("%1")) qFatal("Mask file name missing partition number place marker (%%1)"); | 236 | if (!mask.contains("%1")) qFatal("Mask file name missing partition number place marker (%%1)"); |
| @@ -246,7 +246,7 @@ void makePairwiseMask(const QString &targetInput, const QString &queryInput, con | @@ -246,7 +246,7 @@ void makePairwiseMask(const QString &targetInput, const QString &queryInput, con | ||
| 246 | const FileList targets = TemplateList::fromGallery(targetInput).files(); | 246 | const FileList targets = TemplateList::fromGallery(targetInput).files(); |
| 247 | const FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files(); | 247 | const FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files(); |
| 248 | const int partitions = targets.first().get<int>("crossValidate"); | 248 | const int partitions = targets.first().get<int>("crossValidate"); |
| 249 | - if (partitions == 0) { | 249 | + if (partitions <= 0) { |
| 250 | writeMatrix(makePairwiseMask(targets, queries), mask, targetInput, queryInput); | 250 | writeMatrix(makePairwiseMask(targets, queries), mask, targetInput, queryInput); |
| 251 | } else { | 251 | } else { |
| 252 | if (!mask.contains("%1")) qFatal("Mask file name missing partition number place marker (%%1)"); | 252 | if (!mask.contains("%1")) qFatal("Mask file name missing partition number place marker (%%1)"); |
openbr/core/core.cpp
| @@ -69,7 +69,7 @@ struct AlgorithmCore | @@ -69,7 +69,7 @@ struct AlgorithmCore | ||
| 69 | QScopedPointer<Transform> trainingWrapper(br::wrapTransform(transform.data(), "Stream(readMode=DistributeFrames)")); | 69 | QScopedPointer<Transform> trainingWrapper(br::wrapTransform(transform.data(), "Stream(readMode=DistributeFrames)")); |
| 70 | TemplateList data(TemplateList::fromGallery(input)); | 70 | TemplateList data(TemplateList::fromGallery(input)); |
| 71 | 71 | ||
| 72 | - if (Globals->crossValidate > 1) | 72 | + if (abs(Globals->crossValidate) > 1) |
| 73 | for (int i=data.size()-1; i>=0; i--) | 73 | for (int i=data.size()-1; i>=0; i--) |
| 74 | if (data[i].file.get<bool>("allPartitions",false) || data[i].file.get<bool>("duplicatePartitions",false)) | 74 | if (data[i].file.get<bool>("allPartitions",false) || data[i].file.get<bool>("duplicatePartitions",false)) |
| 75 | data.removeAt(i); | 75 | data.removeAt(i); |
openbr/openbr_plugin.cpp
| @@ -28,6 +28,7 @@ | @@ -28,6 +28,7 @@ | ||
| 28 | #include <QtConcurrentRun> | 28 | #include <QtConcurrentRun> |
| 29 | #include <algorithm> | 29 | #include <algorithm> |
| 30 | #include <iostream> | 30 | #include <iostream> |
| 31 | +#include <RandomLib/Random.hpp> | ||
| 31 | 32 | ||
| 32 | #ifndef BR_EMBEDDED | 33 | #ifndef BR_EMBEDDED |
| 33 | #include <QApplication> | 34 | #include <QApplication> |
| @@ -433,7 +434,7 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | @@ -433,7 +434,7 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | ||
| 433 | if (gallery.getBool("reduce")) | 434 | if (gallery.getBool("reduce")) |
| 434 | newTemplates = newTemplates.reduced(); | 435 | newTemplates = newTemplates.reduced(); |
| 435 | 436 | ||
| 436 | - if (Globals->crossValidate > 1) | 437 | + if (abs(Globals->crossValidate) > 1) |
| 437 | newTemplates = newTemplates.partition("Label"); | 438 | newTemplates = newTemplates.partition("Label"); |
| 438 | 439 | ||
| 439 | if (!templates.isEmpty() && gallery.get<bool>("merge", false)) { | 440 | if (!templates.isEmpty() && gallery.get<bool>("merge", false)) { |
| @@ -528,12 +529,16 @@ QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString | @@ -528,12 +529,16 @@ QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString | ||
| 528 | return result; | 529 | return result; |
| 529 | } | 530 | } |
| 530 | 531 | ||
| 531 | -TemplateList TemplateList::partition(const QString &inputVariable) const | 532 | +TemplateList TemplateList::partition(const QString &inputVariable, unsigned int randomSeed) const |
| 532 | { | 533 | { |
| 533 | - const int crossValidate = Globals->crossValidate; | 534 | + const int crossValidate = std::abs(Globals->crossValidate); |
| 534 | if (crossValidate < 2) | 535 | if (crossValidate < 2) |
| 535 | return *this; | 536 | return *this; |
| 536 | 537 | ||
| 538 | + // Use separate RNG from Common::randN() to avoid re-seeding the global RNG. | ||
| 539 | + // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same randomSeed seed. | ||
| 540 | + RandomLib::Random rng; | ||
| 541 | + | ||
| 537 | TemplateList partitioned = *this; | 542 | TemplateList partitioned = *this; |
| 538 | for (int i=partitioned.size()-1; i>=0; i--) { | 543 | for (int i=partitioned.size()-1; i>=0; i--) { |
| 539 | // See CrossValidateTransform for description of these variables | 544 | // See CrossValidateTransform for description of these variables |
| @@ -546,8 +551,12 @@ TemplateList TemplateList::partition(const QString &inputVariable) const | @@ -546,8 +551,12 @@ TemplateList TemplateList::partition(const QString &inputVariable) const | ||
| 546 | } else if (partitioned[i].file.getBool("allPartitions")) { | 551 | } else if (partitioned[i].file.getBool("allPartitions")) { |
| 547 | partitioned[i].file.set("Partition", -1); | 552 | partitioned[i].file.set("Partition", -1); |
| 548 | } else { | 553 | } else { |
| 549 | - if (!partitioned[i].file.contains(("Partition"))) { | ||
| 550 | - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); | 554 | + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); |
| 555 | + if (randomSeed) { | ||
| 556 | + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed; | ||
| 557 | + rng.Reseed(labelSeed); | ||
| 558 | + partitioned[i].file.set("Partition", rng.Integer() % crossValidate); | ||
| 559 | + } else if (!partitioned[i].file.contains(("Partition"))) { | ||
| 551 | // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow | 560 | // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow |
| 552 | partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); | 561 | partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); |
| 553 | } | 562 | } |
openbr/openbr_plugin.h
| @@ -459,7 +459,7 @@ struct TemplateList : public QList<Template> | @@ -459,7 +459,7 @@ struct TemplateList : public QList<Template> | ||
| 459 | BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers); | 459 | BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers); |
| 460 | 460 | ||
| 461 | /*!< \brief Assign templates to folds partitions. */ | 461 | /*!< \brief Assign templates to folds partitions. */ |
| 462 | - BR_EXPORT TemplateList partition(const QString &inputVariable) const; | 462 | + BR_EXPORT TemplateList partition(const QString &inputVariable, unsigned int randomSeed = 0) const; |
| 463 | 463 | ||
| 464 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const; | 464 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const; |
| 465 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const; | 465 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const; |
openbr/plugins/core/crossvalidate.cpp
| @@ -56,8 +56,10 @@ class CrossValidateTransform : public MetaTransform | @@ -56,8 +56,10 @@ class CrossValidateTransform : public MetaTransform | ||
| 56 | Q_OBJECT | 56 | Q_OBJECT |
| 57 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) | 57 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) |
| 58 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | 58 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 59 | + Q_PROPERTY(unsigned int randomSeed READ get_randomSeed WRITE set_randomSeed RESET reset_randomSeed STORED false) | ||
| 59 | BR_PROPERTY(QString, description, "Identity") | 60 | BR_PROPERTY(QString, description, "Identity") |
| 60 | BR_PROPERTY(QString, inputVariable, "Label") | 61 | BR_PROPERTY(QString, inputVariable, "Label") |
| 62 | + BR_PROPERTY(unsigned int, randomSeed, 0) | ||
| 61 | 63 | ||
| 62 | // numPartitions copies of transform specified by description. | 64 | // numPartitions copies of transform specified by description. |
| 63 | QList<br::Transform*> transforms; | 65 | QList<br::Transform*> transforms; |
| @@ -67,13 +69,14 @@ class CrossValidateTransform : public MetaTransform | @@ -67,13 +69,14 @@ class CrossValidateTransform : public MetaTransform | ||
| 67 | // is generally incorrect behavior. | 69 | // is generally incorrect behavior. |
| 68 | void train(const TemplateList &data) | 70 | void train(const TemplateList &data) |
| 69 | { | 71 | { |
| 70 | - QList<int> partitions = data.partition(inputVariable).files().crossValidationPartitions(); | ||
| 71 | - const int numPartitions = Common::Max(partitions)+1; | ||
| 72 | - | 72 | + QList<int> partitions = data.partition(inputVariable, randomSeed).files().crossValidationPartitions(); |
| 73 | + const int crossValidate = Globals->crossValidate; | ||
| 74 | + // Only train once based on the 0th partition if crossValidate is negative. | ||
| 75 | + const int numPartitions = (crossValidate < 0) ? 1 : Common::Max(partitions)+1; | ||
| 73 | while (transforms.size() < numPartitions) | 76 | while (transforms.size() < numPartitions) |
| 74 | transforms.append(make(description)); | 77 | transforms.append(make(description)); |
| 75 | 78 | ||
| 76 | - if (numPartitions < 2) { | 79 | + if (std::abs(crossValidate) < 2) { |
| 77 | transforms.first()->train(data); | 80 | transforms.first()->train(data); |
| 78 | return; | 81 | return; |
| 79 | } | 82 | } |
| @@ -101,8 +104,13 @@ class CrossValidateTransform : public MetaTransform | @@ -101,8 +104,13 @@ class CrossValidateTransform : public MetaTransform | ||
| 101 | 104 | ||
| 102 | void project(const TemplateList &src, TemplateList &dst) const | 105 | void project(const TemplateList &src, TemplateList &dst) const |
| 103 | { | 106 | { |
| 104 | - TemplateList partitioned = src.partition(inputVariable); | 107 | + TemplateList partitioned = src.partition(inputVariable, randomSeed); |
| 108 | + const int crossValidate = Globals->crossValidate; | ||
| 105 | 109 | ||
| 110 | + if (crossValidate < 0) { | ||
| 111 | + transforms[0]->project(partitioned, dst); | ||
| 112 | + return; | ||
| 113 | + } | ||
| 106 | for (int i=0; i<partitioned.size(); i++) { | 114 | for (int i=0; i<partitioned.size(); i++) { |
| 107 | int partition = partitioned[i].file.get<int>("Partition", 0); | 115 | int partition = partitioned[i].file.get<int>("Partition", 0); |
| 108 | transforms[partition]->project(partitioned, dst); | 116 | transforms[partition]->project(partitioned, dst); |
openbr/plugins/output/eval.cpp
| @@ -45,7 +45,7 @@ class evalOutput : public MatrixOutput | @@ -45,7 +45,7 @@ class evalOutput : public MatrixOutput | ||
| 45 | 45 | ||
| 46 | if (data.data) { | 46 | if (data.data) { |
| 47 | const QString csv = QString(file.name).replace(".eval", ".csv"); | 47 | const QString csv = QString(file.name).replace(".eval", ".csv"); |
| 48 | - if ((Globals->crossValidate == 0) || (!crossValidate)) { | 48 | + if ((Globals->crossValidate <= 0) || (!crossValidate)) { |
| 49 | Evaluate(data, targetFiles, queryFiles, csv); | 49 | Evaluate(data, targetFiles, queryFiles, csv); |
| 50 | } else { | 50 | } else { |
| 51 | QFutureSynchronizer<float> futures; | 51 | QFutureSynchronizer<float> futures; |