Commit 5aec6fb362267eccc48eece9753d2848777c0ac4

Authored by Ben Klein
2 parents 9d766e1e cb24123a

Merge pull request #369 from biometrics/bootStrap

Random test/train split in TemplateList::Partition
app/br/br.cpp
... ... @@ -69,6 +69,7 @@ public:
69 69  
70 70 bool daemon = false;
71 71 const char *daemon_pipe = NULL;
  72 + bool isInt = false;
72 73 while (daemon || (argc > 0)) {
73 74 const char *fun;
74 75 int parc;
... ... @@ -78,7 +79,12 @@ public:
78 79  
79 80 fun = argv[0];
80 81 if (fun[0] == '-') fun++;
81   - parc = 0; while ((parc+1 < argc) && (argv[parc+1][0] != '-')) parc++;
  82 + parc = 0;
  83 + QString(argv[parc+1]).toInt(&isInt);
  84 + while ((parc+1 < argc) && ((argv[parc+1][0] != '-') || isInt)) {
  85 + parc++;
  86 + QString(argv[parc+1]).toInt(&isInt);
  87 + }
82 88 parv = (const char **)&argv[1];
83 89 argc = argc - (parc+1);
84 90 argv = &argv[parc+1];
... ...
openbr/core/bee.cpp
... ... @@ -230,7 +230,7 @@ void makeMask(const QString &amp;targetInput, const QString &amp;queryInput, const QStri
230 230 const FileList targets = TemplateList::fromGallery(targetInput).files();
231 231 const FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files();
232 232 const int partitions = targets.first().get<int>("crossValidate");
233   - if (partitions == 0) {
  233 + if (partitions <= 0) {
234 234 writeMatrix(makeMask(targets, queries), mask, targetInput, queryInput);
235 235 } else {
236 236 if (!mask.contains("%1")) qFatal("Mask file name missing partition number place marker (%%1)");
... ... @@ -246,7 +246,7 @@ void makePairwiseMask(const QString &amp;targetInput, const QString &amp;queryInput, con
246 246 const FileList targets = TemplateList::fromGallery(targetInput).files();
247 247 const FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files();
248 248 const int partitions = targets.first().get<int>("crossValidate");
249   - if (partitions == 0) {
  249 + if (partitions <= 0) {
250 250 writeMatrix(makePairwiseMask(targets, queries), mask, targetInput, queryInput);
251 251 } else {
252 252 if (!mask.contains("%1")) qFatal("Mask file name missing partition number place marker (%%1)");
... ...
openbr/core/core.cpp
... ... @@ -69,7 +69,7 @@ struct AlgorithmCore
69 69 QScopedPointer<Transform> trainingWrapper(br::wrapTransform(transform.data(), "Stream(readMode=DistributeFrames)"));
70 70 TemplateList data(TemplateList::fromGallery(input));
71 71  
72   - if (Globals->crossValidate > 1)
  72 + if (abs(Globals->crossValidate) > 1)
73 73 for (int i=data.size()-1; i>=0; i--)
74 74 if (data[i].file.get<bool>("allPartitions",false) || data[i].file.get<bool>("duplicatePartitions",false))
75 75 data.removeAt(i);
... ...
openbr/openbr_plugin.cpp
... ... @@ -28,6 +28,7 @@
28 28 #include <QtConcurrentRun>
29 29 #include <algorithm>
30 30 #include <iostream>
  31 +#include <RandomLib/Random.hpp>
31 32  
32 33 #ifndef BR_EMBEDDED
33 34 #include <QApplication>
... ... @@ -433,7 +434,7 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery)
433 434 if (gallery.getBool("reduce"))
434 435 newTemplates = newTemplates.reduced();
435 436  
436   - if (Globals->crossValidate > 1)
  437 + if (abs(Globals->crossValidate) > 1)
437 438 newTemplates = newTemplates.partition("Label");
438 439  
439 440 if (!templates.isEmpty() && gallery.get<bool>("merge", false)) {
... ... @@ -528,12 +529,16 @@ QList&lt;int&gt; TemplateList::applyIndex(const QString &amp;propName, const QHash&lt;QString
528 529 return result;
529 530 }
530 531  
531   -TemplateList TemplateList::partition(const QString &inputVariable) const
  532 +TemplateList TemplateList::partition(const QString &inputVariable, unsigned int randomSeed) const
532 533 {
533   - const int crossValidate = Globals->crossValidate;
  534 + const int crossValidate = std::abs(Globals->crossValidate);
534 535 if (crossValidate < 2)
535 536 return *this;
536 537  
  538 + // Use separate RNG from Common::randN() to avoid re-seeding the global RNG.
  539 + // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same randomSeed seed.
  540 + RandomLib::Random rng;
  541 +
537 542 TemplateList partitioned = *this;
538 543 for (int i=partitioned.size()-1; i>=0; i--) {
539 544 // See CrossValidateTransform for description of these variables
... ... @@ -546,8 +551,12 @@ TemplateList TemplateList::partition(const QString &amp;inputVariable) const
546 551 } else if (partitioned[i].file.getBool("allPartitions")) {
547 552 partitioned[i].file.set("Partition", -1);
548 553 } else {
549   - if (!partitioned[i].file.contains(("Partition"))) {
550   - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
  554 + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
  555 + if (randomSeed) {
  556 + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed;
  557 + rng.Reseed(labelSeed);
  558 + partitioned[i].file.set("Partition", rng.Integer() % crossValidate);
  559 + } else if (!partitioned[i].file.contains(("Partition"))) {
551 560 // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow
552 561 partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate);
553 562 }
... ...
openbr/openbr_plugin.h
... ... @@ -459,7 +459,7 @@ struct TemplateList : public QList&lt;Template&gt;
459 459 BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers);
460 460  
461 461 /*!< \brief Assign templates to folds partitions. */
462   - BR_EXPORT TemplateList partition(const QString &inputVariable) const;
  462 + BR_EXPORT TemplateList partition(const QString &inputVariable, unsigned int randomSeed = 0) const;
463 463  
464 464 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const;
465 465 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const;
... ...
openbr/plugins/core/crossvalidate.cpp
... ... @@ -56,8 +56,10 @@ class CrossValidateTransform : public MetaTransform
56 56 Q_OBJECT
57 57 Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)
58 58 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  59 + Q_PROPERTY(unsigned int randomSeed READ get_randomSeed WRITE set_randomSeed RESET reset_randomSeed STORED false)
59 60 BR_PROPERTY(QString, description, "Identity")
60 61 BR_PROPERTY(QString, inputVariable, "Label")
  62 + BR_PROPERTY(unsigned int, randomSeed, 0)
61 63  
62 64 // numPartitions copies of transform specified by description.
63 65 QList<br::Transform*> transforms;
... ... @@ -67,13 +69,14 @@ class CrossValidateTransform : public MetaTransform
67 69 // is generally incorrect behavior.
68 70 void train(const TemplateList &data)
69 71 {
70   - QList<int> partitions = data.partition(inputVariable).files().crossValidationPartitions();
71   - const int numPartitions = Common::Max(partitions)+1;
72   -
  72 + QList<int> partitions = data.partition(inputVariable, randomSeed).files().crossValidationPartitions();
  73 + const int crossValidate = Globals->crossValidate;
  74 + // Only train once based on the 0th partition if crossValidate is negative.
  75 + const int numPartitions = (crossValidate < 0) ? 1 : Common::Max(partitions)+1;
73 76 while (transforms.size() < numPartitions)
74 77 transforms.append(make(description));
75 78  
76   - if (numPartitions < 2) {
  79 + if (std::abs(crossValidate) < 2) {
77 80 transforms.first()->train(data);
78 81 return;
79 82 }
... ... @@ -101,8 +104,13 @@ class CrossValidateTransform : public MetaTransform
101 104  
102 105 void project(const TemplateList &src, TemplateList &dst) const
103 106 {
104   - TemplateList partitioned = src.partition(inputVariable);
  107 + TemplateList partitioned = src.partition(inputVariable, randomSeed);
  108 + const int crossValidate = Globals->crossValidate;
105 109  
  110 + if (crossValidate < 0) {
  111 + transforms[0]->project(partitioned, dst);
  112 + return;
  113 + }
106 114 for (int i=0; i<partitioned.size(); i++) {
107 115 int partition = partitioned[i].file.get<int>("Partition", 0);
108 116 transforms[partition]->project(partitioned, dst);
... ...
openbr/plugins/output/eval.cpp
... ... @@ -45,7 +45,7 @@ class evalOutput : public MatrixOutput
45 45  
46 46 if (data.data) {
47 47 const QString csv = QString(file.name).replace(".eval", ".csv");
48   - if ((Globals->crossValidate == 0) || (!crossValidate)) {
  48 + if ((Globals->crossValidate <= 0) || (!crossValidate)) {
49 49 Evaluate(data, targetFiles, queryFiles, csv);
50 50 } else {
51 51 QFutureSynchronizer<float> futures;
... ...