Commit 92f650219aa28d9eeae48adc4502280805927196

Authored by Scott Klum
1 parent 51b329d6

Finished leaveOneOut crossValidation

openbr/core/bee.cpp
@@ -287,8 +287,8 @@ cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, @@ -287,8 +287,8 @@ cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries,
287 287
288 Mask_t val; 288 Mask_t val;
289 if (fileA == fileB) val = DontCare; 289 if (fileA == fileB) val = DontCare;
290 - else if (labelA == "-1") val = DontCare;  
291 - else if (labelB == "-1") val = DontCare; 290 + else if (labelA == "-1") val = DontCare;
  291 + else if (labelB == "-1") val = DontCare;
292 else if (partitionA != partition) val = DontCare; 292 else if (partitionA != partition) val = DontCare;
293 else if (partitionB == -1) val = NonMatch; 293 else if (partitionB == -1) val = NonMatch;
294 else if (partitionB != partition) val = DontCare; 294 else if (partitionB != partition) val = DontCare;
openbr/openbr_plugin.cpp
@@ -388,23 +388,31 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) @@ -388,23 +388,31 @@ TemplateList TemplateList::fromGallery(const br::File &gallery)
388 const int crossValidate = gallery.get<int>("crossValidate"); 388 const int crossValidate = gallery.get<int>("crossValidate");
389 389
390 if (gallery.getBool("leaveOneOut")) { 390 if (gallery.getBool("leaveOneOut")) {
391 - QStringList subjects;  
392 - for (int i = 0; i < newTemplates.size(); i++) {  
393 - QString subject = newTemplates.at(i).file.get<QString>("Label"); 391 + QStringList labels;
  392 + for (int i=newTemplates.size()-1; i>=0; i--) {
  393 + newTemplates[i].file.set("Index", i+templates.size());
  394 + newTemplates[i].file.set("Gallery", gallery.name);
  395 +
  396 + QString label = newTemplates.at(i).file.get<QString>("Label");
394 // Have we seen this subject before? 397 // Have we seen this subject before?
395 - if (subjects.contains(subject)) {  
396 - subjects.append(subject); 398 + if (!labels.contains(label)) {
  399 + labels.append(label);
397 // Get indices belonging to this subject 400 // Get indices belonging to this subject
398 - QList<int> subjectIndices = newTemplates.find("Label",subject);  
399 - for (int j = 0; j < subjectIndices.size(); j++) { 401 + QList<int> labelIndices = newTemplates.find("Label",label);
  402 + for (int j = 0; j < labelIndices.size(); j++) {
400 // Set subject partitions 403 // Set subject partitions
401 - newTemplates[subjectIndices[j]].file.set("Partition",j); 404 + newTemplates[labelIndices[j]].file.set("Partition",j%crossValidate);
402 } 405 }
403 - // Generate more templates if necessary  
404 - for (int j=0; j<crossValidate-subjectIndices.size(); j++) {  
405 - Template leaveOneOutTemplate = newTemplates[subjectIndices[j%subjectIndices.size()]];  
406 - leaveOneOutTemplate.file.set("Partition", j+subjectIndices.size());  
407 - newTemplates.append(leaveOneOutTemplate); 406 + // Extend the gallery for each partition
  407 + for (int j=0; j<labelIndices.size(); j++) {
  408 + for (int k=0; k<crossValidate; k++) {
  409 + Template leaveOneOutTemplate = newTemplates[labelIndices[j]];
  410 + if (k!=leaveOneOutTemplate.file.get<int>("Partition")) {
  411 + leaveOneOutTemplate.file.set("Partition", k);
  412 + leaveOneOutTemplate.file.set("testOnly", true);
  413 + newTemplates.insert(i+1,leaveOneOutTemplate);
  414 + }
  415 + }
408 } 416 }
409 } 417 }
410 } 418 }
openbr/plugins/stasm4.cpp
@@ -36,8 +36,8 @@ class StasmInitializer : public Initializer @@ -36,8 +36,8 @@ class StasmInitializer : public Initializer
36 36
37 void initialize() const 37 void initialize() const
38 { 38 {
39 - Globals->abbreviations.insert("RectFromStasmEyes","RectFromPoints([29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],0.125,6.0)+Resize(44,164)");  
40 - Globals->abbreviations.insert("RectFromStasmBrow","RectFromPoints([17, 18, 19, 20, 21, 22, 23, 24],0.15,6)+Resize(28,132)"); 39 + Globals->abbreviations.insert("RectFromStasmEyes","RectFromPoints([29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],0.125,6.0)");
  40 + Globals->abbreviations.insert("RectFromStasmBrow","RectFromPoints([16,17,18,19,20,21,22,23,24,25,26,27],0.15,5)");
41 Globals->abbreviations.insert("RectFromStasmNose","RectFromPoints([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58],0.15,1.25)"); 41 Globals->abbreviations.insert("RectFromStasmNose","RectFromPoints([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58],0.15,1.25)");
42 Globals->abbreviations.insert("RectFromStasmMouth","RectFromPoints([59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76],0.3,2.5)"); 42 Globals->abbreviations.insert("RectFromStasmMouth","RectFromPoints([59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76],0.3,2.5)");
43 } 43 }
openbr/plugins/validate.cpp
@@ -11,6 +11,7 @@ namespace br @@ -11,6 +11,7 @@ namespace br
11 * \ingroup transforms 11 * \ingroup transforms
12 * \brief Cross validate a trainable transform. 12 * \brief Cross validate a trainable transform.
13 * \author Josh Klontz \cite jklontz 13 * \author Josh Klontz \cite jklontz
  14 + * \author Scott Klum \cite sklum
14 * \note To use an extended gallery, add an allPartitions="true" flag to the gallery sigset for those images that should be compared 15 * \note To use an extended gallery, add an allPartitions="true" flag to the gallery sigset for those images that should be compared
15 * against for all testing partitions. 16 * against for all testing partitions.
16 */ 17 */
@@ -43,22 +44,46 @@ class CrossValidateTransform : public MetaTransform @@ -43,22 +44,46 @@ class CrossValidateTransform : public MetaTransform
43 44
44 QFutureSynchronizer<void> futures; 45 QFutureSynchronizer<void> futures;
45 for (int i=0; i<numPartitions; i++) { 46 for (int i=0; i<numPartitions; i++) {
  47 + QList<int> partitionsBuffer = partitions;
46 TemplateList partitionedData = data; 48 TemplateList partitionedData = data;
47 - QList<int> removed;  
48 - for (int j=partitionedData.size()-1; j>=0; j--) 49 + int j = partitionedData.size()-1;
  50 + while (j>=0) {
49 // Remove all templates belonging to partition i 51 // Remove all templates belonging to partition i
50 // if leaveOneOut is true, 52 // if leaveOneOut is true,
51 // and i is greater than the number of images for a particular subject 53 // and i is greater than the number of images for a particular subject
52 // even if the partitions are different 54 // even if the partitions are different
53 if (leaveOneOut) { 55 if (leaveOneOut) {
54 - QList<int> subjectIndices = partitionedData.find("Subject",partitionedData.at(j).file.get<QString>("Subject"));  
55 - if (i > subjectIndices.size()) removed.append(subjectIndices[i%subjectIndices.size()]);  
56 - } else if (partitions[j] == i)  
57 - removed.append(j);  
58 - typedef QPair<int,int> Pair;  
59 - foreach (const Pair &pair, Common::Sort(removed,true)) partitionedData.removeAt(pair.first); 56 + const QString label = partitionedData.at(j).file.get<QString>("Label");
  57 + QList<int> subjectIndices = partitionedData.find("Label",label);
  58 + QList<int> removed;
  59 + // Remove test only data
  60 + for (int k=subjectIndices.size()-1; k>=0; k--)
  61 + if (partitionedData[subjectIndices[k]].file.getBool("testOnly")) {
  62 + removed.append(subjectIndices[k]);
  63 + subjectIndices.removeAt(k);
  64 + }
  65 + // Remove template that was repeated to make the testOnly template
  66 + if (subjectIndices.size() > 1 && subjectIndices.size() <= i) {
  67 + removed.append(subjectIndices[i%subjectIndices.size()]);
  68 + }
  69 + else if (partitionsBuffer[j] == i) {
  70 + removed.append(j);
  71 + }
  72 +
  73 + if (!removed.empty()) {
  74 + typedef QPair<int,int> Pair;
  75 + foreach (Pair pair, Common::Sort(removed,true)) {
  76 + partitionedData.removeAt(pair.first); partitionsBuffer.removeAt(pair.first); j--;
  77 + }
  78 + } else {
  79 + j--;
  80 + }
  81 + } else if (partitions[j] == i) {
  82 + partitionedData.removeAt(j);
  83 + } else j--;
  84 + }
60 // Train on the remaining templates 85 // Train on the remaining templates
61 - foreach (const Template &t, partitionedData) qDebug() << "Remaining data for partition " << i << ": " << t.file.baseName(); 86 + foreach (const Template &t, partitionedData) qDebug() << "Remaining data for partition " << i << t.file.baseName() << t.file.get<QString>("Label") << t.file.get<QString>("Partition");
62 futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); 87 futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData));
63 } 88 }
64 futures.waitForFinished(); 89 futures.waitForFinished();