Commit 6a4ba0bfe364387f737fd5ec11c2d7e4db953f49
1 parent
1761523e
Randomly split data into training/testing set
Showing
3 changed files
with
51 additions
and
26 deletions
openbr/openbr_plugin.cpp
| ... | ... | @@ -28,6 +28,7 @@ |
| 28 | 28 | #include <QtConcurrentRun> |
| 29 | 29 | #include <algorithm> |
| 30 | 30 | #include <iostream> |
| 31 | +#include <RandomLib/Random.hpp> | |
| 31 | 32 | |
| 32 | 33 | #ifndef BR_EMBEDDED |
| 33 | 34 | #include <QApplication> |
| ... | ... | @@ -528,28 +529,47 @@ QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString |
| 528 | 529 | return result; |
| 529 | 530 | } |
| 530 | 531 | |
| 531 | -TemplateList TemplateList::partition(const QString &inputVariable) const | |
| 532 | +TemplateList TemplateList::partition(const QString &inputVariable, int bootStrap) const | |
| 532 | 533 | { |
| 533 | - const int crossValidate = Globals->crossValidate; | |
| 534 | - if (crossValidate < 2) | |
| 535 | - return *this; | |
| 536 | - | |
| 537 | 534 | TemplateList partitioned = *this; |
| 538 | - for (int i=partitioned.size()-1; i>=0; i--) { | |
| 539 | - // See CrossValidateTransform for description of these variables | |
| 540 | - if (partitioned[i].file.getBool("duplicatePartitions")) { | |
| 541 | - for (int j=crossValidate-1; j>=0; j--) { | |
| 542 | - Template duplicateTemplate = partitioned[i]; | |
| 543 | - duplicateTemplate.file.set("Partition", j); | |
| 544 | - partitioned.insert(i+1, duplicateTemplate); | |
| 545 | - } | |
| 546 | - } else if (partitioned[i].file.getBool("allPartitions")) { | |
| 547 | - partitioned[i].file.set("Partition", -1); | |
| 548 | - } else { | |
| 549 | - if (!partitioned[i].file.contains(("Partition"))) { | |
| 550 | - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); | |
| 551 | - // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow | |
| 552 | - partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); | |
| 535 | + if (bootStrap > 0) { | |
| 536 | + // Use separate RNG from Common::randN() to avoid re-seeding the global RNG. | |
| 537 | + // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same bootStrap seed. | |
| 538 | + RandomLib::Random rng; | |
| 539 | + for (int i=0; i < partitioned.size(); i++) { | |
| 540 | + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); | |
| 541 | + quint64 seed = md5.toHex().right(8).toULongLong(0, 16) + bootStrap; | |
| 542 | + rng.Reseed(seed); | |
| 543 | + float result = rng.FloatN(); | |
| 544 | + | |
| 545 | + // Roughly 2/3rd to training, 1/3rd to testing | |
| 546 | + if (result <= (2.0/3.0)) | |
| 547 | + // Training | |
| 548 | + partitioned[i].file.set("Partition", QString::number(1)); | |
| 549 | + else | |
| 550 | + partitioned[i].file.set("Partition", QString::number(0)); | |
| 551 | + } | |
| 552 | + } else { | |
| 553 | + const int crossValidate = Globals->crossValidate; | |
| 554 | + if (crossValidate < 2) | |
| 555 | + return *this; | |
| 556 | + | |
| 557 | + for (int i=partitioned.size()-1; i>=0; i--) { | |
| 558 | + // See CrossValidateTransform for description of these variables | |
| 559 | + if (partitioned[i].file.getBool("duplicatePartitions")) { | |
| 560 | + for (int j=crossValidate-1; j>=0; j--) { | |
| 561 | + Template duplicateTemplate = partitioned[i]; | |
| 562 | + duplicateTemplate.file.set("Partition", j); | |
| 563 | + partitioned.insert(i+1, duplicateTemplate); | |
| 564 | + } | |
| 565 | + } else if (partitioned[i].file.getBool("allPartitions")) { | |
| 566 | + partitioned[i].file.set("Partition", -1); | |
| 567 | + } else { | |
| 568 | + if (!partitioned[i].file.contains(("Partition"))) { | |
| 569 | + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); | |
| 570 | + // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow | |
| 571 | + partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); | |
| 572 | + } | |
| 553 | 573 | } |
| 554 | 574 | } |
| 555 | 575 | } | ... | ... |
openbr/openbr_plugin.h
| ... | ... | @@ -459,7 +459,7 @@ struct TemplateList : public QList<Template> |
| 459 | 459 | BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers); |
| 460 | 460 | |
| 461 | 461 | /*!< \brief Assign templates to folds partitions. */ |
| 462 | - BR_EXPORT TemplateList partition(const QString &inputVariable) const; | |
| 462 | + BR_EXPORT TemplateList partition(const QString &inputVariable, int bootStrap = 0) const; | |
| 463 | 463 | |
| 464 | 464 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const; |
| 465 | 465 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const; | ... | ... |
openbr/plugins/core/crossvalidate.cpp
| ... | ... | @@ -56,8 +56,10 @@ class CrossValidateTransform : public MetaTransform |
| 56 | 56 | Q_OBJECT |
| 57 | 57 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) |
| 58 | 58 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 59 | + Q_PROPERTY(int bootStrap READ get_bootStrap WRITE set_bootStrap RESET reset_bootStrap STORED false) | |
| 59 | 60 | BR_PROPERTY(QString, description, "Identity") |
| 60 | 61 | BR_PROPERTY(QString, inputVariable, "Label") |
| 62 | + BR_PROPERTY(int, bootStrap, 0) | |
| 61 | 63 | |
| 62 | 64 | // numPartitions copies of transform specified by description. |
| 63 | 65 | QList<br::Transform*> transforms; |
| ... | ... | @@ -67,13 +69,12 @@ class CrossValidateTransform : public MetaTransform |
| 67 | 69 | // is generally incorrect behavior. |
| 68 | 70 | void train(const TemplateList &data) |
| 69 | 71 | { |
| 70 | - QList<int> partitions = data.partition(inputVariable).files().crossValidationPartitions(); | |
| 71 | - const int numPartitions = Common::Max(partitions)+1; | |
| 72 | - | |
| 72 | + QList<int> partitions = data.partition(inputVariable, bootStrap).files().crossValidationPartitions(); | |
| 73 | + const int numPartitions = (bootStrap > 0) ? 1 : Common::Max(partitions)+1; | |
| 73 | 74 | while (transforms.size() < numPartitions) |
| 74 | 75 | transforms.append(make(description)); |
| 75 | 76 | |
| 76 | - if (numPartitions < 2) { | |
| 77 | + if (numPartitions < 2 && !(bootStrap > 0)) { | |
| 77 | 78 | transforms.first()->train(data); |
| 78 | 79 | return; |
| 79 | 80 | } |
| ... | ... | @@ -101,8 +102,12 @@ class CrossValidateTransform : public MetaTransform |
| 101 | 102 | |
| 102 | 103 | void project(const TemplateList &src, TemplateList &dst) const |
| 103 | 104 | { |
| 104 | - TemplateList partitioned = src.partition(inputVariable); | |
| 105 | + TemplateList partitioned = src.partition(inputVariable, bootStrap); | |
| 105 | 106 | |
| 107 | + if (bootStrap > 0) { | |
| 108 | + transforms[0]->project(partitioned, dst); | |
| 109 | + return; | |
| 110 | + } | |
| 106 | 111 | for (int i=0; i<partitioned.size(); i++) { |
| 107 | 112 | int partition = partitioned[i].file.get<int>("Partition", 0); |
| 108 | 113 | transforms[partition]->project(partitioned, dst); | ... | ... |