diff --git a/openbr/core/bee.cpp b/openbr/core/bee.cpp index 9639657..9428b34 100644 --- a/openbr/core/bee.cpp +++ b/openbr/core/bee.cpp @@ -287,8 +287,8 @@ cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, Mask_t val; if (fileA == fileB) val = DontCare; - else if (labelA == "-1") val = DontCare; - else if (labelB == "-1") val = DontCare; + else if (labelA == "-1") val = DontCare; + else if (labelB == "-1") val = DontCare; else if (partitionA != partition) val = DontCare; else if (partitionB == -1) val = NonMatch; else if (partitionB != partition) val = DontCare; diff --git a/openbr/openbr_plugin.cpp b/openbr/openbr_plugin.cpp index 8caf86f..565a2cd 100644 --- a/openbr/openbr_plugin.cpp +++ b/openbr/openbr_plugin.cpp @@ -386,36 +386,66 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) newTemplates = newTemplates.reduced(); const int crossValidate = gallery.get("crossValidate"); - if (crossValidate > 0) srand(0); - - for (int i=newTemplates.size()-1; i>=0; i--) { - newTemplates[i].file.set("Index", i+templates.size()); - newTemplates[i].file.set("Gallery", gallery.name); - - if (crossValidate > 0) { - if (newTemplates[i].file.getBool("duplicatePartitions")) { - // The duplicatePartitions flag is used to add target images - // crossValidate times to the simmat/mask - // when multiple training sets are being used - - // Set template to the first parition - newTemplates[i].file.set("Partition", QVariant(0)); - - // Insert templates for all the other partitions - for (int j=crossValidate-1; j>=1; j--) { - Template allPartitionTemplate = newTemplates[i]; - allPartitionTemplate.file.set("Partition", j); - newTemplates.insert(i+1, allPartitionTemplate); + + if (gallery.getBool("leaveOneOut")) { + QStringList labels; + for (int i=newTemplates.size()-1; i>=0; i--) { + newTemplates[i].file.set("Index", i+templates.size()); + newTemplates[i].file.set("Gallery", gallery.name); + + QString label = newTemplates.at(i).file.get("Label"); + // Have we seen this subject before? + if (!labels.contains(label)) { + labels.append(label); + // Get indices belonging to this subject + QList labelIndices = newTemplates.find("Label",label); + for (int j = 0; j < labelIndices.size(); j++) { + // Set subject partitions + newTemplates[labelIndices[j]].file.set("Partition",j%crossValidate); + } + // Extend the gallery for each partition + for (int j=0; j("Partition")) { + leaveOneOutTemplate.file.set("Partition", k); + leaveOneOutTemplate.file.set("testOnly", true); + newTemplates.insert(i+1,leaveOneOutTemplate); + } + } } - } else if (newTemplates[i].file.getBool("allPartitions")) { - // The allPartitions flag is used to add an extended set - // of target images to every partition - newTemplates[i].file.set("Partition", -1); - } else { + } + } + } else { + for (int i=newTemplates.size()-1; i>=0; i--) { + newTemplates[i].file.set("Index", i+templates.size()); + newTemplates[i].file.set("Gallery", gallery.name); + + if (crossValidate > 0) { + if (newTemplates[i].file.getBool("duplicatePartitions")) { + // The duplicatePartitions flag is used to add target images + // crossValidate times to the simmat/mask + // when multiple training sets are being used + + // Set template to the first parition + newTemplates[i].file.set("Partition", QVariant(0)); + + // Insert templates for all the other partitions + for (int j=crossValidate-1; j>0; j--) { + Template duplicatePartitionsTemplate = newTemplates[i]; + duplicatePartitionsTemplate.file.set("Partition", j); + newTemplates.insert(i+1, duplicatePartitionsTemplate); + } + } else if (newTemplates[i].file.getBool("allPartitions")) { + // The allPartitions flag is used to add an extended set + // of target images to every partition + newTemplates[i].file.set("Partition", -1); + } else { // Direct use of "Label" is not general -cao const QByteArray md5 = QCryptographicHash::hash(newTemplates[i].file.get("Label").toLatin1(), QCryptographicHash::Md5); // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow newTemplates[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); + } } } } diff --git a/openbr/openbr_plugin.h b/openbr/openbr_plugin.h index 0161799..aa808c6 100644 --- a/openbr/openbr_plugin.h +++ b/openbr/openbr_plugin.h @@ -503,6 +503,20 @@ struct TemplateList : public QList