Commit 8f49129679ea023faa9b97e0de31ff27f01515e2

Authored by Scott Klum
1 parent fca1772a

Using the Train flag to skip project in crossValidation

This should eliminate crossValidation's need for trainable distances entirely.
Showing 1 changed file with 6 additions and 9 deletions
openbr/plugins/validate.cpp
... ... @@ -74,9 +74,7 @@ class CrossValidateTransform : public MetaTransform
74 74 // Remove template that was repeated to make the testOnly template
75 75 if (subjectIndices.size() > 1 && subjectIndices.size() <= i) {
76 76 removed.append(subjectIndices[i%subjectIndices.size()]);
77   - }
78   - // For the time being, we don't support addition training data added to every fold in the case of leaveOneImageOut
79   - else if (partitionsBuffer[j] == i) {
  77 + } else if (partitionsBuffer[j] == i) {
80 78 removed.append(j);
81 79 }
82 80  
... ... @@ -88,10 +86,6 @@ class CrossValidateTransform : public MetaTransform
88 86 } else {
89 87 j--;
90 88 }
91   - } else if (partitions[j] == -1) {
92   - // Keep data for training, but modify the partition so we project into the correct space
93   - partitionedData[j].file.set("Partition",i);
94   - j--;
95 89 } else if (partitions[j] == i) {
96 90 // Remove data, it's designated for testing
97 91 partitionedData.removeAt(j);
... ... @@ -106,8 +100,8 @@ class CrossValidateTransform : public MetaTransform
106 100  
107 101 void project(const Template &src, Template &dst) const
108 102 {
109   - qDebug() << src.file.get<int>("Partition", 0);
110   - transforms[src.file.get<int>("Partition", 0)]->project(src, dst);
  103 + if (src.file.getBool("Train", true)) dst = src;
  104 + else transforms[src.file.get<int>("Partition", 0)]->project(src, dst);
111 105 }
112 106  
113 107 void store(QDataStream &stream) const
... ... @@ -146,6 +140,9 @@ class CrossValidateDistance : public Distance
146 140 const int partitionB = b.file.get<int>(key, 0);
147 141 return (partitionA != partitionB) ? -std::numeric_limits<float>::max() : 0;
148 142 }
  143 +
  144 +public:
  145 + CrossValidateDistance() : Distance(false) {}
149 146 };
150 147  
151 148 BR_REGISTER(Distance, CrossValidateDistance)
... ...