diff --git a/app/br/br.cpp b/app/br/br.cpp index 2e2293a..75dc390 100644 --- a/app/br/br.cpp +++ b/app/br/br.cpp @@ -69,6 +69,7 @@ public: bool daemon = false; const char *daemon_pipe = NULL; + bool isInt = false; while (daemon || (argc > 0)) { const char *fun; int parc; @@ -78,7 +79,12 @@ public: fun = argv[0]; if (fun[0] == '-') fun++; - parc = 0; while ((parc+1 < argc) && (argv[parc+1][0] != '-')) parc++; + parc = 0; + QString(argv[parc+1]).toInt(&isInt); + while ((parc+1 < argc) && ((argv[parc+1][0] != '-') || isInt)) { + parc++; + QString(argv[parc+1]).toInt(&isInt); + } parv = (const char **)&argv[1]; argc = argc - (parc+1); argv = &argv[parc+1]; diff --git a/openbr/openbr_plugin.cpp b/openbr/openbr_plugin.cpp index 7e19644..fec49b8 100644 --- a/openbr/openbr_plugin.cpp +++ b/openbr/openbr_plugin.cpp @@ -531,47 +531,34 @@ QList TemplateList::applyIndex(const QString &propName, const QHashcrossValidate; + const int crossValidate = std::abs(Globals->crossValidate); + if (crossValidate < 2) + return *this; // Use separate RNG from Common::randN() to avoid re-seeding the global RNG. // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same randomSeed seed. RandomLib::Random rng; - if (crossValidate == 1) { - for (int i=0; i < partitioned.size(); i++) { + + TemplateList partitioned = *this; + for (int i=partitioned.size()-1; i>=0; i--) { + // See CrossValidateTransform for description of these variables + if (partitioned[i].file.getBool("duplicatePartitions")) { + for (int j=crossValidate-1; j>=0; j--) { + Template duplicateTemplate = partitioned[i]; + duplicateTemplate.file.set("Partition", j); + partitioned.insert(i+1, duplicateTemplate); + } + } else if (partitioned[i].file.getBool("allPartitions")) { + partitioned[i].file.set("Partition", -1); + } else { const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get(inputVariable).toLatin1(), QCryptographicHash::Md5); - quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed; - rng.Reseed(labelSeed); - - // Roughly 2/3rd to training, 1/3rd to testing for the special single split case - float result = rng.FloatN(); - if (result <= (2.0/3.0)) - // Training - partitioned[i].file.set("Partition", QString::number(1)); - else - partitioned[i].file.set("Partition", QString::number(0)); - } - } else { - for (int i=partitioned.size()-1; i>=0; i--) { - // See CrossValidateTransform for description of these variables - if (partitioned[i].file.getBool("duplicatePartitions")) { - for (int j=crossValidate-1; j>=0; j--) { - Template duplicateTemplate = partitioned[i]; - duplicateTemplate.file.set("Partition", j); - partitioned.insert(i+1, duplicateTemplate); - } - } else if (partitioned[i].file.getBool("allPartitions")) { - partitioned[i].file.set("Partition", -1); - } else { - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get(inputVariable).toLatin1(), QCryptographicHash::Md5); - if (randomSeed) { - quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed; - rng.Reseed(labelSeed); - partitioned[i].file.set("Partition", rng.Integer() % crossValidate); - } else if (!partitioned[i].file.contains(("Partition"))) { - // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow - partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); - } + if (randomSeed) { + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed; + rng.Reseed(labelSeed); + partitioned[i].file.set("Partition", rng.Integer() % crossValidate); + } else if (!partitioned[i].file.contains(("Partition"))) { + // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow + partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); } } } diff --git a/openbr/plugins/core/crossvalidate.cpp b/openbr/plugins/core/crossvalidate.cpp index d533b22..8afa83f 100644 --- a/openbr/plugins/core/crossvalidate.cpp +++ b/openbr/plugins/core/crossvalidate.cpp @@ -71,10 +71,16 @@ class CrossValidateTransform : public MetaTransform { QList partitions = data.partition(inputVariable, randomSeed).files().crossValidationPartitions(); const int crossValidate = Globals->crossValidate; - const int numPartitions = (crossValidate == 1) ? 1 : Common::Max(partitions)+1; + // Only train once based on the 0th partition if crossValidate is negative. + const int numPartitions = (crossValidate < 0) ? 1 : Common::Max(partitions)+1; while (transforms.size() < numPartitions) transforms.append(make(description)); + if (std::abs(crossValidate) < 2) { + transforms.first()->train(data); + return; + } + QFutureSynchronizer futures; for (int i=0; icrossValidate; - if (crossValidate == 1) { + if (crossValidate < 0) { transforms[0]->project(partitioned, dst); return; }