Commit 9102a476b9dc27b4776d09b0b282fd1a4a2788d7

Authored by bhklein
1 parent 89556ca3

randomize crossValidation partitions

openbr/openbr_plugin.cpp
@@ -529,20 +529,22 @@ QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString @@ -529,20 +529,22 @@ QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString
529 return result; 529 return result;
530 } 530 }
531 531
532 -TemplateList TemplateList::partition(const QString &inputVariable, unsigned int bootStrap) const 532 +TemplateList TemplateList::partition(const QString &inputVariable, unsigned int randomSeed) const
533 { 533 {
534 TemplateList partitioned = *this; 534 TemplateList partitioned = *this;
535 - if (bootStrap > 0) {  
536 - // Use separate RNG from Common::randN() to avoid re-seeding the global RNG.  
537 - // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same bootStrap seed.  
538 - RandomLib::Random rng; 535 + const int crossValidate = Globals->crossValidate;
  536 +
  537 + // Use separate RNG from Common::randN() to avoid re-seeding the global RNG.
  538 + // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same randomSeed seed.
  539 + RandomLib::Random rng;
  540 + if (crossValidate == 1) {
539 for (int i=0; i < partitioned.size(); i++) { 541 for (int i=0; i < partitioned.size(); i++) {
540 const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); 542 const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
541 - quint64 seed = md5.toHex().right(8).toULongLong(0, 16) + bootStrap;  
542 - rng.Reseed(seed);  
543 - float result = rng.FloatN(); 543 + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed;
  544 + rng.Reseed(labelSeed);
544 545
545 - // Roughly 2/3rd to training, 1/3rd to testing 546 + // Roughly 2/3rd to training, 1/3rd to testing for the special single split case
  547 + float result = rng.FloatN();
546 if (result <= (2.0/3.0)) 548 if (result <= (2.0/3.0))
547 // Training 549 // Training
548 partitioned[i].file.set("Partition", QString::number(1)); 550 partitioned[i].file.set("Partition", QString::number(1));
@@ -550,10 +552,6 @@ TemplateList TemplateList::partition(const QString &amp;inputVariable, unsigned int @@ -550,10 +552,6 @@ TemplateList TemplateList::partition(const QString &amp;inputVariable, unsigned int
550 partitioned[i].file.set("Partition", QString::number(0)); 552 partitioned[i].file.set("Partition", QString::number(0));
551 } 553 }
552 } else { 554 } else {
553 - const int crossValidate = Globals->crossValidate;  
554 - if (crossValidate < 2)  
555 - return *this;  
556 -  
557 for (int i=partitioned.size()-1; i>=0; i--) { 555 for (int i=partitioned.size()-1; i>=0; i--) {
558 // See CrossValidateTransform for description of these variables 556 // See CrossValidateTransform for description of these variables
559 if (partitioned[i].file.getBool("duplicatePartitions")) { 557 if (partitioned[i].file.getBool("duplicatePartitions")) {
@@ -565,8 +563,12 @@ TemplateList TemplateList::partition(const QString &amp;inputVariable, unsigned int @@ -565,8 +563,12 @@ TemplateList TemplateList::partition(const QString &amp;inputVariable, unsigned int
565 } else if (partitioned[i].file.getBool("allPartitions")) { 563 } else if (partitioned[i].file.getBool("allPartitions")) {
566 partitioned[i].file.set("Partition", -1); 564 partitioned[i].file.set("Partition", -1);
567 } else { 565 } else {
568 - if (!partitioned[i].file.contains(("Partition"))) {  
569 - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); 566 + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
  567 + if (randomSeed) {
  568 + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed;
  569 + rng.Reseed(labelSeed);
  570 + partitioned[i].file.set("Partition", rng.Integer() % crossValidate);
  571 + } else if (!partitioned[i].file.contains(("Partition"))) {
570 // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow 572 // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow
571 partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); 573 partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate);
572 } 574 }
openbr/openbr_plugin.h
@@ -459,7 +459,7 @@ struct TemplateList : public QList&lt;Template&gt; @@ -459,7 +459,7 @@ struct TemplateList : public QList&lt;Template&gt;
459 BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers); 459 BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers);
460 460
461 /*!< \brief Assign templates to folds partitions. */ 461 /*!< \brief Assign templates to folds partitions. */
462 - BR_EXPORT TemplateList partition(const QString &inputVariable, unsigned int bootStrap = 0) const; 462 + BR_EXPORT TemplateList partition(const QString &inputVariable, unsigned int randomSeed = 0) const;
463 463
464 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const; 464 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const;
465 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const; 465 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const;
openbr/plugins/core/crossvalidate.cpp
@@ -56,10 +56,10 @@ class CrossValidateTransform : public MetaTransform @@ -56,10 +56,10 @@ class CrossValidateTransform : public MetaTransform
56 Q_OBJECT 56 Q_OBJECT
57 Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) 57 Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)
58 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) 58 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
59 - Q_PROPERTY(unsigned int bootStrap READ get_bootStrap WRITE set_bootStrap RESET reset_bootStrap STORED false) 59 + Q_PROPERTY(unsigned int randomSeed READ get_randomSeed WRITE set_randomSeed RESET reset_randomSeed STORED false)
60 BR_PROPERTY(QString, description, "Identity") 60 BR_PROPERTY(QString, description, "Identity")
61 BR_PROPERTY(QString, inputVariable, "Label") 61 BR_PROPERTY(QString, inputVariable, "Label")
62 - BR_PROPERTY(unsigned int, bootStrap, 0) 62 + BR_PROPERTY(unsigned int, randomSeed, 0)
63 63
64 // numPartitions copies of transform specified by description. 64 // numPartitions copies of transform specified by description.
65 QList<br::Transform*> transforms; 65 QList<br::Transform*> transforms;
@@ -69,16 +69,12 @@ class CrossValidateTransform : public MetaTransform @@ -69,16 +69,12 @@ class CrossValidateTransform : public MetaTransform
69 // is generally incorrect behavior. 69 // is generally incorrect behavior.
70 void train(const TemplateList &data) 70 void train(const TemplateList &data)
71 { 71 {
72 - QList<int> partitions = data.partition(inputVariable, bootStrap).files().crossValidationPartitions();  
73 - const int numPartitions = (bootStrap > 0) ? 1 : Common::Max(partitions)+1; 72 + QList<int> partitions = data.partition(inputVariable, randomSeed).files().crossValidationPartitions();
  73 + const int crossValidate = Globals->crossValidate;
  74 + const int numPartitions = (crossValidate == 1) ? 1 : Common::Max(partitions)+1;
74 while (transforms.size() < numPartitions) 75 while (transforms.size() < numPartitions)
75 transforms.append(make(description)); 76 transforms.append(make(description));
76 77
77 - if (numPartitions < 2 && !bootStrap) {  
78 - transforms.first()->train(data);  
79 - return;  
80 - }  
81 -  
82 QFutureSynchronizer<void> futures; 78 QFutureSynchronizer<void> futures;
83 for (int i=0; i<numPartitions; i++) { 79 for (int i=0; i<numPartitions; i++) {
84 TemplateList partitionedData = data; 80 TemplateList partitionedData = data;
@@ -102,9 +98,10 @@ class CrossValidateTransform : public MetaTransform @@ -102,9 +98,10 @@ class CrossValidateTransform : public MetaTransform
102 98
103 void project(const TemplateList &src, TemplateList &dst) const 99 void project(const TemplateList &src, TemplateList &dst) const
104 { 100 {
105 - TemplateList partitioned = src.partition(inputVariable, bootStrap); 101 + TemplateList partitioned = src.partition(inputVariable, randomSeed);
  102 + const int crossValidate = Globals->crossValidate;
106 103
107 - if (bootStrap > 0) { 104 + if (crossValidate == 1) {
108 transforms[0]->project(partitioned, dst); 105 transforms[0]->project(partitioned, dst);
109 return; 106 return;
110 } 107 }