Commit 92f650219aa28d9eeae48adc4502280805927196
1 parent
51b329d6
Finished leaveOneOut crossValidation
Showing
4 changed files
with
59 additions
and
26 deletions
openbr/core/bee.cpp
| @@ -287,8 +287,8 @@ cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, | @@ -287,8 +287,8 @@ cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, | ||
| 287 | 287 | ||
| 288 | Mask_t val; | 288 | Mask_t val; |
| 289 | if (fileA == fileB) val = DontCare; | 289 | if (fileA == fileB) val = DontCare; |
| 290 | - else if (labelA == "-1") val = DontCare; | ||
| 291 | - else if (labelB == "-1") val = DontCare; | 290 | + else if (labelA == "-1") val = DontCare; |
| 291 | + else if (labelB == "-1") val = DontCare; | ||
| 292 | else if (partitionA != partition) val = DontCare; | 292 | else if (partitionA != partition) val = DontCare; |
| 293 | else if (partitionB == -1) val = NonMatch; | 293 | else if (partitionB == -1) val = NonMatch; |
| 294 | else if (partitionB != partition) val = DontCare; | 294 | else if (partitionB != partition) val = DontCare; |
openbr/openbr_plugin.cpp
| @@ -388,23 +388,31 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | @@ -388,23 +388,31 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | ||
| 388 | const int crossValidate = gallery.get<int>("crossValidate"); | 388 | const int crossValidate = gallery.get<int>("crossValidate"); |
| 389 | 389 | ||
| 390 | if (gallery.getBool("leaveOneOut")) { | 390 | if (gallery.getBool("leaveOneOut")) { |
| 391 | - QStringList subjects; | ||
| 392 | - for (int i = 0; i < newTemplates.size(); i++) { | ||
| 393 | - QString subject = newTemplates.at(i).file.get<QString>("Label"); | 391 | + QStringList labels; |
| 392 | + for (int i=newTemplates.size()-1; i>=0; i--) { | ||
| 393 | + newTemplates[i].file.set("Index", i+templates.size()); | ||
| 394 | + newTemplates[i].file.set("Gallery", gallery.name); | ||
| 395 | + | ||
| 396 | + QString label = newTemplates.at(i).file.get<QString>("Label"); | ||
| 394 | // Have we seen this subject before? | 397 | // Have we seen this subject before? |
| 395 | - if (subjects.contains(subject)) { | ||
| 396 | - subjects.append(subject); | 398 | + if (!labels.contains(label)) { |
| 399 | + labels.append(label); | ||
| 397 | // Get indices belonging to this subject | 400 | // Get indices belonging to this subject |
| 398 | - QList<int> subjectIndices = newTemplates.find("Label",subject); | ||
| 399 | - for (int j = 0; j < subjectIndices.size(); j++) { | 401 | + QList<int> labelIndices = newTemplates.find("Label",label); |
| 402 | + for (int j = 0; j < labelIndices.size(); j++) { | ||
| 400 | // Set subject partitions | 403 | // Set subject partitions |
| 401 | - newTemplates[subjectIndices[j]].file.set("Partition",j); | 404 | + newTemplates[labelIndices[j]].file.set("Partition",j%crossValidate); |
| 402 | } | 405 | } |
| 403 | - // Generate more templates if necessary | ||
| 404 | - for (int j=0; j<crossValidate-subjectIndices.size(); j++) { | ||
| 405 | - Template leaveOneOutTemplate = newTemplates[subjectIndices[j%subjectIndices.size()]]; | ||
| 406 | - leaveOneOutTemplate.file.set("Partition", j+subjectIndices.size()); | ||
| 407 | - newTemplates.append(leaveOneOutTemplate); | 406 | + // Extend the gallery for each partition |
| 407 | + for (int j=0; j<labelIndices.size(); j++) { | ||
| 408 | + for (int k=0; k<crossValidate; k++) { | ||
| 409 | + Template leaveOneOutTemplate = newTemplates[labelIndices[j]]; | ||
| 410 | + if (k!=leaveOneOutTemplate.file.get<int>("Partition")) { | ||
| 411 | + leaveOneOutTemplate.file.set("Partition", k); | ||
| 412 | + leaveOneOutTemplate.file.set("testOnly", true); | ||
| 413 | + newTemplates.insert(i+1,leaveOneOutTemplate); | ||
| 414 | + } | ||
| 415 | + } | ||
| 408 | } | 416 | } |
| 409 | } | 417 | } |
| 410 | } | 418 | } |
openbr/plugins/stasm4.cpp
| @@ -36,8 +36,8 @@ class StasmInitializer : public Initializer | @@ -36,8 +36,8 @@ class StasmInitializer : public Initializer | ||
| 36 | 36 | ||
| 37 | void initialize() const | 37 | void initialize() const |
| 38 | { | 38 | { |
| 39 | - Globals->abbreviations.insert("RectFromStasmEyes","RectFromPoints([29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],0.125,6.0)+Resize(44,164)"); | ||
| 40 | - Globals->abbreviations.insert("RectFromStasmBrow","RectFromPoints([17, 18, 19, 20, 21, 22, 23, 24],0.15,6)+Resize(28,132)"); | 39 | + Globals->abbreviations.insert("RectFromStasmEyes","RectFromPoints([29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],0.125,6.0)"); |
| 40 | + Globals->abbreviations.insert("RectFromStasmBrow","RectFromPoints([16,17,18,19,20,21,22,23,24,25,26,27],0.15,5)"); | ||
| 41 | Globals->abbreviations.insert("RectFromStasmNose","RectFromPoints([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58],0.15,1.25)"); | 41 | Globals->abbreviations.insert("RectFromStasmNose","RectFromPoints([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58],0.15,1.25)"); |
| 42 | Globals->abbreviations.insert("RectFromStasmMouth","RectFromPoints([59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76],0.3,2.5)"); | 42 | Globals->abbreviations.insert("RectFromStasmMouth","RectFromPoints([59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76],0.3,2.5)"); |
| 43 | } | 43 | } |
openbr/plugins/validate.cpp
| @@ -11,6 +11,7 @@ namespace br | @@ -11,6 +11,7 @@ namespace br | ||
| 11 | * \ingroup transforms | 11 | * \ingroup transforms |
| 12 | * \brief Cross validate a trainable transform. | 12 | * \brief Cross validate a trainable transform. |
| 13 | * \author Josh Klontz \cite jklontz | 13 | * \author Josh Klontz \cite jklontz |
| 14 | + * \author Scott Klum \cite sklum | ||
| 14 | * \note To use an extended gallery, add an allPartitions="true" flag to the gallery sigset for those images that should be compared | 15 | * \note To use an extended gallery, add an allPartitions="true" flag to the gallery sigset for those images that should be compared |
| 15 | * against for all testing partitions. | 16 | * against for all testing partitions. |
| 16 | */ | 17 | */ |
| @@ -43,22 +44,46 @@ class CrossValidateTransform : public MetaTransform | @@ -43,22 +44,46 @@ class CrossValidateTransform : public MetaTransform | ||
| 43 | 44 | ||
| 44 | QFutureSynchronizer<void> futures; | 45 | QFutureSynchronizer<void> futures; |
| 45 | for (int i=0; i<numPartitions; i++) { | 46 | for (int i=0; i<numPartitions; i++) { |
| 47 | + QList<int> partitionsBuffer = partitions; | ||
| 46 | TemplateList partitionedData = data; | 48 | TemplateList partitionedData = data; |
| 47 | - QList<int> removed; | ||
| 48 | - for (int j=partitionedData.size()-1; j>=0; j--) | 49 | + int j = partitionedData.size()-1; |
| 50 | + while (j>=0) { | ||
| 49 | // Remove all templates belonging to partition i | 51 | // Remove all templates belonging to partition i |
| 50 | // if leaveOneOut is true, | 52 | // if leaveOneOut is true, |
| 51 | // and i is greater than the number of images for a particular subject | 53 | // and i is greater than the number of images for a particular subject |
| 52 | // even if the partitions are different | 54 | // even if the partitions are different |
| 53 | if (leaveOneOut) { | 55 | if (leaveOneOut) { |
| 54 | - QList<int> subjectIndices = partitionedData.find("Subject",partitionedData.at(j).file.get<QString>("Subject")); | ||
| 55 | - if (i > subjectIndices.size()) removed.append(subjectIndices[i%subjectIndices.size()]); | ||
| 56 | - } else if (partitions[j] == i) | ||
| 57 | - removed.append(j); | ||
| 58 | - typedef QPair<int,int> Pair; | ||
| 59 | - foreach (const Pair &pair, Common::Sort(removed,true)) partitionedData.removeAt(pair.first); | 56 | + const QString label = partitionedData.at(j).file.get<QString>("Label"); |
| 57 | + QList<int> subjectIndices = partitionedData.find("Label",label); | ||
| 58 | + QList<int> removed; | ||
| 59 | + // Remove test only data | ||
| 60 | + for (int k=subjectIndices.size()-1; k>=0; k--) | ||
| 61 | + if (partitionedData[subjectIndices[k]].file.getBool("testOnly")) { | ||
| 62 | + removed.append(subjectIndices[k]); | ||
| 63 | + subjectIndices.removeAt(k); | ||
| 64 | + } | ||
| 65 | + // Remove template that was repeated to make the testOnly template | ||
| 66 | + if (subjectIndices.size() > 1 && subjectIndices.size() <= i) { | ||
| 67 | + removed.append(subjectIndices[i%subjectIndices.size()]); | ||
| 68 | + } | ||
| 69 | + else if (partitionsBuffer[j] == i) { | ||
| 70 | + removed.append(j); | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + if (!removed.empty()) { | ||
| 74 | + typedef QPair<int,int> Pair; | ||
| 75 | + foreach (Pair pair, Common::Sort(removed,true)) { | ||
| 76 | + partitionedData.removeAt(pair.first); partitionsBuffer.removeAt(pair.first); j--; | ||
| 77 | + } | ||
| 78 | + } else { | ||
| 79 | + j--; | ||
| 80 | + } | ||
| 81 | + } else if (partitions[j] == i) { | ||
| 82 | + partitionedData.removeAt(j); | ||
| 83 | + } else j--; | ||
| 84 | + } | ||
| 60 | // Train on the remaining templates | 85 | // Train on the remaining templates |
| 61 | - foreach (const Template &t, partitionedData) qDebug() << "Remaining data for partition " << i << ": " << t.file.baseName(); | 86 | + foreach (const Template &t, partitionedData) qDebug() << "Remaining data for partition " << i << t.file.baseName() << t.file.get<QString>("Label") << t.file.get<QString>("Partition"); |
| 62 | futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); | 87 | futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); |
| 63 | } | 88 | } |
| 64 | futures.waitForFinished(); | 89 | futures.waitForFinished(); |