Commit 6a4ba0bfe364387f737fd5ec11c2d7e4db953f49

Authored by bhklein
1 parent 1761523e

Randomly split data into training/testing set

openbr/openbr_plugin.cpp
@@ -28,6 +28,7 @@ @@ -28,6 +28,7 @@
28 #include <QtConcurrentRun> 28 #include <QtConcurrentRun>
29 #include <algorithm> 29 #include <algorithm>
30 #include <iostream> 30 #include <iostream>
  31 +#include <RandomLib/Random.hpp>
31 32
32 #ifndef BR_EMBEDDED 33 #ifndef BR_EMBEDDED
33 #include <QApplication> 34 #include <QApplication>
@@ -528,28 +529,47 @@ QList&lt;int&gt; TemplateList::applyIndex(const QString &amp;propName, const QHash&lt;QString @@ -528,28 +529,47 @@ QList&lt;int&gt; TemplateList::applyIndex(const QString &amp;propName, const QHash&lt;QString
528 return result; 529 return result;
529 } 530 }
530 531
531 -TemplateList TemplateList::partition(const QString &inputVariable) const 532 +TemplateList TemplateList::partition(const QString &inputVariable, int bootStrap) const
532 { 533 {
533 - const int crossValidate = Globals->crossValidate;  
534 - if (crossValidate < 2)  
535 - return *this;  
536 -  
537 TemplateList partitioned = *this; 534 TemplateList partitioned = *this;
538 - for (int i=partitioned.size()-1; i>=0; i--) {  
539 - // See CrossValidateTransform for description of these variables  
540 - if (partitioned[i].file.getBool("duplicatePartitions")) {  
541 - for (int j=crossValidate-1; j>=0; j--) {  
542 - Template duplicateTemplate = partitioned[i];  
543 - duplicateTemplate.file.set("Partition", j);  
544 - partitioned.insert(i+1, duplicateTemplate);  
545 - }  
546 - } else if (partitioned[i].file.getBool("allPartitions")) {  
547 - partitioned[i].file.set("Partition", -1);  
548 - } else {  
549 - if (!partitioned[i].file.contains(("Partition"))) {  
550 - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);  
551 - // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow  
552 - partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); 535 + if (bootStrap > 0) {
  536 + // Use separate RNG from Common::randN() to avoid re-seeding the global RNG.
  537 + // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same bootStrap seed.
  538 + RandomLib::Random rng;
  539 + for (int i=0; i < partitioned.size(); i++) {
  540 + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
  541 + quint64 seed = md5.toHex().right(8).toULongLong(0, 16) + bootStrap;
  542 + rng.Reseed(seed);
  543 + float result = rng.FloatN();
  544 +
  545 + // Roughly 2/3rd to training, 1/3rd to testing
  546 + if (result <= (2.0/3.0))
  547 + // Training
  548 + partitioned[i].file.set("Partition", QString::number(1));
  549 + else
  550 + partitioned[i].file.set("Partition", QString::number(0));
  551 + }
  552 + } else {
  553 + const int crossValidate = Globals->crossValidate;
  554 + if (crossValidate < 2)
  555 + return *this;
  556 +
  557 + for (int i=partitioned.size()-1; i>=0; i--) {
  558 + // See CrossValidateTransform for description of these variables
  559 + if (partitioned[i].file.getBool("duplicatePartitions")) {
  560 + for (int j=crossValidate-1; j>=0; j--) {
  561 + Template duplicateTemplate = partitioned[i];
  562 + duplicateTemplate.file.set("Partition", j);
  563 + partitioned.insert(i+1, duplicateTemplate);
  564 + }
  565 + } else if (partitioned[i].file.getBool("allPartitions")) {
  566 + partitioned[i].file.set("Partition", -1);
  567 + } else {
  568 + if (!partitioned[i].file.contains(("Partition"))) {
  569 + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
  570 + // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow
  571 + partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate);
  572 + }
553 } 573 }
554 } 574 }
555 } 575 }
openbr/openbr_plugin.h
@@ -459,7 +459,7 @@ struct TemplateList : public QList&lt;Template&gt; @@ -459,7 +459,7 @@ struct TemplateList : public QList&lt;Template&gt;
459 BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers); 459 BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers);
460 460
461 /*!< \brief Assign templates to folds partitions. */ 461 /*!< \brief Assign templates to folds partitions. */
462 - BR_EXPORT TemplateList partition(const QString &inputVariable) const; 462 + BR_EXPORT TemplateList partition(const QString &inputVariable, int bootStrap = 0) const;
463 463
464 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const; 464 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const;
465 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const; 465 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const;
openbr/plugins/core/crossvalidate.cpp
@@ -56,8 +56,10 @@ class CrossValidateTransform : public MetaTransform @@ -56,8 +56,10 @@ class CrossValidateTransform : public MetaTransform
56 Q_OBJECT 56 Q_OBJECT
57 Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) 57 Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)
58 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) 58 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  59 + Q_PROPERTY(int bootStrap READ get_bootStrap WRITE set_bootStrap RESET reset_bootStrap STORED false)
59 BR_PROPERTY(QString, description, "Identity") 60 BR_PROPERTY(QString, description, "Identity")
60 BR_PROPERTY(QString, inputVariable, "Label") 61 BR_PROPERTY(QString, inputVariable, "Label")
  62 + BR_PROPERTY(int, bootStrap, 0)
61 63
62 // numPartitions copies of transform specified by description. 64 // numPartitions copies of transform specified by description.
63 QList<br::Transform*> transforms; 65 QList<br::Transform*> transforms;
@@ -67,13 +69,12 @@ class CrossValidateTransform : public MetaTransform @@ -67,13 +69,12 @@ class CrossValidateTransform : public MetaTransform
67 // is generally incorrect behavior. 69 // is generally incorrect behavior.
68 void train(const TemplateList &data) 70 void train(const TemplateList &data)
69 { 71 {
70 - QList<int> partitions = data.partition(inputVariable).files().crossValidationPartitions();  
71 - const int numPartitions = Common::Max(partitions)+1;  
72 - 72 + QList<int> partitions = data.partition(inputVariable, bootStrap).files().crossValidationPartitions();
  73 + const int numPartitions = (bootStrap > 0) ? 1 : Common::Max(partitions)+1;
73 while (transforms.size() < numPartitions) 74 while (transforms.size() < numPartitions)
74 transforms.append(make(description)); 75 transforms.append(make(description));
75 76
76 - if (numPartitions < 2) { 77 + if (numPartitions < 2 && !(bootStrap > 0)) {
77 transforms.first()->train(data); 78 transforms.first()->train(data);
78 return; 79 return;
79 } 80 }
@@ -101,8 +102,12 @@ class CrossValidateTransform : public MetaTransform @@ -101,8 +102,12 @@ class CrossValidateTransform : public MetaTransform
101 102
102 void project(const TemplateList &src, TemplateList &dst) const 103 void project(const TemplateList &src, TemplateList &dst) const
103 { 104 {
104 - TemplateList partitioned = src.partition(inputVariable); 105 + TemplateList partitioned = src.partition(inputVariable, bootStrap);
105 106
  107 + if (bootStrap > 0) {
  108 + transforms[0]->project(partitioned, dst);
  109 + return;
  110 + }
106 for (int i=0; i<partitioned.size(); i++) { 111 for (int i=0; i<partitioned.size(); i++) {
107 int partition = partitioned[i].file.get<int>("Partition", 0); 112 int partition = partitioned[i].file.get<int>("Partition", 0);
108 transforms[partition]->project(partitioned, dst); 113 transforms[partition]->project(partitioned, dst);