Commit 8f49129679ea023faa9b97e0de31ff27f01515e2

Authored by Scott Klum
1 parent fca1772a

Using the Train flag to skip project in crossValidation

This should eliminate crossValidation's need for trainable distances entirely.
Showing 1 changed file with 6 additions and 9 deletions
openbr/plugins/validate.cpp
@@ -74,9 +74,7 @@ class CrossValidateTransform : public MetaTransform @@ -74,9 +74,7 @@ class CrossValidateTransform : public MetaTransform
74 // Remove template that was repeated to make the testOnly template 74 // Remove template that was repeated to make the testOnly template
75 if (subjectIndices.size() > 1 && subjectIndices.size() <= i) { 75 if (subjectIndices.size() > 1 && subjectIndices.size() <= i) {
76 removed.append(subjectIndices[i%subjectIndices.size()]); 76 removed.append(subjectIndices[i%subjectIndices.size()]);
77 - }  
78 - // For the time being, we don't support addition training data added to every fold in the case of leaveOneImageOut  
79 - else if (partitionsBuffer[j] == i) { 77 + } else if (partitionsBuffer[j] == i) {
80 removed.append(j); 78 removed.append(j);
81 } 79 }
82 80
@@ -88,10 +86,6 @@ class CrossValidateTransform : public MetaTransform @@ -88,10 +86,6 @@ class CrossValidateTransform : public MetaTransform
88 } else { 86 } else {
89 j--; 87 j--;
90 } 88 }
91 - } else if (partitions[j] == -1) {  
92 - // Keep data for training, but modify the partition so we project into the correct space  
93 - partitionedData[j].file.set("Partition",i);  
94 - j--;  
95 } else if (partitions[j] == i) { 89 } else if (partitions[j] == i) {
96 // Remove data, it's designated for testing 90 // Remove data, it's designated for testing
97 partitionedData.removeAt(j); 91 partitionedData.removeAt(j);
@@ -106,8 +100,8 @@ class CrossValidateTransform : public MetaTransform @@ -106,8 +100,8 @@ class CrossValidateTransform : public MetaTransform
106 100
107 void project(const Template &src, Template &dst) const 101 void project(const Template &src, Template &dst) const
108 { 102 {
109 - qDebug() << src.file.get<int>("Partition", 0);  
110 - transforms[src.file.get<int>("Partition", 0)]->project(src, dst); 103 + if (src.file.getBool("Train", true)) dst = src;
  104 + else transforms[src.file.get<int>("Partition", 0)]->project(src, dst);
111 } 105 }
112 106
113 void store(QDataStream &stream) const 107 void store(QDataStream &stream) const
@@ -146,6 +140,9 @@ class CrossValidateDistance : public Distance @@ -146,6 +140,9 @@ class CrossValidateDistance : public Distance
146 const int partitionB = b.file.get<int>(key, 0); 140 const int partitionB = b.file.get<int>(key, 0);
147 return (partitionA != partitionB) ? -std::numeric_limits<float>::max() : 0; 141 return (partitionA != partitionB) ? -std::numeric_limits<float>::max() : 0;
148 } 142 }
  143 +
  144 +public:
  145 + CrossValidateDistance() : Distance(false) {}
149 }; 146 };
150 147
151 BR_REGISTER(Distance, CrossValidateDistance) 148 BR_REGISTER(Distance, CrossValidateDistance)