diff --git a/app/br/br.cpp b/app/br/br.cpp index 2e2293a..75dc390 100644 --- a/app/br/br.cpp +++ b/app/br/br.cpp @@ -69,6 +69,7 @@ public: bool daemon = false; const char *daemon_pipe = NULL; + bool isInt = false; while (daemon || (argc > 0)) { const char *fun; int parc; @@ -78,7 +79,12 @@ public: fun = argv[0]; if (fun[0] == '-') fun++; - parc = 0; while ((parc+1 < argc) && (argv[parc+1][0] != '-')) parc++; + parc = 0; + QString(argv[parc+1]).toInt(&isInt); + while ((parc+1 < argc) && ((argv[parc+1][0] != '-') || isInt)) { + parc++; + QString(argv[parc+1]).toInt(&isInt); + } parv = (const char **)&argv[1]; argc = argc - (parc+1); argv = &argv[parc+1]; diff --git a/openbr/core/bee.cpp b/openbr/core/bee.cpp index 2f1e1be..1611ef8 100644 --- a/openbr/core/bee.cpp +++ b/openbr/core/bee.cpp @@ -230,7 +230,7 @@ void makeMask(const QString &targetInput, const QString &queryInput, const QStri const FileList targets = TemplateList::fromGallery(targetInput).files(); const FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files(); const int partitions = targets.first().get("crossValidate"); - if (partitions == 0) { + if (partitions <= 0) { writeMatrix(makeMask(targets, queries), mask, targetInput, queryInput); } else { if (!mask.contains("%1")) qFatal("Mask file name missing partition number place marker (%%1)"); @@ -246,7 +246,7 @@ void makePairwiseMask(const QString &targetInput, const QString &queryInput, con const FileList targets = TemplateList::fromGallery(targetInput).files(); const FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files(); const int partitions = targets.first().get("crossValidate"); - if (partitions == 0) { + if (partitions <= 0) { writeMatrix(makePairwiseMask(targets, queries), mask, targetInput, queryInput); } else { if (!mask.contains("%1")) qFatal("Mask file name missing partition number place marker (%%1)"); diff --git a/openbr/core/core.cpp b/openbr/core/core.cpp index d1ed798..bd84b93 100644 --- a/openbr/core/core.cpp +++ b/openbr/core/core.cpp @@ -69,7 +69,7 @@ struct AlgorithmCore QScopedPointer trainingWrapper(br::wrapTransform(transform.data(), "Stream(readMode=DistributeFrames)")); TemplateList data(TemplateList::fromGallery(input)); - if (Globals->crossValidate > 1) + if (abs(Globals->crossValidate) > 1) for (int i=data.size()-1; i>=0; i--) if (data[i].file.get("allPartitions",false) || data[i].file.get("duplicatePartitions",false)) data.removeAt(i); diff --git a/openbr/openbr_plugin.cpp b/openbr/openbr_plugin.cpp index 6159e49..b9f4fb4 100644 --- a/openbr/openbr_plugin.cpp +++ b/openbr/openbr_plugin.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #ifndef BR_EMBEDDED #include @@ -433,7 +434,7 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) if (gallery.getBool("reduce")) newTemplates = newTemplates.reduced(); - if (Globals->crossValidate > 1) + if (abs(Globals->crossValidate) > 1) newTemplates = newTemplates.partition("Label"); if (!templates.isEmpty() && gallery.get("merge", false)) { @@ -528,12 +529,16 @@ QList TemplateList::applyIndex(const QString &propName, const QHashcrossValidate; + const int crossValidate = std::abs(Globals->crossValidate); if (crossValidate < 2) return *this; + // Use separate RNG from Common::randN() to avoid re-seeding the global RNG. + // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same randomSeed seed. + RandomLib::Random rng; + TemplateList partitioned = *this; for (int i=partitioned.size()-1; i>=0; i--) { // See CrossValidateTransform for description of these variables @@ -546,8 +551,12 @@ TemplateList TemplateList::partition(const QString &inputVariable) const } else if (partitioned[i].file.getBool("allPartitions")) { partitioned[i].file.set("Partition", -1); } else { - if (!partitioned[i].file.contains(("Partition"))) { - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get(inputVariable).toLatin1(), QCryptographicHash::Md5); + const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get(inputVariable).toLatin1(), QCryptographicHash::Md5); + if (randomSeed) { + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed; + rng.Reseed(labelSeed); + partitioned[i].file.set("Partition", rng.Integer() % crossValidate); + } else if (!partitioned[i].file.contains(("Partition"))) { // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); } diff --git a/openbr/openbr_plugin.h b/openbr/openbr_plugin.h index d8bc9b2..77dfbb1 100644 --- a/openbr/openbr_plugin.h +++ b/openbr/openbr_plugin.h @@ -459,7 +459,7 @@ struct TemplateList : public QList