Commit 9102a476b9dc27b4776d09b0b282fd1a4a2788d7
1 parent
89556ca3
randomize crossValidation partitions
Showing
3 changed files
with
26 additions
and
27 deletions
openbr/openbr_plugin.cpp
| ... | ... | @@ -529,20 +529,22 @@ QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString |
| 529 | 529 | return result; |
| 530 | 530 | } |
| 531 | 531 | |
| 532 | -TemplateList TemplateList::partition(const QString &inputVariable, unsigned int bootStrap) const | |
| 532 | +TemplateList TemplateList::partition(const QString &inputVariable, unsigned int randomSeed) const | |
| 533 | 533 | { |
| 534 | 534 | TemplateList partitioned = *this; |
| 535 | - if (bootStrap > 0) { | |
| 536 | - // Use separate RNG from Common::randN() to avoid re-seeding the global RNG. | |
| 537 | - // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same bootStrap seed. | |
| 538 | - RandomLib::Random rng; | |
| 535 | + const int crossValidate = Globals->crossValidate; | |
| 536 | + | |
| 537 | + // Use separate RNG from Common::randN() to avoid re-seeding the global RNG. | |
| 538 | + // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same randomSeed seed. | |
| 539 | + RandomLib::Random rng; | |
| 540 | + if (crossValidate == 1) { | |
| 539 | 541 | for (int i=0; i < partitioned.size(); i++) { |
| 540 | 542 | const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); |
| 541 | - quint64 seed = md5.toHex().right(8).toULongLong(0, 16) + bootStrap; | |
| 542 | - rng.Reseed(seed); | |
| 543 | - float result = rng.FloatN(); | |
| 543 | + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed; | |
| 544 | + rng.Reseed(labelSeed); | |
| 544 | 545 | |
| 545 | - // Roughly 2/3rd to training, 1/3rd to testing | |
| 546 | + // Roughly 2/3rd to training, 1/3rd to testing for the special single split case | |
| 547 | + float result = rng.FloatN(); | |
| 546 | 548 | if (result <= (2.0/3.0)) |
| 547 | 549 | // Training |
| 548 | 550 | partitioned[i].file.set("Partition", QString::number(1)); |
| ... | ... | @@ -550,10 +552,6 @@ TemplateList TemplateList::partition(const QString &inputVariable, unsigned int |
| 550 | 552 | partitioned[i].file.set("Partition", QString::number(0)); |
| 551 | 553 | } |
| 552 | 554 | } else { |
| 553 | - const int crossValidate = Globals->crossValidate; | |
| 554 | - if (crossValidate < 2) | |
| 555 | - return *this; | |
| 556 | - | |
| 557 | 555 | for (int i=partitioned.size()-1; i>=0; i--) { |
| 558 | 556 | // See CrossValidateTransform for description of these variables |
| 559 | 557 | if (partitioned[i].file.getBool("duplicatePartitions")) { |
| ... | ... | @@ -565,8 +563,12 @@ TemplateList TemplateList::partition(const QString &inputVariable, unsigned int |
| 565 | 563 | } else if (partitioned[i].file.getBool("allPartitions")) { |
| 566 | 564 | partitioned[i].file.set("Partition", -1); |
| 567 | 565 | } else { |
| 568 | - if (!partitioned[i].file.contains(("Partition"))) { | |
| 569 | - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); | |
| 566 | + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); | |
| 567 | + if (randomSeed) { | |
| 568 | + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed; | |
| 569 | + rng.Reseed(labelSeed); | |
| 570 | + partitioned[i].file.set("Partition", rng.Integer() % crossValidate); | |
| 571 | + } else if (!partitioned[i].file.contains(("Partition"))) { | |
| 570 | 572 | // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow |
| 571 | 573 | partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); |
| 572 | 574 | } | ... | ... |
openbr/openbr_plugin.h
| ... | ... | @@ -459,7 +459,7 @@ struct TemplateList : public QList<Template> |
| 459 | 459 | BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers); |
| 460 | 460 | |
| 461 | 461 | /*!< \brief Assign templates to folds partitions. */ |
| 462 | - BR_EXPORT TemplateList partition(const QString &inputVariable, unsigned int bootStrap = 0) const; | |
| 462 | + BR_EXPORT TemplateList partition(const QString &inputVariable, unsigned int randomSeed = 0) const; | |
| 463 | 463 | |
| 464 | 464 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const; |
| 465 | 465 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const; | ... | ... |
openbr/plugins/core/crossvalidate.cpp
| ... | ... | @@ -56,10 +56,10 @@ class CrossValidateTransform : public MetaTransform |
| 56 | 56 | Q_OBJECT |
| 57 | 57 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) |
| 58 | 58 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 59 | - Q_PROPERTY(unsigned int bootStrap READ get_bootStrap WRITE set_bootStrap RESET reset_bootStrap STORED false) | |
| 59 | + Q_PROPERTY(unsigned int randomSeed READ get_randomSeed WRITE set_randomSeed RESET reset_randomSeed STORED false) | |
| 60 | 60 | BR_PROPERTY(QString, description, "Identity") |
| 61 | 61 | BR_PROPERTY(QString, inputVariable, "Label") |
| 62 | - BR_PROPERTY(unsigned int, bootStrap, 0) | |
| 62 | + BR_PROPERTY(unsigned int, randomSeed, 0) | |
| 63 | 63 | |
| 64 | 64 | // numPartitions copies of transform specified by description. |
| 65 | 65 | QList<br::Transform*> transforms; |
| ... | ... | @@ -69,16 +69,12 @@ class CrossValidateTransform : public MetaTransform |
| 69 | 69 | // is generally incorrect behavior. |
| 70 | 70 | void train(const TemplateList &data) |
| 71 | 71 | { |
| 72 | - QList<int> partitions = data.partition(inputVariable, bootStrap).files().crossValidationPartitions(); | |
| 73 | - const int numPartitions = (bootStrap > 0) ? 1 : Common::Max(partitions)+1; | |
| 72 | + QList<int> partitions = data.partition(inputVariable, randomSeed).files().crossValidationPartitions(); | |
| 73 | + const int crossValidate = Globals->crossValidate; | |
| 74 | + const int numPartitions = (crossValidate == 1) ? 1 : Common::Max(partitions)+1; | |
| 74 | 75 | while (transforms.size() < numPartitions) |
| 75 | 76 | transforms.append(make(description)); |
| 76 | 77 | |
| 77 | - if (numPartitions < 2 && !bootStrap) { | |
| 78 | - transforms.first()->train(data); | |
| 79 | - return; | |
| 80 | - } | |
| 81 | - | |
| 82 | 78 | QFutureSynchronizer<void> futures; |
| 83 | 79 | for (int i=0; i<numPartitions; i++) { |
| 84 | 80 | TemplateList partitionedData = data; |
| ... | ... | @@ -102,9 +98,10 @@ class CrossValidateTransform : public MetaTransform |
| 102 | 98 | |
| 103 | 99 | void project(const TemplateList &src, TemplateList &dst) const |
| 104 | 100 | { |
| 105 | - TemplateList partitioned = src.partition(inputVariable, bootStrap); | |
| 101 | + TemplateList partitioned = src.partition(inputVariable, randomSeed); | |
| 102 | + const int crossValidate = Globals->crossValidate; | |
| 106 | 103 | |
| 107 | - if (bootStrap > 0) { | |
| 104 | + if (crossValidate == 1) { | |
| 108 | 105 | transforms[0]->project(partitioned, dst); |
| 109 | 106 | return; |
| 110 | 107 | } | ... | ... |