Commit 6a4ba0bfe364387f737fd5ec11c2d7e4db953f49

Authored by bhklein
1 parent 1761523e

Randomly split data into training/testing set

openbr/openbr_plugin.cpp
... ... @@ -28,6 +28,7 @@
28 28 #include <QtConcurrentRun>
29 29 #include <algorithm>
30 30 #include <iostream>
  31 +#include <RandomLib/Random.hpp>
31 32  
32 33 #ifndef BR_EMBEDDED
33 34 #include <QApplication>
... ... @@ -528,28 +529,47 @@ QList&lt;int&gt; TemplateList::applyIndex(const QString &amp;propName, const QHash&lt;QString
528 529 return result;
529 530 }
530 531  
531   -TemplateList TemplateList::partition(const QString &inputVariable) const
  532 +TemplateList TemplateList::partition(const QString &inputVariable, int bootStrap) const
532 533 {
533   - const int crossValidate = Globals->crossValidate;
534   - if (crossValidate < 2)
535   - return *this;
536   -
537 534 TemplateList partitioned = *this;
538   - for (int i=partitioned.size()-1; i>=0; i--) {
539   - // See CrossValidateTransform for description of these variables
540   - if (partitioned[i].file.getBool("duplicatePartitions")) {
541   - for (int j=crossValidate-1; j>=0; j--) {
542   - Template duplicateTemplate = partitioned[i];
543   - duplicateTemplate.file.set("Partition", j);
544   - partitioned.insert(i+1, duplicateTemplate);
545   - }
546   - } else if (partitioned[i].file.getBool("allPartitions")) {
547   - partitioned[i].file.set("Partition", -1);
548   - } else {
549   - if (!partitioned[i].file.contains(("Partition"))) {
550   - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
551   - // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow
552   - partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate);
  535 + if (bootStrap > 0) {
  536 + // Use separate RNG from Common::randN() to avoid re-seeding the global RNG.
  537 + // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same bootStrap seed.
  538 + RandomLib::Random rng;
  539 + for (int i=0; i < partitioned.size(); i++) {
  540 + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
  541 + quint64 seed = md5.toHex().right(8).toULongLong(0, 16) + bootStrap;
  542 + rng.Reseed(seed);
  543 + float result = rng.FloatN();
  544 +
  545 + // Roughly 2/3rd to training, 1/3rd to testing
  546 + if (result <= (2.0/3.0))
  547 + // Training
  548 + partitioned[i].file.set("Partition", QString::number(1));
  549 + else
  550 + partitioned[i].file.set("Partition", QString::number(0));
  551 + }
  552 + } else {
  553 + const int crossValidate = Globals->crossValidate;
  554 + if (crossValidate < 2)
  555 + return *this;
  556 +
  557 + for (int i=partitioned.size()-1; i>=0; i--) {
  558 + // See CrossValidateTransform for description of these variables
  559 + if (partitioned[i].file.getBool("duplicatePartitions")) {
  560 + for (int j=crossValidate-1; j>=0; j--) {
  561 + Template duplicateTemplate = partitioned[i];
  562 + duplicateTemplate.file.set("Partition", j);
  563 + partitioned.insert(i+1, duplicateTemplate);
  564 + }
  565 + } else if (partitioned[i].file.getBool("allPartitions")) {
  566 + partitioned[i].file.set("Partition", -1);
  567 + } else {
  568 + if (!partitioned[i].file.contains(("Partition"))) {
  569 + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
  570 + // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow
  571 + partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate);
  572 + }
553 573 }
554 574 }
555 575 }
... ...
openbr/openbr_plugin.h
... ... @@ -459,7 +459,7 @@ struct TemplateList : public QList&lt;Template&gt;
459 459 BR_EXPORT static TemplateList relabel(const TemplateList &tl, const QString &propName, bool preserveIntegers);
460 460  
461 461 /*!< \brief Assign templates to folds partitions. */
462   - BR_EXPORT TemplateList partition(const QString &inputVariable) const;
  462 + BR_EXPORT TemplateList partition(const QString &inputVariable, int bootStrap = 0) const;
463 463  
464 464 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const;
465 465 BR_EXPORT QList<int> indexProperty(const QString &propName, QHash<QString, int> &valueMap, QHash<int, QVariant> &reverseLookup) const;
... ...
openbr/plugins/core/crossvalidate.cpp
... ... @@ -56,8 +56,10 @@ class CrossValidateTransform : public MetaTransform
56 56 Q_OBJECT
57 57 Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)
58 58 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  59 + Q_PROPERTY(int bootStrap READ get_bootStrap WRITE set_bootStrap RESET reset_bootStrap STORED false)
59 60 BR_PROPERTY(QString, description, "Identity")
60 61 BR_PROPERTY(QString, inputVariable, "Label")
  62 + BR_PROPERTY(int, bootStrap, 0)
61 63  
62 64 // numPartitions copies of transform specified by description.
63 65 QList<br::Transform*> transforms;
... ... @@ -67,13 +69,12 @@ class CrossValidateTransform : public MetaTransform
67 69 // is generally incorrect behavior.
68 70 void train(const TemplateList &data)
69 71 {
70   - QList<int> partitions = data.partition(inputVariable).files().crossValidationPartitions();
71   - const int numPartitions = Common::Max(partitions)+1;
72   -
  72 + QList<int> partitions = data.partition(inputVariable, bootStrap).files().crossValidationPartitions();
  73 + const int numPartitions = (bootStrap > 0) ? 1 : Common::Max(partitions)+1;
73 74 while (transforms.size() < numPartitions)
74 75 transforms.append(make(description));
75 76  
76   - if (numPartitions < 2) {
  77 + if (numPartitions < 2 && !(bootStrap > 0)) {
77 78 transforms.first()->train(data);
78 79 return;
79 80 }
... ... @@ -101,8 +102,12 @@ class CrossValidateTransform : public MetaTransform
101 102  
102 103 void project(const TemplateList &src, TemplateList &dst) const
103 104 {
104   - TemplateList partitioned = src.partition(inputVariable);
  105 + TemplateList partitioned = src.partition(inputVariable, bootStrap);
105 106  
  107 + if (bootStrap > 0) {
  108 + transforms[0]->project(partitioned, dst);
  109 + return;
  110 + }
106 111 for (int i=0; i<partitioned.size(); i++) {
107 112 int partition = partitioned[i].file.get<int>("Partition", 0);
108 113 transforms[partition]->project(partitioned, dst);
... ...