Commit 6df13fa090bbcbd676eda6a51d6e580c032e1f09

Authored by bhklein
1 parent 9102a476

only train/test on 0th partition if crossValidate is neg. accept neg. numbers from CL

app/br/br.cpp
@@ -69,6 +69,7 @@ public: @@ -69,6 +69,7 @@ public:
69 69
70 bool daemon = false; 70 bool daemon = false;
71 const char *daemon_pipe = NULL; 71 const char *daemon_pipe = NULL;
  72 + bool isInt = false;
72 while (daemon || (argc > 0)) { 73 while (daemon || (argc > 0)) {
73 const char *fun; 74 const char *fun;
74 int parc; 75 int parc;
@@ -78,7 +79,12 @@ public: @@ -78,7 +79,12 @@ public:
78 79
79 fun = argv[0]; 80 fun = argv[0];
80 if (fun[0] == '-') fun++; 81 if (fun[0] == '-') fun++;
81 - parc = 0; while ((parc+1 < argc) && (argv[parc+1][0] != '-')) parc++; 82 + parc = 0;
  83 + QString(argv[parc+1]).toInt(&isInt);
  84 + while ((parc+1 < argc) && ((argv[parc+1][0] != '-') || isInt)) {
  85 + parc++;
  86 + QString(argv[parc+1]).toInt(&isInt);
  87 + }
82 parv = (const char **)&argv[1]; 88 parv = (const char **)&argv[1];
83 argc = argc - (parc+1); 89 argc = argc - (parc+1);
84 argv = &argv[parc+1]; 90 argv = &argv[parc+1];
openbr/openbr_plugin.cpp
@@ -531,47 +531,34 @@ QList&lt;int&gt; TemplateList::applyIndex(const QString &amp;propName, const QHash&lt;QString @@ -531,47 +531,34 @@ QList&lt;int&gt; TemplateList::applyIndex(const QString &amp;propName, const QHash&lt;QString
531 531
532 TemplateList TemplateList::partition(const QString &inputVariable, unsigned int randomSeed) const 532 TemplateList TemplateList::partition(const QString &inputVariable, unsigned int randomSeed) const
533 { 533 {
534 - TemplateList partitioned = *this;  
535 - const int crossValidate = Globals->crossValidate; 534 + const int crossValidate = std::abs(Globals->crossValidate);
  535 + if (crossValidate < 2)
  536 + return *this;
536 537
537 // Use separate RNG from Common::randN() to avoid re-seeding the global RNG. 538 // Use separate RNG from Common::randN() to avoid re-seeding the global RNG.
538 // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same randomSeed seed. 539 // rng is seeded with the inputVariable hash in order to maintain partition across br runs, given the same randomSeed seed.
539 RandomLib::Random rng; 540 RandomLib::Random rng;
540 - if (crossValidate == 1) {  
541 - for (int i=0; i < partitioned.size(); i++) { 541 +
  542 + TemplateList partitioned = *this;
  543 + for (int i=partitioned.size()-1; i>=0; i--) {
  544 + // See CrossValidateTransform for description of these variables
  545 + if (partitioned[i].file.getBool("duplicatePartitions")) {
  546 + for (int j=crossValidate-1; j>=0; j--) {
  547 + Template duplicateTemplate = partitioned[i];
  548 + duplicateTemplate.file.set("Partition", j);
  549 + partitioned.insert(i+1, duplicateTemplate);
  550 + }
  551 + } else if (partitioned[i].file.getBool("allPartitions")) {
  552 + partitioned[i].file.set("Partition", -1);
  553 + } else {
542 const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5); 554 const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);
543 - quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed;  
544 - rng.Reseed(labelSeed);  
545 -  
546 - // Roughly 2/3rd to training, 1/3rd to testing for the special single split case  
547 - float result = rng.FloatN();  
548 - if (result <= (2.0/3.0))  
549 - // Training  
550 - partitioned[i].file.set("Partition", QString::number(1));  
551 - else  
552 - partitioned[i].file.set("Partition", QString::number(0));  
553 - }  
554 - } else {  
555 - for (int i=partitioned.size()-1; i>=0; i--) {  
556 - // See CrossValidateTransform for description of these variables  
557 - if (partitioned[i].file.getBool("duplicatePartitions")) {  
558 - for (int j=crossValidate-1; j>=0; j--) {  
559 - Template duplicateTemplate = partitioned[i];  
560 - duplicateTemplate.file.set("Partition", j);  
561 - partitioned.insert(i+1, duplicateTemplate);  
562 - }  
563 - } else if (partitioned[i].file.getBool("allPartitions")) {  
564 - partitioned[i].file.set("Partition", -1);  
565 - } else {  
566 - const QByteArray md5 = QCryptographicHash::hash(partitioned[i].file.get<QString>(inputVariable).toLatin1(), QCryptographicHash::Md5);  
567 - if (randomSeed) {  
568 - quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed;  
569 - rng.Reseed(labelSeed);  
570 - partitioned[i].file.set("Partition", rng.Integer() % crossValidate);  
571 - } else if (!partitioned[i].file.contains(("Partition"))) {  
572 - // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow  
573 - partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate);  
574 - } 555 + if (randomSeed) {
  556 + quint64 labelSeed = md5.toHex().right(8).toULongLong(0, 16) + randomSeed;
  557 + rng.Reseed(labelSeed);
  558 + partitioned[i].file.set("Partition", rng.Integer() % crossValidate);
  559 + } else if (!partitioned[i].file.contains(("Partition"))) {
  560 + // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow
  561 + partitioned[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate);
575 } 562 }
576 } 563 }
577 } 564 }
openbr/plugins/core/crossvalidate.cpp
@@ -71,10 +71,16 @@ class CrossValidateTransform : public MetaTransform @@ -71,10 +71,16 @@ class CrossValidateTransform : public MetaTransform
71 { 71 {
72 QList<int> partitions = data.partition(inputVariable, randomSeed).files().crossValidationPartitions(); 72 QList<int> partitions = data.partition(inputVariable, randomSeed).files().crossValidationPartitions();
73 const int crossValidate = Globals->crossValidate; 73 const int crossValidate = Globals->crossValidate;
74 - const int numPartitions = (crossValidate == 1) ? 1 : Common::Max(partitions)+1; 74 + // Only train once based on the 0th partition if crossValidate is negative.
  75 + const int numPartitions = (crossValidate < 0) ? 1 : Common::Max(partitions)+1;
75 while (transforms.size() < numPartitions) 76 while (transforms.size() < numPartitions)
76 transforms.append(make(description)); 77 transforms.append(make(description));
77 78
  79 + if (std::abs(crossValidate) < 2) {
  80 + transforms.first()->train(data);
  81 + return;
  82 + }
  83 +
78 QFutureSynchronizer<void> futures; 84 QFutureSynchronizer<void> futures;
79 for (int i=0; i<numPartitions; i++) { 85 for (int i=0; i<numPartitions; i++) {
80 TemplateList partitionedData = data; 86 TemplateList partitionedData = data;
@@ -101,7 +107,7 @@ class CrossValidateTransform : public MetaTransform @@ -101,7 +107,7 @@ class CrossValidateTransform : public MetaTransform
101 TemplateList partitioned = src.partition(inputVariable, randomSeed); 107 TemplateList partitioned = src.partition(inputVariable, randomSeed);
102 const int crossValidate = Globals->crossValidate; 108 const int crossValidate = Globals->crossValidate;
103 109
104 - if (crossValidate == 1) { 110 + if (crossValidate < 0) {
105 transforms[0]->project(partitioned, dst); 111 transforms[0]->project(partitioned, dst);
106 return; 112 return;
107 } 113 }