Commit 6df13fa090bbcbd676eda6a51d6e580c032e1f09

Authored by bhklein
1 parent 9102a476

only train/test on 0th partition if crossValidate is neg. accept neg. numbers from CL

app/br/br.cpp
... ... @@ -69,6 +69,7 @@ public:
69 69  
70 70 bool daemon = false;
71 71 const char *daemon_pipe = NULL;
  72 + bool isInt = false;
72 73 while (daemon || (argc > 0)) {
73 74 const char *fun;
74 75 int parc;
... ... @@ -78,7 +79,12 @@ public:
78 79  
79 80 fun = argv[0];
80 81 if (fun[0] == '-') fun++;
81   - parc = 0; while ((parc+1 < argc) && (argv[parc+1][0] != '-')) parc++;
  82 + parc = 0;
  83 + QString(argv[parc+1]).toInt(&isInt);
  84 + while ((parc+1 < argc) && ((argv[parc+1][0] != '-') || isInt)) {
  85 + parc++;
  86 + QString(argv[parc+1]).toInt(&isInt);
  87 + }
82 88 parv = (const char **)&argv[1];
83 89 argc = argc - (parc+1);
84 90 argv = &argv[parc+1];
... ...
openbr/openbr_plugin.cpp
... ... @@ -531,47 +531,34 @@ QList&lt;int&gt; TemplateList::applyIndex(const QString &amp;propName, const QHash&lt;QString
531 531  
532 532 TemplateList TemplateList::partition(const QString &inputVariable, unsigned int randomSeed) const
533 533 {
534   - TemplateList partitioned = *this;
535   - const int crossValidate = Globals->crossValidate;
  534 + const int crossValidate = std::abs(Globals->crossValidate);
  535 + if (crossValidate < 2)
  536 + return *this;
536 537  
537 538 // Use separate RNG from Common::randN() to avoid re-seeding the global RNG.
538 539 // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same randomSeed seed.
539 540 RandomLib::Random rng;
540   - if (crossValidate == 1) {
541   - for (int i=0; i < partitioned.size(); i++) {
  541 +
  542 + TemplateList partitioned = *this;
  543 + for (int i=partitioned.size()-1; i>=0; i--) {
  544 + // See CrossValidateTransform for description of these variables
  545 + if (partitioned[i].file.getBool("duplicatePartitions")) {
  546 + for (int j=crossValidate-1; j>=0; j--) {
  547 + Template duplicateTemplate = partitioned[i];
  548 + duplicateTemplate.file.set("Partition", j);
  549 + partitioned.insert(i+1, duplicateTemplate);
  550 + }
  551 + } else if (partitioned[i].file.getBool("allPartitions")) {
  552 + partitioned[i].file.set("Partition", -1);
  553 + } else {
542 554 const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
543   - quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed;
544   - rng.Reseed(labelSeed);
545   -
546   - // Roughly 2/3rd to training, 1/3rd to testing for the special single split case
547   - float result = rng.FloatN();
548   - if (result <= (2.0/3.0))
549   - // Training
550   - partitioned[i].file.set("Partition", QString::number(1));
551   - else
552   - partitioned[i].file.set("Partition", QString::number(0));
553   - }
554   - } else {
555   - for (int i=partitioned.size()-1; i>=0; i--) {
556   - // See CrossValidateTransform for description of these variables
557   - if (partitioned[i].file.getBool("duplicatePartitions")) {
558   - for (int j=crossValidate-1; j>=0; j--) {
559   - Template duplicateTemplate = partitioned[i];
560   - duplicateTemplate.file.set("Partition", j);
561   - partitioned.insert(i+1, duplicateTemplate);
562   - }
563   - } else if (partitioned[i].file.getBool("allPartitions")) {
564   - partitioned[i].file.set("Partition", -1);
565   - } else {
566   - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
567   - if (randomSeed) {
568   - quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed;
569   - rng.Reseed(labelSeed);
570   - partitioned[i].file.set("Partition", rng.Integer() % crossValidate);
571   - } else if (!partitioned[i].file.contains(("Partition"))) {
572   - // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow
573   - partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate);
574   - }
  555 + if (randomSeed) {
  556 + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed;
  557 + rng.Reseed(labelSeed);
  558 + partitioned[i].file.set("Partition", rng.Integer() % crossValidate);
  559 + } else if (!partitioned[i].file.contains(("Partition"))) {
  560 + // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow
  561 + partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate);
575 562 }
576 563 }
577 564 }
... ...
openbr/plugins/core/crossvalidate.cpp
... ... @@ -71,10 +71,16 @@ class CrossValidateTransform : public MetaTransform
71 71 {
72 72 QList<int> partitions = data.partition(inputVariable, randomSeed).files().crossValidationPartitions();
73 73 const int crossValidate = Globals->crossValidate;
74   - const int numPartitions = (crossValidate == 1) ? 1 : Common::Max(partitions)+1;
  74 + // Only train once based on the 0th partition if crossValidate is negative.
  75 + const int numPartitions = (crossValidate < 0) ? 1 : Common::Max(partitions)+1;
75 76 while (transforms.size() < numPartitions)
76 77 transforms.append(make(description));
77 78  
  79 + if (std::abs(crossValidate) < 2) {
  80 + transforms.first()->train(data);
  81 + return;
  82 + }
  83 +
78 84 QFutureSynchronizer<void> futures;
79 85 for (int i=0; i<numPartitions; i++) {
80 86 TemplateList partitionedData = data;
... ... @@ -101,7 +107,7 @@ class CrossValidateTransform : public MetaTransform
101 107 TemplateList partitioned = src.partition(inputVariable, randomSeed);
102 108 const int crossValidate = Globals->crossValidate;
103 109  
104   - if (crossValidate == 1) {
  110 + if (crossValidate < 0) {
105 111 transforms[0]->project(partitioned, dst);
106 112 return;
107 113 }
... ...