Commit 9102a476b9dc27b4776d09b0b282fd1a4a2788d7
1 parent
89556ca3
randomize crossValidation partitions
Showing
3 changed files
with
26 additions
and
27 deletions
openbr/openbr_plugin.cpp
| @@ -529,20 +529,22 @@ QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString | @@ -529,20 +529,22 @@ QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString | ||
| 529 | return result; | 529 | return result; |
| 530 | } | 530 | } |
| 531 | 531 | ||
| 532 | -TemplateList TemplateList::partition(const QString &inputVariable, unsigned int bootStrap) const | 532 | +TemplateList TemplateList::partition(const QString &inputVariable, unsigned int randomSeed) const |
| 533 | { | 533 | { |
| 534 | TemplateList partitioned = *this; | 534 | TemplateList partitioned = *this; |
| 535 | - if (bootStrap > 0) { | ||
| 536 | - // Use separate RNG from Common::randN() to avoid re-seeding the global RNG. | ||
| 537 | - // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same bootStrap seed. | ||
| 538 | - RandomLib::Random rng; | 535 | + const int crossValidate = Globals->crossValidate; |
| 536 | + | ||
| 537 | + // Use separate RNG from Common::randN() to avoid re-seeding the global RNG. | ||
| 538 | + // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same randomSeed seed. | ||
| 539 | + RandomLib::Random rng; | ||
| 540 | + if (crossValidate == 1) { | ||
| 539 | for (int i=0; i < partitioned.size(); i++) { | 541 | for (int i=0; i < partitioned.size(); i++) { |
| 540 | const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); | 542 | const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); |
| 541 | - quint64 seed = md5.toHex().right(8).toULongLong(0, 16) + bootStrap; | ||
| 542 | - rng.Reseed(seed); | ||
| 543 | - float result = rng.FloatN(); | 543 | + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed; |
| 544 | + rng.Reseed(labelSeed); | ||
| 544 | 545 | ||
| 545 | - // Roughly 2/3rd to training, 1/3rd to testing | 546 | + // Roughly 2/3rd to training, 1/3rd to testing for the special single split case |
| 547 | + float result = rng.FloatN(); | ||
| 546 | if (result <= (2.0/3.0)) | 548 | if (result <= (2.0/3.0)) |
| 547 | // Training | 549 | // Training |
| 548 | partitioned[i].file.set("Partition", QString::number(1)); | 550 | partitioned[i].file.set("Partition", QString::number(1)); |
| @@ -550,10 +552,6 @@ TemplateList TemplateList::partition(const QString &inputVariable, unsigned int | @@ -550,10 +552,6 @@ TemplateList TemplateList::partition(const QString &inputVariable, unsigned int | ||
| 550 | partitioned[i].file.set("Partition", QString::number(0)); | 552 | partitioned[i].file.set("Partition", QString::number(0)); |
| 551 | } | 553 | } |
| 552 | } else { | 554 | } else { |
| 553 | - const int crossValidate = Globals->crossValidate; | ||
| 554 | - if (crossValidate < 2) | ||
| 555 | - return *this; | ||
| 556 | - | ||
| 557 | for (int i=partitioned.size()-1; i>=0; i--) { | 555 | for (int i=partitioned.size()-1; i>=0; i--) { |
| 558 | // See CrossValidateTransform for description of these variables | 556 | // See CrossValidateTransform for description of these variables |
| 559 | if (partitioned[i].file.getBool("duplicatePartitions")) { | 557 | if (partitioned[i].file.getBool("duplicatePartitions")) { |
| @@ -565,8 +563,12 @@ TemplateList TemplateList::partition(const QString &inputVariable, unsigned int | @@ -565,8 +563,12 @@ TemplateList TemplateList::partition(const QString &inputVariable, unsigned int | ||
| 565 | } else if (partitioned[i].file.getBool("allPartitions")) { | 563 | } else if (partitioned[i].file.getBool("allPartitions")) { |
| 566 | partitioned[i].file.set("Partition", -1); | 564 | partitioned[i].file.set("Partition", -1); |
| 567 | } else { | 565 | } else { |
| 568 | - if (!partitioned[i].file.contains(("Partition"))) { | ||
| 569 | - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); | 566 | + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); |
| 567 | + if (randomSeed) { | ||
| 568 | + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed; | ||
| 569 | + rng.Reseed(labelSeed); | ||
| 570 | + partitioned[i].file.set("Partition", rng.Integer() % crossValidate); | ||
| 571 | + } else if (!partitioned[i].file.contains(("Partition"))) { | ||
| 570 | // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow | 572 | // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow |
| 571 | partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); | 573 | partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); |
| 572 | } | 574 | } |
openbr/openbr_plugin.h
| @@ -459,7 +459,7 @@ struct TemplateList : public QList<Template> | @@ -459,7 +459,7 @@ struct TemplateList : public QList<Template> | ||
| 459 | BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers); | 459 | BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers); |
| 460 | 460 | ||
| 461 | /*!< \brief Assign templates to folds partitions. */ | 461 | /*!< \brief Assign templates to folds partitions. */ |
| 462 | - BR_EXPORT TemplateList partition(const QString &inputVariable, unsigned int bootStrap = 0) const; | 462 | + BR_EXPORT TemplateList partition(const QString &inputVariable, unsigned int randomSeed = 0) const; |
| 463 | 463 | ||
| 464 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const; | 464 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const; |
| 465 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const; | 465 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const; |
openbr/plugins/core/crossvalidate.cpp
| @@ -56,10 +56,10 @@ class CrossValidateTransform : public MetaTransform | @@ -56,10 +56,10 @@ class CrossValidateTransform : public MetaTransform | ||
| 56 | Q_OBJECT | 56 | Q_OBJECT |
| 57 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) | 57 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) |
| 58 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | 58 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 59 | - Q_PROPERTY(unsigned int bootStrap READ get_bootStrap WRITE set_bootStrap RESET reset_bootStrap STORED false) | 59 | + Q_PROPERTY(unsigned int randomSeed READ get_randomSeed WRITE set_randomSeed RESET reset_randomSeed STORED false) |
| 60 | BR_PROPERTY(QString, description, "Identity") | 60 | BR_PROPERTY(QString, description, "Identity") |
| 61 | BR_PROPERTY(QString, inputVariable, "Label") | 61 | BR_PROPERTY(QString, inputVariable, "Label") |
| 62 | - BR_PROPERTY(unsigned int, bootStrap, 0) | 62 | + BR_PROPERTY(unsigned int, randomSeed, 0) |
| 63 | 63 | ||
| 64 | // numPartitions copies of transform specified by description. | 64 | // numPartitions copies of transform specified by description. |
| 65 | QList<br::Transform*> transforms; | 65 | QList<br::Transform*> transforms; |
| @@ -69,16 +69,12 @@ class CrossValidateTransform : public MetaTransform | @@ -69,16 +69,12 @@ class CrossValidateTransform : public MetaTransform | ||
| 69 | // is generally incorrect behavior. | 69 | // is generally incorrect behavior. |
| 70 | void train(const TemplateList &data) | 70 | void train(const TemplateList &data) |
| 71 | { | 71 | { |
| 72 | - QList<int> partitions = data.partition(inputVariable, bootStrap).files().crossValidationPartitions(); | ||
| 73 | - const int numPartitions = (bootStrap > 0) ? 1 : Common::Max(partitions)+1; | 72 | + QList<int> partitions = data.partition(inputVariable, randomSeed).files().crossValidationPartitions(); |
| 73 | + const int crossValidate = Globals->crossValidate; | ||
| 74 | + const int numPartitions = (crossValidate == 1) ? 1 : Common::Max(partitions)+1; | ||
| 74 | while (transforms.size() < numPartitions) | 75 | while (transforms.size() < numPartitions) |
| 75 | transforms.append(make(description)); | 76 | transforms.append(make(description)); |
| 76 | 77 | ||
| 77 | - if (numPartitions < 2 && !bootStrap) { | ||
| 78 | - transforms.first()->train(data); | ||
| 79 | - return; | ||
| 80 | - } | ||
| 81 | - | ||
| 82 | QFutureSynchronizer<void> futures; | 78 | QFutureSynchronizer<void> futures; |
| 83 | for (int i=0; i<numPartitions; i++) { | 79 | for (int i=0; i<numPartitions; i++) { |
| 84 | TemplateList partitionedData = data; | 80 | TemplateList partitionedData = data; |
| @@ -102,9 +98,10 @@ class CrossValidateTransform : public MetaTransform | @@ -102,9 +98,10 @@ class CrossValidateTransform : public MetaTransform | ||
| 102 | 98 | ||
| 103 | void project(const TemplateList &src, TemplateList &dst) const | 99 | void project(const TemplateList &src, TemplateList &dst) const |
| 104 | { | 100 | { |
| 105 | - TemplateList partitioned = src.partition(inputVariable, bootStrap); | 101 | + TemplateList partitioned = src.partition(inputVariable, randomSeed); |
| 102 | + const int crossValidate = Globals->crossValidate; | ||
| 106 | 103 | ||
| 107 | - if (bootStrap > 0) { | 104 | + if (crossValidate == 1) { |
| 108 | transforms[0]->project(partitioned, dst); | 105 | transforms[0]->project(partitioned, dst); |
| 109 | return; | 106 | return; |
| 110 | } | 107 | } |