Commit 92f650219aa28d9eeae48adc4502280805927196

Authored by Scott Klum
1 parent 51b329d6

Finished leaveOneOut crossValidation

openbr/core/bee.cpp
... ... @@ -287,8 +287,8 @@ cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries,
287 287  
288 288 Mask_t val;
289 289 if (fileA == fileB) val = DontCare;
290   - else if (labelA == "-1") val = DontCare;
291   - else if (labelB == "-1") val = DontCare;
  290 + else if (labelA == "-1") val = DontCare;
  291 + else if (labelB == "-1") val = DontCare;
292 292 else if (partitionA != partition) val = DontCare;
293 293 else if (partitionB == -1) val = NonMatch;
294 294 else if (partitionB != partition) val = DontCare;
... ...
openbr/openbr_plugin.cpp
... ... @@ -388,23 +388,31 @@ TemplateList TemplateList::fromGallery(const br::File &gallery)
388 388 const int crossValidate = gallery.get<int>("crossValidate");
389 389  
390 390 if (gallery.getBool("leaveOneOut")) {
391   - QStringList subjects;
392   - for (int i = 0; i < newTemplates.size(); i++) {
393   - QString subject = newTemplates.at(i).file.get<QString>("Label");
  391 + QStringList labels;
  392 + for (int i=newTemplates.size()-1; i>=0; i--) {
  393 + newTemplates[i].file.set("Index", i+templates.size());
  394 + newTemplates[i].file.set("Gallery", gallery.name);
  395 +
  396 + QString label = newTemplates.at(i).file.get<QString>("Label");
394 397 // Have we seen this subject before?
395   - if (subjects.contains(subject)) {
396   - subjects.append(subject);
  398 + if (!labels.contains(label)) {
  399 + labels.append(label);
397 400 // Get indices belonging to this subject
398   - QList<int> subjectIndices = newTemplates.find("Label",subject);
399   - for (int j = 0; j < subjectIndices.size(); j++) {
  401 + QList<int> labelIndices = newTemplates.find("Label",label);
  402 + for (int j = 0; j < labelIndices.size(); j++) {
400 403 // Set subject partitions
401   - newTemplates[subjectIndices[j]].file.set("Partition",j);
  404 + newTemplates[labelIndices[j]].file.set("Partition",j%crossValidate);
402 405 }
403   - // Generate more templates if necessary
404   - for (int j=0; j<crossValidate-subjectIndices.size(); j++) {
405   - Template leaveOneOutTemplate = newTemplates[subjectIndices[j%subjectIndices.size()]];
406   - leaveOneOutTemplate.file.set("Partition", j+subjectIndices.size());
407   - newTemplates.append(leaveOneOutTemplate);
  406 + // Extend the gallery for each partition
  407 + for (int j=0; j<labelIndices.size(); j++) {
  408 + for (int k=0; k<crossValidate; k++) {
  409 + Template leaveOneOutTemplate = newTemplates[labelIndices[j]];
  410 + if (k!=leaveOneOutTemplate.file.get<int>("Partition")) {
  411 + leaveOneOutTemplate.file.set("Partition", k);
  412 + leaveOneOutTemplate.file.set("testOnly", true);
  413 + newTemplates.insert(i+1,leaveOneOutTemplate);
  414 + }
  415 + }
408 416 }
409 417 }
410 418 }
... ...
openbr/plugins/stasm4.cpp
... ... @@ -36,8 +36,8 @@ class StasmInitializer : public Initializer
36 36  
37 37 void initialize() const
38 38 {
39   - Globals->abbreviations.insert("RectFromStasmEyes","RectFromPoints([29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],0.125,6.0)+Resize(44,164)");
40   - Globals->abbreviations.insert("RectFromStasmBrow","RectFromPoints([17, 18, 19, 20, 21, 22, 23, 24],0.15,6)+Resize(28,132)");
  39 + Globals->abbreviations.insert("RectFromStasmEyes","RectFromPoints([29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],0.125,6.0)");
  40 + Globals->abbreviations.insert("RectFromStasmBrow","RectFromPoints([16,17,18,19,20,21,22,23,24,25,26,27],0.15,5)");
41 41 Globals->abbreviations.insert("RectFromStasmNose","RectFromPoints([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58],0.15,1.25)");
42 42 Globals->abbreviations.insert("RectFromStasmMouth","RectFromPoints([59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76],0.3,2.5)");
43 43 }
... ...
openbr/plugins/validate.cpp
... ... @@ -11,6 +11,7 @@ namespace br
11 11 * \ingroup transforms
12 12 * \brief Cross validate a trainable transform.
13 13 * \author Josh Klontz \cite jklontz
  14 + * \author Scott Klum \cite sklum
14 15 * \note To use an extended gallery, add an allPartitions="true" flag to the gallery sigset for those images that should be compared
15 16 * against for all testing partitions.
16 17 */
... ... @@ -43,22 +44,46 @@ class CrossValidateTransform : public MetaTransform
43 44  
44 45 QFutureSynchronizer<void> futures;
45 46 for (int i=0; i<numPartitions; i++) {
  47 + QList<int> partitionsBuffer = partitions;
46 48 TemplateList partitionedData = data;
47   - QList<int> removed;
48   - for (int j=partitionedData.size()-1; j>=0; j--)
  49 + int j = partitionedData.size()-1;
  50 + while (j>=0) {
49 51 // Remove all templates belonging to partition i
50 52 // if leaveOneOut is true,
51 53 // and i is greater than the number of images for a particular subject
52 54 // even if the partitions are different
53 55 if (leaveOneOut) {
54   - QList<int> subjectIndices = partitionedData.find("Subject",partitionedData.at(j).file.get<QString>("Subject"));
55   - if (i > subjectIndices.size()) removed.append(subjectIndices[i%subjectIndices.size()]);
56   - } else if (partitions[j] == i)
57   - removed.append(j);
58   - typedef QPair<int,int> Pair;
59   - foreach (const Pair &pair, Common::Sort(removed,true)) partitionedData.removeAt(pair.first);
  56 + const QString label = partitionedData.at(j).file.get<QString>("Label");
  57 + QList<int> subjectIndices = partitionedData.find("Label",label);
  58 + QList<int> removed;
  59 + // Remove test only data
  60 + for (int k=subjectIndices.size()-1; k>=0; k--)
  61 + if (partitionedData[subjectIndices[k]].file.getBool("testOnly")) {
  62 + removed.append(subjectIndices[k]);
  63 + subjectIndices.removeAt(k);
  64 + }
  65 + // Remove template that was repeated to make the testOnly template
  66 + if (subjectIndices.size() > 1 && subjectIndices.size() <= i) {
  67 + removed.append(subjectIndices[i%subjectIndices.size()]);
  68 + }
  69 + else if (partitionsBuffer[j] == i) {
  70 + removed.append(j);
  71 + }
  72 +
  73 + if (!removed.empty()) {
  74 + typedef QPair<int,int> Pair;
  75 + foreach (Pair pair, Common::Sort(removed,true)) {
  76 + partitionedData.removeAt(pair.first); partitionsBuffer.removeAt(pair.first); j--;
  77 + }
  78 + } else {
  79 + j--;
  80 + }
  81 + } else if (partitions[j] == i) {
  82 + partitionedData.removeAt(j);
  83 + } else j--;
  84 + }
60 85 // Train on the remaining templates
61   - foreach (const Template &t, partitionedData) qDebug() << "Remaining data for partition " << i << ": " << t.file.baseName();
  86 + foreach (const Template &t, partitionedData) qDebug() << "Remaining data for partition " << i << t.file.baseName() << t.file.get<QString>("Label") << t.file.get<QString>("Partition");
62 87 futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData));
63 88 }
64 89 futures.waitForFinished();
... ...