Commit c7010d940bf51784f94575037a9ae8a7bcc282e8
1 parent
91eff7dd
Added some comments about leaveOneOut crossValidation
Showing
1 changed file
with
37 additions
and
30 deletions
openbr/openbr_plugin.cpp
| @@ -392,41 +392,47 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | @@ -392,41 +392,47 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | ||
| 392 | 392 | ||
| 393 | const int crossValidate = gallery.get<int>("crossValidate"); | 393 | const int crossValidate = gallery.get<int>("crossValidate"); |
| 394 | 394 | ||
| 395 | - for (int i=newTemplates.size()-1; i>=0; i--) { | 395 | + // The leaveOneImageOut flag is used when we want to train on n-1 of a subject's images |
| 396 | + // Thus, we find all the images for a particular subject, and set their partitions based on | ||
| 397 | + // the crossValidate parameter | ||
| 398 | + // Note that when the number of images per subject varies from subject to subject | ||
| 399 | + // the number of subjects will decrease as the partition increases | ||
| 400 | + if (gallery.getBool("leaveOneImageOut") && crossValidate > 0) { | ||
| 401 | + QStringList labels; | ||
| 402 | + for (int i=newTemplates.size()-1; i>=0; i--) { | ||
| 396 | newTemplates[i].file.set("Index", i+templates.size()); | 403 | newTemplates[i].file.set("Index", i+templates.size()); |
| 397 | newTemplates[i].file.set("Gallery", gallery.name); | 404 | newTemplates[i].file.set("Gallery", gallery.name); |
| 398 | 405 | ||
| 399 | - if (crossValidate > 0) { | ||
| 400 | - if (gallery.getBool("leaveOneImageOut")) { | ||
| 401 | - QStringList labels; | ||
| 402 | - for (int i=newTemplates.size()-1; i>=0; i--) { | ||
| 403 | - newTemplates[i].file.set("Index", i+templates.size()); | ||
| 404 | - newTemplates[i].file.set("Gallery", gallery.name); | ||
| 405 | - | ||
| 406 | - QString label = newTemplates.at(i).file.get<QString>("Label"); | ||
| 407 | - // Have we seen this subject before? | ||
| 408 | - if (!labels.contains(label)) { | ||
| 409 | - labels.append(label); | ||
| 410 | - // Get indices belonging to this subject | ||
| 411 | - QList<int> labelIndices = newTemplates.find("Label",label); | ||
| 412 | - for (int j = 0; j < labelIndices.size(); j++) { | ||
| 413 | - // Set subject partitions | ||
| 414 | - newTemplates[labelIndices[j]].file.set("Partition",j%crossValidate); | ||
| 415 | - } | ||
| 416 | - // Extend the gallery for each partition | ||
| 417 | - for (int j=0; j<labelIndices.size(); j++) { | ||
| 418 | - for (int k=0; k<crossValidate; k++) { | ||
| 419 | - Template leaveOneImageOutTemplate = newTemplates[labelIndices[j]]; | ||
| 420 | - if (k!=leaveOneImageOutTemplate.file.get<int>("Partition")) { | ||
| 421 | - leaveOneImageOutTemplate.file.set("Partition", k); | ||
| 422 | - leaveOneImageOutTemplate.file.set("testOnly", true); | ||
| 423 | - newTemplates.insert(i+1,leaveOneImageOutTemplate); | ||
| 424 | - } | ||
| 425 | - } | ||
| 426 | - } | 406 | + QString label = newTemplates.at(i).file.get<QString>("Label"); |
| 407 | + // Have we seen this subject before? | ||
| 408 | + if (!labels.contains(label)) { | ||
| 409 | + labels.append(label); | ||
| 410 | + // Get indices belonging to this subject | ||
| 411 | + QList<int> labelIndices = newTemplates.find("Label",label); | ||
| 412 | + for (int j = 0; j < labelIndices.size(); j++) { | ||
| 413 | + // Set subject partitions | ||
| 414 | + newTemplates[labelIndices[j]].file.set("Partition",j%crossValidate); | ||
| 415 | + } | ||
| 416 | + // Extend the gallery for each partition | ||
| 417 | + for (int j=0; j<labelIndices.size(); j++) { | ||
| 418 | + for (int k=0; k<crossValidate; k++) { | ||
| 419 | + Template leaveOneImageOutTemplate = newTemplates[labelIndices[j]]; | ||
| 420 | + if (k!=leaveOneImageOutTemplate.file.get<int>("Partition")) { | ||
| 421 | + leaveOneImageOutTemplate.file.set("Partition", k); | ||
| 422 | + leaveOneImageOutTemplate.file.set("testOnly", true); | ||
| 423 | + newTemplates.insert(i+1,leaveOneImageOutTemplate); | ||
| 427 | } | 424 | } |
| 428 | } | 425 | } |
| 429 | - } else if (newTemplates[i].file.getBool("duplicatePartitions")) { | 426 | + } |
| 427 | + } | ||
| 428 | + } | ||
| 429 | + } else { | ||
| 430 | + for (int i=newTemplates.size()-1; i>=0; i--) { | ||
| 431 | + newTemplates[i].file.set("Index", i+templates.size()); | ||
| 432 | + newTemplates[i].file.set("Gallery", gallery.name); | ||
| 433 | + | ||
| 434 | + if (crossValidate > 0) { | ||
| 435 | + if (newTemplates[i].file.getBool("duplicatePartitions")) { | ||
| 430 | // The duplicatePartitions flag is used to add target images | 436 | // The duplicatePartitions flag is used to add target images |
| 431 | // crossValidate times to the simmat/mask | 437 | // crossValidate times to the simmat/mask |
| 432 | // when multiple training sets are being used | 438 | // when multiple training sets are being used |
| @@ -453,6 +459,7 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | @@ -453,6 +459,7 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | ||
| 453 | } | 459 | } |
| 454 | } | 460 | } |
| 455 | } | 461 | } |
| 462 | + } | ||
| 456 | } | 463 | } |
| 457 | 464 | ||
| 458 | if (!templates.isEmpty() && gallery.get<bool>("merge", false)) { | 465 | if (!templates.isEmpty() && gallery.get<bool>("merge", false)) { |