Commit 9102a476b9dc27b4776d09b0b282fd1a4a2788d7

Authored by bhklein
1 parent 89556ca3

randomize crossValidation partitions

openbr/openbr_plugin.cpp
... ... @@ -529,20 +529,22 @@ QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString
529 529 return result;
530 530 }
531 531  
532   -TemplateList TemplateList::partition(const QString &inputVariable, unsigned int bootStrap) const
  532 +TemplateList TemplateList::partition(const QString &inputVariable, unsigned int randomSeed) const
533 533 {
534 534 TemplateList partitioned = *this;
535   - if (bootStrap > 0) {
536   - // Use separate RNG from Common::randN() to avoid re-seeding the global RNG.
537   - // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same bootStrap seed.
538   - RandomLib::Random rng;
  535 + const int crossValidate = Globals->crossValidate;
  536 +
  537 + // Use separate RNG from Common::randN() to avoid re-seeding the global RNG.
  538 + // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same randomSeed seed.
  539 + RandomLib::Random rng;
  540 + if (crossValidate == 1) {
539 541 for (int i=0; i < partitioned.size(); i++) {
540 542 const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
541   - quint64 seed = md5.toHex().right(8).toULongLong(0, 16) + bootStrap;
542   - rng.Reseed(seed);
543   - float result = rng.FloatN();
  543 + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed;
  544 + rng.Reseed(labelSeed);
544 545  
545   - // Roughly 2/3rd to training, 1/3rd to testing
  546 + // Roughly 2/3rd to training, 1/3rd to testing for the special single split case
  547 + float result = rng.FloatN();
546 548 if (result <= (2.0/3.0))
547 549 // Training
548 550 partitioned[i].file.set("Partition", QString::number(1));
... ... @@ -550,10 +552,6 @@ TemplateList TemplateList::partition(const QString &amp;inputVariable, unsigned int
550 552 partitioned[i].file.set("Partition", QString::number(0));
551 553 }
552 554 } else {
553   - const int crossValidate = Globals->crossValidate;
554   - if (crossValidate < 2)
555   - return *this;
556   -
557 555 for (int i=partitioned.size()-1; i>=0; i--) {
558 556 // See CrossValidateTransform for description of these variables
559 557 if (partitioned[i].file.getBool("duplicatePartitions")) {
... ... @@ -565,8 +563,12 @@ TemplateList TemplateList::partition(const QString &amp;inputVariable, unsigned int
565 563 } else if (partitioned[i].file.getBool("allPartitions")) {
566 564 partitioned[i].file.set("Partition", -1);
567 565 } else {
568   - if (!partitioned[i].file.contains(("Partition"))) {
569   - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
  566 + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
  567 + if (randomSeed) {
  568 + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed;
  569 + rng.Reseed(labelSeed);
  570 + partitioned[i].file.set("Partition", rng.Integer() % crossValidate);
  571 + } else if (!partitioned[i].file.contains(("Partition"))) {
570 572 // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow
571 573 partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate);
572 574 }
... ...
openbr/openbr_plugin.h
... ... @@ -459,7 +459,7 @@ struct TemplateList : public QList&lt;Template&gt;
459 459 BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers);
460 460  
461 461 /*!< \brief Assign templates to folds partitions. */
462   - BR_EXPORT TemplateList partition(const QString &inputVariable, unsigned int bootStrap = 0) const;
  462 + BR_EXPORT TemplateList partition(const QString &inputVariable, unsigned int randomSeed = 0) const;
463 463  
464 464 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const;
465 465 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const;
... ...
openbr/plugins/core/crossvalidate.cpp
... ... @@ -56,10 +56,10 @@ class CrossValidateTransform : public MetaTransform
56 56 Q_OBJECT
57 57 Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)
58 58 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
59   - Q_PROPERTY(unsigned int bootStrap READ get_bootStrap WRITE set_bootStrap RESET reset_bootStrap STORED false)
  59 + Q_PROPERTY(unsigned int randomSeed READ get_randomSeed WRITE set_randomSeed RESET reset_randomSeed STORED false)
60 60 BR_PROPERTY(QString, description, "Identity")
61 61 BR_PROPERTY(QString, inputVariable, "Label")
62   - BR_PROPERTY(unsigned int, bootStrap, 0)
  62 + BR_PROPERTY(unsigned int, randomSeed, 0)
63 63  
64 64 // numPartitions copies of transform specified by description.
65 65 QList<br::Transform*> transforms;
... ... @@ -69,16 +69,12 @@ class CrossValidateTransform : public MetaTransform
69 69 // is generally incorrect behavior.
70 70 void train(const TemplateList &data)
71 71 {
72   - QList<int> partitions = data.partition(inputVariable, bootStrap).files().crossValidationPartitions();
73   - const int numPartitions = (bootStrap > 0) ? 1 : Common::Max(partitions)+1;
  72 + QList<int> partitions = data.partition(inputVariable, randomSeed).files().crossValidationPartitions();
  73 + const int crossValidate = Globals->crossValidate;
  74 + const int numPartitions = (crossValidate == 1) ? 1 : Common::Max(partitions)+1;
74 75 while (transforms.size() < numPartitions)
75 76 transforms.append(make(description));
76 77  
77   - if (numPartitions < 2 && !bootStrap) {
78   - transforms.first()->train(data);
79   - return;
80   - }
81   -
82 78 QFutureSynchronizer<void> futures;
83 79 for (int i=0; i<numPartitions; i++) {
84 80 TemplateList partitionedData = data;
... ... @@ -102,9 +98,10 @@ class CrossValidateTransform : public MetaTransform
102 98  
103 99 void project(const TemplateList &src, TemplateList &dst) const
104 100 {
105   - TemplateList partitioned = src.partition(inputVariable, bootStrap);
  101 + TemplateList partitioned = src.partition(inputVariable, randomSeed);
  102 + const int crossValidate = Globals->crossValidate;
106 103  
107   - if (bootStrap > 0) {
  104 + if (crossValidate == 1) {
108 105 transforms[0]->project(partitioned, dst);
109 106 return;
110 107 }
... ...