diff --git a/openbr/openbr_plugin.cpp b/openbr/openbr_plugin.cpp index 6159e49..214baa4 100644 --- a/openbr/openbr_plugin.cpp +++ b/openbr/openbr_plugin.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #ifndef BR_EMBEDDED #include @@ -528,28 +529,47 @@ QList TemplateList::applyIndex(const QString &propName, const QHashcrossValidate; - if (crossValidate < 2) - return *this; - TemplateList partitioned = *this; - for (int i=partitioned.size()-1; i>=0; i--) { - // See CrossValidateTransform for description of these variables - if (partitioned[i].file.getBool("duplicatePartitions")) { - for (int j=crossValidate-1; j>=0; j--) { - Template duplicateTemplate = partitioned[i]; - duplicateTemplate.file.set("Partition", j); - partitioned.insert(i+1, duplicateTemplate); - } - } else if (partitioned[i].file.getBool("allPartitions")) { - partitioned[i].file.set("Partition", -1); - } else { - if (!partitioned[i].file.contains(("Partition"))) { - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get(inputVariable).toLatin1(), QCryptographicHash::Md5); - // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow - partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); + if (bootStrap > 0) { + // Use separate RNG from Common::randN() to avoid re-seeding the global RNG. + // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same bootStrap seed. + RandomLib::Random rng; + for (int i=0; i < partitioned.size(); i++) { + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get(inputVariable).toLatin1(), QCryptographicHash::Md5); + quint64 seed = md5.toHex().right(8).toULongLong(0, 16) + bootStrap; + rng.Reseed(seed); + float result = rng.FloatN(); + + // Roughly 2/3rd to training, 1/3rd to testing + if (result <= (2.0/3.0)) + // Training + partitioned[i].file.set("Partition", QString::number(1)); + else + partitioned[i].file.set("Partition", QString::number(0)); + } + } else { + const int crossValidate = Globals->crossValidate; + if (crossValidate < 2) + return *this; + + for (int i=partitioned.size()-1; i>=0; i--) { + // See CrossValidateTransform for description of these variables + if (partitioned[i].file.getBool("duplicatePartitions")) { + for (int j=crossValidate-1; j>=0; j--) { + Template duplicateTemplate = partitioned[i]; + duplicateTemplate.file.set("Partition", j); + partitioned.insert(i+1, duplicateTemplate); + } + } else if (partitioned[i].file.getBool("allPartitions")) { + partitioned[i].file.set("Partition", -1); + } else { + if (!partitioned[i].file.contains(("Partition"))) { + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get(inputVariable).toLatin1(), QCryptographicHash::Md5); + // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow + partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); + } } } } diff --git a/openbr/openbr_plugin.h b/openbr/openbr_plugin.h index d8bc9b2..a7424a2 100644 --- a/openbr/openbr_plugin.h +++ b/openbr/openbr_plugin.h @@ -459,7 +459,7 @@ struct TemplateList : public QList