Commit 5aec6fb362267eccc48eece9753d2848777c0ac4
Merge pull request #369 from biometrics/bootStrap
Random test/train split in TemplateList::Partition
Showing
7 changed files
with
39 additions
and
16 deletions
app/br/br.cpp
| ... | ... | @@ -69,6 +69,7 @@ public: |
| 69 | 69 | |
| 70 | 70 | bool daemon = false; |
| 71 | 71 | const char *daemon_pipe = NULL; |
| 72 | + bool isInt = false; | |
| 72 | 73 | while (daemon || (argc > 0)) { |
| 73 | 74 | const char *fun; |
| 74 | 75 | int parc; |
| ... | ... | @@ -78,7 +79,12 @@ public: |
| 78 | 79 | |
| 79 | 80 | fun = argv[0]; |
| 80 | 81 | if (fun[0] == '-') fun++; |
| 81 | - parc = 0; while ((parc+1 < argc) && (argv[parc+1][0] != '-')) parc++; | |
| 82 | + parc = 0; | |
| 83 | + QString(argv[parc+1]).toInt(&isInt); | |
| 84 | + while ((parc+1 < argc) && ((argv[parc+1][0] != '-') || isInt)) { | |
| 85 | + parc++; | |
| 86 | + QString(argv[parc+1]).toInt(&isInt); | |
| 87 | + } | |
| 82 | 88 | parv = (const char **)&argv[1]; |
| 83 | 89 | argc = argc - (parc+1); |
| 84 | 90 | argv = &argv[parc+1]; | ... | ... |
openbr/core/bee.cpp
| ... | ... | @@ -230,7 +230,7 @@ void makeMask(const QString &targetInput, const QString &queryInput, const QStri |
| 230 | 230 | const FileList targets = TemplateList::fromGallery(targetInput).files(); |
| 231 | 231 | const FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files(); |
| 232 | 232 | const int partitions = targets.first().get<int>("crossValidate"); |
| 233 | - if (partitions == 0) { | |
| 233 | + if (partitions <= 0) { | |
| 234 | 234 | writeMatrix(makeMask(targets, queries), mask, targetInput, queryInput); |
| 235 | 235 | } else { |
| 236 | 236 | if (!mask.contains("%1")) qFatal("Mask file name missing partition number place marker (%%1)"); |
| ... | ... | @@ -246,7 +246,7 @@ void makePairwiseMask(const QString &targetInput, const QString &queryInput, con |
| 246 | 246 | const FileList targets = TemplateList::fromGallery(targetInput).files(); |
| 247 | 247 | const FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files(); |
| 248 | 248 | const int partitions = targets.first().get<int>("crossValidate"); |
| 249 | - if (partitions == 0) { | |
| 249 | + if (partitions <= 0) { | |
| 250 | 250 | writeMatrix(makePairwiseMask(targets, queries), mask, targetInput, queryInput); |
| 251 | 251 | } else { |
| 252 | 252 | if (!mask.contains("%1")) qFatal("Mask file name missing partition number place marker (%%1)"); | ... | ... |
openbr/core/core.cpp
| ... | ... | @@ -69,7 +69,7 @@ struct AlgorithmCore |
| 69 | 69 | QScopedPointer<Transform> trainingWrapper(br::wrapTransform(transform.data(), "Stream(readMode=DistributeFrames)")); |
| 70 | 70 | TemplateList data(TemplateList::fromGallery(input)); |
| 71 | 71 | |
| 72 | - if (Globals->crossValidate > 1) | |
| 72 | + if (abs(Globals->crossValidate) > 1) | |
| 73 | 73 | for (int i=data.size()-1; i>=0; i--) |
| 74 | 74 | if (data[i].file.get<bool>("allPartitions",false) || data[i].file.get<bool>("duplicatePartitions",false)) |
| 75 | 75 | data.removeAt(i); | ... | ... |
openbr/openbr_plugin.cpp
| ... | ... | @@ -28,6 +28,7 @@ |
| 28 | 28 | #include <QtConcurrentRun> |
| 29 | 29 | #include <algorithm> |
| 30 | 30 | #include <iostream> |
| 31 | +#include <RandomLib/Random.hpp> | |
| 31 | 32 | |
| 32 | 33 | #ifndef BR_EMBEDDED |
| 33 | 34 | #include <QApplication> |
| ... | ... | @@ -433,7 +434,7 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) |
| 433 | 434 | if (gallery.getBool("reduce")) |
| 434 | 435 | newTemplates = newTemplates.reduced(); |
| 435 | 436 | |
| 436 | - if (Globals->crossValidate > 1) | |
| 437 | + if (abs(Globals->crossValidate) > 1) | |
| 437 | 438 | newTemplates = newTemplates.partition("Label"); |
| 438 | 439 | |
| 439 | 440 | if (!templates.isEmpty() && gallery.get<bool>("merge", false)) { |
| ... | ... | @@ -528,12 +529,16 @@ QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString |
| 528 | 529 | return result; |
| 529 | 530 | } |
| 530 | 531 | |
| 531 | -TemplateList TemplateList::partition(const QString &inputVariable) const | |
| 532 | +TemplateList TemplateList::partition(const QString &inputVariable, unsigned int randomSeed) const | |
| 532 | 533 | { |
| 533 | - const int crossValidate = Globals->crossValidate; | |
| 534 | + const int crossValidate = std::abs(Globals->crossValidate); | |
| 534 | 535 | if (crossValidate < 2) |
| 535 | 536 | return *this; |
| 536 | 537 | |
| 538 | + // Use separate RNG from Common::randN() to avoid re-seeding the global RNG. | |
| 539 | + // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same randomSeed seed. | |
| 540 | + RandomLib::Random rng; | |
| 541 | + | |
| 537 | 542 | TemplateList partitioned = *this; |
| 538 | 543 | for (int i=partitioned.size()-1; i>=0; i--) { |
| 539 | 544 | // See CrossValidateTransform for description of these variables |
| ... | ... | @@ -546,8 +551,12 @@ TemplateList TemplateList::partition(const QString &inputVariable) const |
| 546 | 551 | } else if (partitioned[i].file.getBool("allPartitions")) { |
| 547 | 552 | partitioned[i].file.set("Partition", -1); |
| 548 | 553 | } else { |
| 549 | - if (!partitioned[i].file.contains(("Partition"))) { | |
| 550 | - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); | |
| 554 | + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); | |
| 555 | + if (randomSeed) { | |
| 556 | + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed; | |
| 557 | + rng.Reseed(labelSeed); | |
| 558 | + partitioned[i].file.set("Partition", rng.Integer() % crossValidate); | |
| 559 | + } else if (!partitioned[i].file.contains(("Partition"))) { | |
| 551 | 560 | // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow |
| 552 | 561 | partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); |
| 553 | 562 | } | ... | ... |
openbr/openbr_plugin.h
| ... | ... | @@ -459,7 +459,7 @@ struct TemplateList : public QList<Template> |
| 459 | 459 | BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers); |
| 460 | 460 | |
| 461 | 461 | /*!< \brief Assign templates to folds partitions. */ |
| 462 | - BR_EXPORT TemplateList partition(const QString &inputVariable) const; | |
| 462 | + BR_EXPORT TemplateList partition(const QString &inputVariable, unsigned int randomSeed = 0) const; | |
| 463 | 463 | |
| 464 | 464 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const; |
| 465 | 465 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const; | ... | ... |
openbr/plugins/core/crossvalidate.cpp
| ... | ... | @@ -56,8 +56,10 @@ class CrossValidateTransform : public MetaTransform |
| 56 | 56 | Q_OBJECT |
| 57 | 57 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) |
| 58 | 58 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 59 | + Q_PROPERTY(unsigned int randomSeed READ get_randomSeed WRITE set_randomSeed RESET reset_randomSeed STORED false) | |
| 59 | 60 | BR_PROPERTY(QString, description, "Identity") |
| 60 | 61 | BR_PROPERTY(QString, inputVariable, "Label") |
| 62 | + BR_PROPERTY(unsigned int, randomSeed, 0) | |
| 61 | 63 | |
| 62 | 64 | // numPartitions copies of transform specified by description. |
| 63 | 65 | QList<br::Transform*> transforms; |
| ... | ... | @@ -67,13 +69,14 @@ class CrossValidateTransform : public MetaTransform |
| 67 | 69 | // is generally incorrect behavior. |
| 68 | 70 | void train(const TemplateList &data) |
| 69 | 71 | { |
| 70 | - QList<int> partitions = data.partition(inputVariable).files().crossValidationPartitions(); | |
| 71 | - const int numPartitions = Common::Max(partitions)+1; | |
| 72 | - | |
| 72 | + QList<int> partitions = data.partition(inputVariable, randomSeed).files().crossValidationPartitions(); | |
| 73 | + const int crossValidate = Globals->crossValidate; | |
| 74 | + // Only train once based on the 0th partition if crossValidate is negative. | |
| 75 | + const int numPartitions = (crossValidate < 0) ? 1 : Common::Max(partitions)+1; | |
| 73 | 76 | while (transforms.size() < numPartitions) |
| 74 | 77 | transforms.append(make(description)); |
| 75 | 78 | |
| 76 | - if (numPartitions < 2) { | |
| 79 | + if (std::abs(crossValidate) < 2) { | |
| 77 | 80 | transforms.first()->train(data); |
| 78 | 81 | return; |
| 79 | 82 | } |
| ... | ... | @@ -101,8 +104,13 @@ class CrossValidateTransform : public MetaTransform |
| 101 | 104 | |
| 102 | 105 | void project(const TemplateList &src, TemplateList &dst) const |
| 103 | 106 | { |
| 104 | - TemplateList partitioned = src.partition(inputVariable); | |
| 107 | + TemplateList partitioned = src.partition(inputVariable, randomSeed); | |
| 108 | + const int crossValidate = Globals->crossValidate; | |
| 105 | 109 | |
| 110 | + if (crossValidate < 0) { | |
| 111 | + transforms[0]->project(partitioned, dst); | |
| 112 | + return; | |
| 113 | + } | |
| 106 | 114 | for (int i=0; i<partitioned.size(); i++) { |
| 107 | 115 | int partition = partitioned[i].file.get<int>("Partition", 0); |
| 108 | 116 | transforms[partition]->project(partitioned, dst); | ... | ... |
openbr/plugins/output/eval.cpp
| ... | ... | @@ -45,7 +45,7 @@ class evalOutput : public MatrixOutput |
| 45 | 45 | |
| 46 | 46 | if (data.data) { |
| 47 | 47 | const QString csv = QString(file.name).replace(".eval", ".csv"); |
| 48 | - if ((Globals->crossValidate == 0) || (!crossValidate)) { | |
| 48 | + if ((Globals->crossValidate <= 0) || (!crossValidate)) { | |
| 49 | 49 | Evaluate(data, targetFiles, queryFiles, csv); |
| 50 | 50 | } else { |
| 51 | 51 | QFutureSynchronizer<float> futures; | ... | ... |