Commit 5aec6fb362267eccc48eece9753d2848777c0ac4

Authored by Ben Klein
2 parents 9d766e1e cb24123a

Merge pull request #369 from biometrics/bootStrap

Random test/train split in TemplateList::Partition
app/br/br.cpp
@@ -69,6 +69,7 @@ public: @@ -69,6 +69,7 @@ public:
69 69
70 bool daemon = false; 70 bool daemon = false;
71 const char *daemon_pipe = NULL; 71 const char *daemon_pipe = NULL;
  72 + bool isInt = false;
72 while (daemon || (argc > 0)) { 73 while (daemon || (argc > 0)) {
73 const char *fun; 74 const char *fun;
74 int parc; 75 int parc;
@@ -78,7 +79,12 @@ public: @@ -78,7 +79,12 @@ public:
78 79
79 fun = argv[0]; 80 fun = argv[0];
80 if (fun[0] == '-') fun++; 81 if (fun[0] == '-') fun++;
81 - parc = 0; while ((parc+1 < argc) && (argv[parc+1][0] != '-')) parc++; 82 + parc = 0;
  83 + QString(argv[parc+1]).toInt(&isInt);
  84 + while ((parc+1 < argc) && ((argv[parc+1][0] != '-') || isInt)) {
  85 + parc++;
  86 + QString(argv[parc+1]).toInt(&isInt);
  87 + }
82 parv = (const char **)&argv[1]; 88 parv = (const char **)&argv[1];
83 argc = argc - (parc+1); 89 argc = argc - (parc+1);
84 argv = &argv[parc+1]; 90 argv = &argv[parc+1];
openbr/core/bee.cpp
@@ -230,7 +230,7 @@ void makeMask(const QString &amp;targetInput, const QString &amp;queryInput, const QStri @@ -230,7 +230,7 @@ void makeMask(const QString &amp;targetInput, const QString &amp;queryInput, const QStri
230 const FileList targets = TemplateList::fromGallery(targetInput).files(); 230 const FileList targets = TemplateList::fromGallery(targetInput).files();
231 const FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files(); 231 const FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files();
232 const int partitions = targets.first().get<int>("crossValidate"); 232 const int partitions = targets.first().get<int>("crossValidate");
233 - if (partitions == 0) { 233 + if (partitions <= 0) {
234 writeMatrix(makeMask(targets, queries), mask, targetInput, queryInput); 234 writeMatrix(makeMask(targets, queries), mask, targetInput, queryInput);
235 } else { 235 } else {
236 if (!mask.contains("%1")) qFatal("Mask file name missing partition number place marker (%%1)"); 236 if (!mask.contains("%1")) qFatal("Mask file name missing partition number place marker (%%1)");
@@ -246,7 +246,7 @@ void makePairwiseMask(const QString &amp;targetInput, const QString &amp;queryInput, con @@ -246,7 +246,7 @@ void makePairwiseMask(const QString &amp;targetInput, const QString &amp;queryInput, con
246 const FileList targets = TemplateList::fromGallery(targetInput).files(); 246 const FileList targets = TemplateList::fromGallery(targetInput).files();
247 const FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files(); 247 const FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files();
248 const int partitions = targets.first().get<int>("crossValidate"); 248 const int partitions = targets.first().get<int>("crossValidate");
249 - if (partitions == 0) { 249 + if (partitions <= 0) {
250 writeMatrix(makePairwiseMask(targets, queries), mask, targetInput, queryInput); 250 writeMatrix(makePairwiseMask(targets, queries), mask, targetInput, queryInput);
251 } else { 251 } else {
252 if (!mask.contains("%1")) qFatal("Mask file name missing partition number place marker (%%1)"); 252 if (!mask.contains("%1")) qFatal("Mask file name missing partition number place marker (%%1)");
openbr/core/core.cpp
@@ -69,7 +69,7 @@ struct AlgorithmCore @@ -69,7 +69,7 @@ struct AlgorithmCore
69 QScopedPointer<Transform> trainingWrapper(br::wrapTransform(transform.data(), "Stream(readMode=DistributeFrames)")); 69 QScopedPointer<Transform> trainingWrapper(br::wrapTransform(transform.data(), "Stream(readMode=DistributeFrames)"));
70 TemplateList data(TemplateList::fromGallery(input)); 70 TemplateList data(TemplateList::fromGallery(input));
71 71
72 - if (Globals->crossValidate > 1) 72 + if (abs(Globals->crossValidate) > 1)
73 for (int i=data.size()-1; i>=0; i--) 73 for (int i=data.size()-1; i>=0; i--)
74 if (data[i].file.get<bool>("allPartitions",false) || data[i].file.get<bool>("duplicatePartitions",false)) 74 if (data[i].file.get<bool>("allPartitions",false) || data[i].file.get<bool>("duplicatePartitions",false))
75 data.removeAt(i); 75 data.removeAt(i);
openbr/openbr_plugin.cpp
@@ -28,6 +28,7 @@ @@ -28,6 +28,7 @@
28 #include <QtConcurrentRun> 28 #include <QtConcurrentRun>
29 #include <algorithm> 29 #include <algorithm>
30 #include <iostream> 30 #include <iostream>
  31 +#include <RandomLib/Random.hpp>
31 32
32 #ifndef BR_EMBEDDED 33 #ifndef BR_EMBEDDED
33 #include <QApplication> 34 #include <QApplication>
@@ -433,7 +434,7 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery) @@ -433,7 +434,7 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery)
433 if (gallery.getBool("reduce")) 434 if (gallery.getBool("reduce"))
434 newTemplates = newTemplates.reduced(); 435 newTemplates = newTemplates.reduced();
435 436
436 - if (Globals->crossValidate > 1) 437 + if (abs(Globals->crossValidate) > 1)
437 newTemplates = newTemplates.partition("Label"); 438 newTemplates = newTemplates.partition("Label");
438 439
439 if (!templates.isEmpty() && gallery.get<bool>("merge", false)) { 440 if (!templates.isEmpty() && gallery.get<bool>("merge", false)) {
@@ -528,12 +529,16 @@ QList&lt;int&gt; TemplateList::applyIndex(const QString &amp;propName, const QHash&lt;QString @@ -528,12 +529,16 @@ QList&lt;int&gt; TemplateList::applyIndex(const QString &amp;propName, const QHash&lt;QString
528 return result; 529 return result;
529 } 530 }
530 531
531 -TemplateList TemplateList::partition(const QString &inputVariable) const 532 +TemplateList TemplateList::partition(const QString &inputVariable, unsigned int randomSeed) const
532 { 533 {
533 - const int crossValidate = Globals->crossValidate; 534 + const int crossValidate = std::abs(Globals->crossValidate);
534 if (crossValidate < 2) 535 if (crossValidate < 2)
535 return *this; 536 return *this;
536 537
  538 + // Use separate RNG from Common::randN() to avoid re-seeding the global RNG.
  539 + // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same randomSeed seed.
  540 + RandomLib::Random rng;
  541 +
537 TemplateList partitioned = *this; 542 TemplateList partitioned = *this;
538 for (int i=partitioned.size()-1; i>=0; i--) { 543 for (int i=partitioned.size()-1; i>=0; i--) {
539 // See CrossValidateTransform for description of these variables 544 // See CrossValidateTransform for description of these variables
@@ -546,8 +551,12 @@ TemplateList TemplateList::partition(const QString &amp;inputVariable) const @@ -546,8 +551,12 @@ TemplateList TemplateList::partition(const QString &amp;inputVariable) const
546 } else if (partitioned[i].file.getBool("allPartitions")) { 551 } else if (partitioned[i].file.getBool("allPartitions")) {
547 partitioned[i].file.set("Partition", -1); 552 partitioned[i].file.set("Partition", -1);
548 } else { 553 } else {
549 - if (!partitioned[i].file.contains(("Partition"))) {  
550 - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); 554 + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
  555 + if (randomSeed) {
  556 + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed;
  557 + rng.Reseed(labelSeed);
  558 + partitioned[i].file.set("Partition", rng.Integer() % crossValidate);
  559 + } else if (!partitioned[i].file.contains(("Partition"))) {
551 // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow 560 // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow
552 partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); 561 partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate);
553 } 562 }
openbr/openbr_plugin.h
@@ -459,7 +459,7 @@ struct TemplateList : public QList&lt;Template&gt; @@ -459,7 +459,7 @@ struct TemplateList : public QList&lt;Template&gt;
459 BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers); 459 BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers);
460 460
461 /*!< \brief Assign templates to folds partitions. */ 461 /*!< \brief Assign templates to folds partitions. */
462 - BR_EXPORT TemplateList partition(const QString &inputVariable) const; 462 + BR_EXPORT TemplateList partition(const QString &inputVariable, unsigned int randomSeed = 0) const;
463 463
464 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const; 464 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const;
465 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const; 465 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const;
openbr/plugins/core/crossvalidate.cpp
@@ -56,8 +56,10 @@ class CrossValidateTransform : public MetaTransform @@ -56,8 +56,10 @@ class CrossValidateTransform : public MetaTransform
56 Q_OBJECT 56 Q_OBJECT
57 Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) 57 Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)
58 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) 58 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  59 + Q_PROPERTY(unsigned int randomSeed READ get_randomSeed WRITE set_randomSeed RESET reset_randomSeed STORED false)
59 BR_PROPERTY(QString, description, "Identity") 60 BR_PROPERTY(QString, description, "Identity")
60 BR_PROPERTY(QString, inputVariable, "Label") 61 BR_PROPERTY(QString, inputVariable, "Label")
  62 + BR_PROPERTY(unsigned int, randomSeed, 0)
61 63
62 // numPartitions copies of transform specified by description. 64 // numPartitions copies of transform specified by description.
63 QList<br::Transform*> transforms; 65 QList<br::Transform*> transforms;
@@ -67,13 +69,14 @@ class CrossValidateTransform : public MetaTransform @@ -67,13 +69,14 @@ class CrossValidateTransform : public MetaTransform
67 // is generally incorrect behavior. 69 // is generally incorrect behavior.
68 void train(const TemplateList &data) 70 void train(const TemplateList &data)
69 { 71 {
70 - QList<int> partitions = data.partition(inputVariable).files().crossValidationPartitions();  
71 - const int numPartitions = Common::Max(partitions)+1;  
72 - 72 + QList<int> partitions = data.partition(inputVariable, randomSeed).files().crossValidationPartitions();
  73 + const int crossValidate = Globals->crossValidate;
  74 + // Only train once based on the 0th partition if crossValidate is negative.
  75 + const int numPartitions = (crossValidate < 0) ? 1 : Common::Max(partitions)+1;
73 while (transforms.size() < numPartitions) 76 while (transforms.size() < numPartitions)
74 transforms.append(make(description)); 77 transforms.append(make(description));
75 78
76 - if (numPartitions < 2) { 79 + if (std::abs(crossValidate) < 2) {
77 transforms.first()->train(data); 80 transforms.first()->train(data);
78 return; 81 return;
79 } 82 }
@@ -101,8 +104,13 @@ class CrossValidateTransform : public MetaTransform @@ -101,8 +104,13 @@ class CrossValidateTransform : public MetaTransform
101 104
102 void project(const TemplateList &src, TemplateList &dst) const 105 void project(const TemplateList &src, TemplateList &dst) const
103 { 106 {
104 - TemplateList partitioned = src.partition(inputVariable); 107 + TemplateList partitioned = src.partition(inputVariable, randomSeed);
  108 + const int crossValidate = Globals->crossValidate;
105 109
  110 + if (crossValidate < 0) {
  111 + transforms[0]->project(partitioned, dst);
  112 + return;
  113 + }
106 for (int i=0; i<partitioned.size(); i++) { 114 for (int i=0; i<partitioned.size(); i++) {
107 int partition = partitioned[i].file.get<int>("Partition", 0); 115 int partition = partitioned[i].file.get<int>("Partition", 0);
108 transforms[partition]->project(partitioned, dst); 116 transforms[partition]->project(partitioned, dst);
openbr/plugins/output/eval.cpp
@@ -45,7 +45,7 @@ class evalOutput : public MatrixOutput @@ -45,7 +45,7 @@ class evalOutput : public MatrixOutput
45 45
46 if (data.data) { 46 if (data.data) {
47 const QString csv = QString(file.name).replace(".eval", ".csv"); 47 const QString csv = QString(file.name).replace(".eval", ".csv");
48 - if ((Globals->crossValidate == 0) || (!crossValidate)) { 48 + if ((Globals->crossValidate <= 0) || (!crossValidate)) {
49 Evaluate(data, targetFiles, queryFiles, csv); 49 Evaluate(data, targetFiles, queryFiles, csv);
50 } else { 50 } else {
51 QFutureSynchronizer<float> futures; 51 QFutureSynchronizer<float> futures;