Commit c7010d940bf51784f94575037a9ae8a7bcc282e8

Authored by Scott Klum
1 parent 91eff7dd

Added some comments about leaveOneOut crossValidation

Showing 1 changed file with 37 additions and 30 deletions
openbr/openbr_plugin.cpp
@@ -392,41 +392,47 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) @@ -392,41 +392,47 @@ TemplateList TemplateList::fromGallery(const br::File &gallery)
392 392
393 const int crossValidate = gallery.get<int>("crossValidate"); 393 const int crossValidate = gallery.get<int>("crossValidate");
394 394
395 - for (int i=newTemplates.size()-1; i>=0; i--) { 395 + // The leaveOneImageOut flag is used when we want to train on n-1 of a subject's images
  396 + // Thus, we find all the images for a particular subject, and set their partitions based on
  397 + // the crossValidate parameter
  398 + // Note that when the number of images per subject varies from subject to subject
  399 + // the number of subjects will decrease as the partition increases
  400 + if (gallery.getBool("leaveOneImageOut") && crossValidate > 0) {
  401 + QStringList labels;
  402 + for (int i=newTemplates.size()-1; i>=0; i--) {
396 newTemplates[i].file.set("Index", i+templates.size()); 403 newTemplates[i].file.set("Index", i+templates.size());
397 newTemplates[i].file.set("Gallery", gallery.name); 404 newTemplates[i].file.set("Gallery", gallery.name);
398 405
399 - if (crossValidate > 0) {  
400 - if (gallery.getBool("leaveOneImageOut")) {  
401 - QStringList labels;  
402 - for (int i=newTemplates.size()-1; i>=0; i--) {  
403 - newTemplates[i].file.set("Index", i+templates.size());  
404 - newTemplates[i].file.set("Gallery", gallery.name);  
405 -  
406 - QString label = newTemplates.at(i).file.get<QString>("Label");  
407 - // Have we seen this subject before?  
408 - if (!labels.contains(label)) {  
409 - labels.append(label);  
410 - // Get indices belonging to this subject  
411 - QList<int> labelIndices = newTemplates.find("Label",label);  
412 - for (int j = 0; j < labelIndices.size(); j++) {  
413 - // Set subject partitions  
414 - newTemplates[labelIndices[j]].file.set("Partition",j%crossValidate);  
415 - }  
416 - // Extend the gallery for each partition  
417 - for (int j=0; j<labelIndices.size(); j++) {  
418 - for (int k=0; k<crossValidate; k++) {  
419 - Template leaveOneImageOutTemplate = newTemplates[labelIndices[j]];  
420 - if (k!=leaveOneImageOutTemplate.file.get<int>("Partition")) {  
421 - leaveOneImageOutTemplate.file.set("Partition", k);  
422 - leaveOneImageOutTemplate.file.set("testOnly", true);  
423 - newTemplates.insert(i+1,leaveOneImageOutTemplate);  
424 - }  
425 - }  
426 - } 406 + QString label = newTemplates.at(i).file.get<QString>("Label");
  407 + // Have we seen this subject before?
  408 + if (!labels.contains(label)) {
  409 + labels.append(label);
  410 + // Get indices belonging to this subject
  411 + QList<int> labelIndices = newTemplates.find("Label",label);
  412 + for (int j = 0; j < labelIndices.size(); j++) {
  413 + // Set subject partitions
  414 + newTemplates[labelIndices[j]].file.set("Partition",j%crossValidate);
  415 + }
  416 + // Extend the gallery for each partition
  417 + for (int j=0; j<labelIndices.size(); j++) {
  418 + for (int k=0; k<crossValidate; k++) {
  419 + Template leaveOneImageOutTemplate = newTemplates[labelIndices[j]];
  420 + if (k!=leaveOneImageOutTemplate.file.get<int>("Partition")) {
  421 + leaveOneImageOutTemplate.file.set("Partition", k);
  422 + leaveOneImageOutTemplate.file.set("testOnly", true);
  423 + newTemplates.insert(i+1,leaveOneImageOutTemplate);
427 } 424 }
428 } 425 }
429 - } else if (newTemplates[i].file.getBool("duplicatePartitions")) { 426 + }
  427 + }
  428 + }
  429 + } else {
  430 + for (int i=newTemplates.size()-1; i>=0; i--) {
  431 + newTemplates[i].file.set("Index", i+templates.size());
  432 + newTemplates[i].file.set("Gallery", gallery.name);
  433 +
  434 + if (crossValidate > 0) {
  435 + if (newTemplates[i].file.getBool("duplicatePartitions")) {
430 // The duplicatePartitions flag is used to add target images 436 // The duplicatePartitions flag is used to add target images
431 // crossValidate times to the simmat/mask 437 // crossValidate times to the simmat/mask
432 // when multiple training sets are being used 438 // when multiple training sets are being used
@@ -453,6 +459,7 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery) @@ -453,6 +459,7 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery)
453 } 459 }
454 } 460 }
455 } 461 }
  462 + }
456 } 463 }
457 464
458 if (!templates.isEmpty() && gallery.get<bool>("merge", false)) { 465 if (!templates.isEmpty() && gallery.get<bool>("merge", false)) {