Commit 6df13fa090bbcbd676eda6a51d6e580c032e1f09
1 parent
9102a476
only train/test on 0th partition if crossValidate is neg. accept neg. numbers from CL
Showing
3 changed files
with
38 additions
and
39 deletions
app/br/br.cpp
| ... | ... | @@ -69,6 +69,7 @@ public: |
| 69 | 69 | |
| 70 | 70 | bool daemon = false; |
| 71 | 71 | const char *daemon_pipe = NULL; |
| 72 | + bool isInt = false; | |
| 72 | 73 | while (daemon || (argc > 0)) { |
| 73 | 74 | const char *fun; |
| 74 | 75 | int parc; |
| ... | ... | @@ -78,7 +79,12 @@ public: |
| 78 | 79 | |
| 79 | 80 | fun = argv[0]; |
| 80 | 81 | if (fun[0] == '-') fun++; |
| 81 | - parc = 0; while ((parc+1 < argc) && (argv[parc+1][0] != '-')) parc++; | |
| 82 | + parc = 0; | |
| 83 | + QString(argv[parc+1]).toInt(&isInt); | |
| 84 | + while ((parc+1 < argc) && ((argv[parc+1][0] != '-') || isInt)) { | |
| 85 | + parc++; | |
| 86 | + QString(argv[parc+1]).toInt(&isInt); | |
| 87 | + } | |
| 82 | 88 | parv = (const char **)&argv[1]; |
| 83 | 89 | argc = argc - (parc+1); |
| 84 | 90 | argv = &argv[parc+1]; | ... | ... |
openbr/openbr_plugin.cpp
| ... | ... | @@ -531,47 +531,34 @@ QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString |
| 531 | 531 | |
| 532 | 532 | TemplateList TemplateList::partition(const QString &inputVariable, unsigned int randomSeed) const |
| 533 | 533 | { |
| 534 | - TemplateList partitioned = *this; | |
| 535 | - const int crossValidate = Globals->crossValidate; | |
| 534 | + const int crossValidate = std::abs(Globals->crossValidate); | |
| 535 | + if (crossValidate < 2) | |
| 536 | + return *this; | |
| 536 | 537 | |
| 537 | 538 | // Use separate RNG from Common::randN() to avoid re-seeding the global RNG. |
| 538 | 539 | // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same randomSeed seed. |
| 539 | 540 | RandomLib::Random rng; |
| 540 | - if (crossValidate == 1) { | |
| 541 | - for (int i=0; i < partitioned.size(); i++) { | |
| 541 | + | |
| 542 | + TemplateList partitioned = *this; | |
| 543 | + for (int i=partitioned.size()-1; i>=0; i--) { | |
| 544 | + // See CrossValidateTransform for description of these variables | |
| 545 | + if (partitioned[i].file.getBool("duplicatePartitions")) { | |
| 546 | + for (int j=crossValidate-1; j>=0; j--) { | |
| 547 | + Template duplicateTemplate = partitioned[i]; | |
| 548 | + duplicateTemplate.file.set("Partition", j); | |
| 549 | + partitioned.insert(i+1, duplicateTemplate); | |
| 550 | + } | |
| 551 | + } else if (partitioned[i].file.getBool("allPartitions")) { | |
| 552 | + partitioned[i].file.set("Partition", -1); | |
| 553 | + } else { | |
| 542 | 554 | const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); |
| 543 | - quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed; | |
| 544 | - rng.Reseed(labelSeed); | |
| 545 | - | |
| 546 | - // Roughly 2/3rd to training, 1/3rd to testing for the special single split case | |
| 547 | - float result = rng.FloatN(); | |
| 548 | - if (result <= (2.0/3.0)) | |
| 549 | - // Training | |
| 550 | - partitioned[i].file.set("Partition", QString::number(1)); | |
| 551 | - else | |
| 552 | - partitioned[i].file.set("Partition", QString::number(0)); | |
| 553 | - } | |
| 554 | - } else { | |
| 555 | - for (int i=partitioned.size()-1; i>=0; i--) { | |
| 556 | - // See CrossValidateTransform for description of these variables | |
| 557 | - if (partitioned[i].file.getBool("duplicatePartitions")) { | |
| 558 | - for (int j=crossValidate-1; j>=0; j--) { | |
| 559 | - Template duplicateTemplate = partitioned[i]; | |
| 560 | - duplicateTemplate.file.set("Partition", j); | |
| 561 | - partitioned.insert(i+1, duplicateTemplate); | |
| 562 | - } | |
| 563 | - } else if (partitioned[i].file.getBool("allPartitions")) { | |
| 564 | - partitioned[i].file.set("Partition", -1); | |
| 565 | - } else { | |
| 566 | - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); | |
| 567 | - if (randomSeed) { | |
| 568 | - quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed; | |
| 569 | - rng.Reseed(labelSeed); | |
| 570 | - partitioned[i].file.set("Partition", rng.Integer() % crossValidate); | |
| 571 | - } else if (!partitioned[i].file.contains(("Partition"))) { | |
| 572 | - // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow | |
| 573 | - partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); | |
| 574 | - } | |
| 555 | + if (randomSeed) { | |
| 556 | + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed; | |
| 557 | + rng.Reseed(labelSeed); | |
| 558 | + partitioned[i].file.set("Partition", rng.Integer() % crossValidate); | |
| 559 | + } else if (!partitioned[i].file.contains(("Partition"))) { | |
| 560 | + // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow | |
| 561 | + partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); | |
| 575 | 562 | } |
| 576 | 563 | } |
| 577 | 564 | } | ... | ... |
openbr/plugins/core/crossvalidate.cpp
| ... | ... | @@ -71,10 +71,16 @@ class CrossValidateTransform : public MetaTransform |
| 71 | 71 | { |
| 72 | 72 | QList<int> partitions = data.partition(inputVariable, randomSeed).files().crossValidationPartitions(); |
| 73 | 73 | const int crossValidate = Globals->crossValidate; |
| 74 | - const int numPartitions = (crossValidate == 1) ? 1 : Common::Max(partitions)+1; | |
| 74 | + // Only train once based on the 0th partition if crossValidate is negative. | |
| 75 | + const int numPartitions = (crossValidate < 0) ? 1 : Common::Max(partitions)+1; | |
| 75 | 76 | while (transforms.size() < numPartitions) |
| 76 | 77 | transforms.append(make(description)); |
| 77 | 78 | |
| 79 | + if (std::abs(crossValidate) < 2) { | |
| 80 | + transforms.first()->train(data); | |
| 81 | + return; | |
| 82 | + } | |
| 83 | + | |
| 78 | 84 | QFutureSynchronizer<void> futures; |
| 79 | 85 | for (int i=0; i<numPartitions; i++) { |
| 80 | 86 | TemplateList partitionedData = data; |
| ... | ... | @@ -101,7 +107,7 @@ class CrossValidateTransform : public MetaTransform |
| 101 | 107 | TemplateList partitioned = src.partition(inputVariable, randomSeed); |
| 102 | 108 | const int crossValidate = Globals->crossValidate; |
| 103 | 109 | |
| 104 | - if (crossValidate == 1) { | |
| 110 | + if (crossValidate < 0) { | |
| 105 | 111 | transforms[0]->project(partitioned, dst); |
| 106 | 112 | return; |
| 107 | 113 | } | ... | ... |