diff --git a/openbr/openbr_plugin.cpp b/openbr/openbr_plugin.cpp index 53192f5..7e19644 100644 --- a/openbr/openbr_plugin.cpp +++ b/openbr/openbr_plugin.cpp @@ -529,20 +529,22 @@ QList TemplateList::applyIndex(const QString &propName, const QHash 0) { - // Use separate RNG from Common::randN() to avoid re-seeding the global RNG. - // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same bootStrap seed. - RandomLib::Random rng; + const int crossValidate = Globals->crossValidate; + + // Use separate RNG from Common::randN() to avoid re-seeding the global RNG. + // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same randomSeed seed. + RandomLib::Random rng; + if (crossValidate == 1) { for (int i=0; i < partitioned.size(); i++) { const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get(inputVariable).toLatin1(), QCryptographicHash::Md5); - quint64 seed = md5.toHex().right(8).toULongLong(0, 16) + bootStrap; - rng.Reseed(seed); - float result = rng.FloatN(); + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed; + rng.Reseed(labelSeed); - // Roughly 2/3rd to training, 1/3rd to testing + // Roughly 2/3rd to training, 1/3rd to testing for the special single split case + float result = rng.FloatN(); if (result <= (2.0/3.0)) // Training partitioned[i].file.set("Partition", QString::number(1)); @@ -550,10 +552,6 @@ TemplateList TemplateList::partition(const QString &inputVariable, unsigned int partitioned[i].file.set("Partition", QString::number(0)); } } else { - const int crossValidate = Globals->crossValidate; - if (crossValidate < 2) - return *this; - for (int i=partitioned.size()-1; i>=0; i--) { // See CrossValidateTransform for description of these variables if (partitioned[i].file.getBool("duplicatePartitions")) { @@ -565,8 +563,12 @@ TemplateList TemplateList::partition(const QString &inputVariable, unsigned int } else if (partitioned[i].file.getBool("allPartitions")) { partitioned[i].file.set("Partition", -1); } else { - if (!partitioned[i].file.contains(("Partition"))) { - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get(inputVariable).toLatin1(), QCryptographicHash::Md5); + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get(inputVariable).toLatin1(), QCryptographicHash::Md5); + if (randomSeed) { + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed; + rng.Reseed(labelSeed); + partitioned[i].file.set("Partition", rng.Integer() % crossValidate); + } else if (!partitioned[i].file.contains(("Partition"))) { // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); } diff --git a/openbr/openbr_plugin.h b/openbr/openbr_plugin.h index 74bc050..77dfbb1 100644 --- a/openbr/openbr_plugin.h +++ b/openbr/openbr_plugin.h @@ -459,7 +459,7 @@ struct TemplateList : public QList