Commit 6a4ba0bfe364387f737fd5ec11c2d7e4db953f49
1 parent
1761523e
Randomly split data into training/testing set
Showing
3 changed files
with
51 additions
and
26 deletions
openbr/openbr_plugin.cpp
| @@ -28,6 +28,7 @@ | @@ -28,6 +28,7 @@ | ||
| 28 | #include <QtConcurrentRun> | 28 | #include <QtConcurrentRun> |
| 29 | #include <algorithm> | 29 | #include <algorithm> |
| 30 | #include <iostream> | 30 | #include <iostream> |
| 31 | +#include <RandomLib/Random.hpp> | ||
| 31 | 32 | ||
| 32 | #ifndef BR_EMBEDDED | 33 | #ifndef BR_EMBEDDED |
| 33 | #include <QApplication> | 34 | #include <QApplication> |
| @@ -528,28 +529,47 @@ QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString | @@ -528,28 +529,47 @@ QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString | ||
| 528 | return result; | 529 | return result; |
| 529 | } | 530 | } |
| 530 | 531 | ||
| 531 | -TemplateList TemplateList::partition(const QString &inputVariable) const | 532 | +TemplateList TemplateList::partition(const QString &inputVariable, int bootStrap) const |
| 532 | { | 533 | { |
| 533 | - const int crossValidate = Globals->crossValidate; | ||
| 534 | - if (crossValidate < 2) | ||
| 535 | - return *this; | ||
| 536 | - | ||
| 537 | TemplateList partitioned = *this; | 534 | TemplateList partitioned = *this; |
| 538 | - for (int i=partitioned.size()-1; i>=0; i--) { | ||
| 539 | - // See CrossValidateTransform for description of these variables | ||
| 540 | - if (partitioned[i].file.getBool("duplicatePartitions")) { | ||
| 541 | - for (int j=crossValidate-1; j>=0; j--) { | ||
| 542 | - Template duplicateTemplate = partitioned[i]; | ||
| 543 | - duplicateTemplate.file.set("Partition", j); | ||
| 544 | - partitioned.insert(i+1, duplicateTemplate); | ||
| 545 | - } | ||
| 546 | - } else if (partitioned[i].file.getBool("allPartitions")) { | ||
| 547 | - partitioned[i].file.set("Partition", -1); | ||
| 548 | - } else { | ||
| 549 | - if (!partitioned[i].file.contains(("Partition"))) { | ||
| 550 | - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); | ||
| 551 | - // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow | ||
| 552 | - partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); | 535 | + if (bootStrap > 0) { |
| 536 | + // Use separate RNG from Common::randN() to avoid re-seeding the global RNG. | ||
| 537 | + // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same bootStrap seed. | ||
| 538 | + RandomLib::Random rng; | ||
| 539 | + for (int i=0; i < partitioned.size(); i++) { | ||
| 540 | + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); | ||
| 541 | + quint64 seed = md5.toHex().right(8).toULongLong(0, 16) + bootStrap; | ||
| 542 | + rng.Reseed(seed); | ||
| 543 | + float result = rng.FloatN(); | ||
| 544 | + | ||
| 545 | + // Roughly 2/3rd to training, 1/3rd to testing | ||
| 546 | + if (result <= (2.0/3.0)) | ||
| 547 | + // Training | ||
| 548 | + partitioned[i].file.set("Partition", QString::number(1)); | ||
| 549 | + else | ||
| 550 | + partitioned[i].file.set("Partition", QString::number(0)); | ||
| 551 | + } | ||
| 552 | + } else { | ||
| 553 | + const int crossValidate = Globals->crossValidate; | ||
| 554 | + if (crossValidate < 2) | ||
| 555 | + return *this; | ||
| 556 | + | ||
| 557 | + for (int i=partitioned.size()-1; i>=0; i--) { | ||
| 558 | + // See CrossValidateTransform for description of these variables | ||
| 559 | + if (partitioned[i].file.getBool("duplicatePartitions")) { | ||
| 560 | + for (int j=crossValidate-1; j>=0; j--) { | ||
| 561 | + Template duplicateTemplate = partitioned[i]; | ||
| 562 | + duplicateTemplate.file.set("Partition", j); | ||
| 563 | + partitioned.insert(i+1, duplicateTemplate); | ||
| 564 | + } | ||
| 565 | + } else if (partitioned[i].file.getBool("allPartitions")) { | ||
| 566 | + partitioned[i].file.set("Partition", -1); | ||
| 567 | + } else { | ||
| 568 | + if (!partitioned[i].file.contains(("Partition"))) { | ||
| 569 | + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); | ||
| 570 | + // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow | ||
| 571 | + partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); | ||
| 572 | + } | ||
| 553 | } | 573 | } |
| 554 | } | 574 | } |
| 555 | } | 575 | } |
openbr/openbr_plugin.h
| @@ -459,7 +459,7 @@ struct TemplateList : public QList<Template> | @@ -459,7 +459,7 @@ struct TemplateList : public QList<Template> | ||
| 459 | BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers); | 459 | BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers); |
| 460 | 460 | ||
| 461 | /*!< \brief Assign templates to folds partitions. */ | 461 | /*!< \brief Assign templates to folds partitions. */ |
| 462 | - BR_EXPORT TemplateList partition(const QString &inputVariable) const; | 462 | + BR_EXPORT TemplateList partition(const QString &inputVariable, int bootStrap = 0) const; |
| 463 | 463 | ||
| 464 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const; | 464 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const; |
| 465 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const; | 465 | BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const; |
openbr/plugins/core/crossvalidate.cpp
| @@ -56,8 +56,10 @@ class CrossValidateTransform : public MetaTransform | @@ -56,8 +56,10 @@ class CrossValidateTransform : public MetaTransform | ||
| 56 | Q_OBJECT | 56 | Q_OBJECT |
| 57 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) | 57 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) |
| 58 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | 58 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 59 | + Q_PROPERTY(int bootStrap READ get_bootStrap WRITE set_bootStrap RESET reset_bootStrap STORED false) | ||
| 59 | BR_PROPERTY(QString, description, "Identity") | 60 | BR_PROPERTY(QString, description, "Identity") |
| 60 | BR_PROPERTY(QString, inputVariable, "Label") | 61 | BR_PROPERTY(QString, inputVariable, "Label") |
| 62 | + BR_PROPERTY(int, bootStrap, 0) | ||
| 61 | 63 | ||
| 62 | // numPartitions copies of transform specified by description. | 64 | // numPartitions copies of transform specified by description. |
| 63 | QList<br::Transform*> transforms; | 65 | QList<br::Transform*> transforms; |
| @@ -67,13 +69,12 @@ class CrossValidateTransform : public MetaTransform | @@ -67,13 +69,12 @@ class CrossValidateTransform : public MetaTransform | ||
| 67 | // is generally incorrect behavior. | 69 | // is generally incorrect behavior. |
| 68 | void train(const TemplateList &data) | 70 | void train(const TemplateList &data) |
| 69 | { | 71 | { |
| 70 | - QList<int> partitions = data.partition(inputVariable).files().crossValidationPartitions(); | ||
| 71 | - const int numPartitions = Common::Max(partitions)+1; | ||
| 72 | - | 72 | + QList<int> partitions = data.partition(inputVariable, bootStrap).files().crossValidationPartitions(); |
| 73 | + const int numPartitions = (bootStrap > 0) ? 1 : Common::Max(partitions)+1; | ||
| 73 | while (transforms.size() < numPartitions) | 74 | while (transforms.size() < numPartitions) |
| 74 | transforms.append(make(description)); | 75 | transforms.append(make(description)); |
| 75 | 76 | ||
| 76 | - if (numPartitions < 2) { | 77 | + if (numPartitions < 2 && !(bootStrap > 0)) { |
| 77 | transforms.first()->train(data); | 78 | transforms.first()->train(data); |
| 78 | return; | 79 | return; |
| 79 | } | 80 | } |
| @@ -101,8 +102,12 @@ class CrossValidateTransform : public MetaTransform | @@ -101,8 +102,12 @@ class CrossValidateTransform : public MetaTransform | ||
| 101 | 102 | ||
| 102 | void project(const TemplateList &src, TemplateList &dst) const | 103 | void project(const TemplateList &src, TemplateList &dst) const |
| 103 | { | 104 | { |
| 104 | - TemplateList partitioned = src.partition(inputVariable); | 105 | + TemplateList partitioned = src.partition(inputVariable, bootStrap); |
| 105 | 106 | ||
| 107 | + if (bootStrap > 0) { | ||
| 108 | + transforms[0]->project(partitioned, dst); | ||
| 109 | + return; | ||
| 110 | + } | ||
| 106 | for (int i=0; i<partitioned.size(); i++) { | 111 | for (int i=0; i<partitioned.size(); i++) { |
| 107 | int partition = partitioned[i].file.get<int>("Partition", 0); | 112 | int partition = partitioned[i].file.get<int>("Partition", 0); |
| 108 | transforms[partition]->project(partitioned, dst); | 113 | transforms[partition]->project(partitioned, dst); |