Commit 8f49129679ea023faa9b97e0de31ff27f01515e2
1 parent
fca1772a
Using the Train flag to skip project in crossValidation
This should eliminate crossValidation's need for trainable distances entirely.
Showing
1 changed file
with
6 additions
and
9 deletions
openbr/plugins/validate.cpp
| @@ -74,9 +74,7 @@ class CrossValidateTransform : public MetaTransform | @@ -74,9 +74,7 @@ class CrossValidateTransform : public MetaTransform | ||
| 74 | // Remove template that was repeated to make the testOnly template | 74 | // Remove template that was repeated to make the testOnly template |
| 75 | if (subjectIndices.size() > 1 && subjectIndices.size() <= i) { | 75 | if (subjectIndices.size() > 1 && subjectIndices.size() <= i) { |
| 76 | removed.append(subjectIndices[i%subjectIndices.size()]); | 76 | removed.append(subjectIndices[i%subjectIndices.size()]); |
| 77 | - } | ||
| 78 | - // For the time being, we don't support addition training data added to every fold in the case of leaveOneImageOut | ||
| 79 | - else if (partitionsBuffer[j] == i) { | 77 | + } else if (partitionsBuffer[j] == i) { |
| 80 | removed.append(j); | 78 | removed.append(j); |
| 81 | } | 79 | } |
| 82 | 80 | ||
| @@ -88,10 +86,6 @@ class CrossValidateTransform : public MetaTransform | @@ -88,10 +86,6 @@ class CrossValidateTransform : public MetaTransform | ||
| 88 | } else { | 86 | } else { |
| 89 | j--; | 87 | j--; |
| 90 | } | 88 | } |
| 91 | - } else if (partitions[j] == -1) { | ||
| 92 | - // Keep data for training, but modify the partition so we project into the correct space | ||
| 93 | - partitionedData[j].file.set("Partition",i); | ||
| 94 | - j--; | ||
| 95 | } else if (partitions[j] == i) { | 89 | } else if (partitions[j] == i) { |
| 96 | // Remove data, it's designated for testing | 90 | // Remove data, it's designated for testing |
| 97 | partitionedData.removeAt(j); | 91 | partitionedData.removeAt(j); |
| @@ -106,8 +100,8 @@ class CrossValidateTransform : public MetaTransform | @@ -106,8 +100,8 @@ class CrossValidateTransform : public MetaTransform | ||
| 106 | 100 | ||
| 107 | void project(const Template &src, Template &dst) const | 101 | void project(const Template &src, Template &dst) const |
| 108 | { | 102 | { |
| 109 | - qDebug() << src.file.get<int>("Partition", 0); | ||
| 110 | - transforms[src.file.get<int>("Partition", 0)]->project(src, dst); | 103 | + if (src.file.getBool("Train", true)) dst = src; |
| 104 | + else transforms[src.file.get<int>("Partition", 0)]->project(src, dst); | ||
| 111 | } | 105 | } |
| 112 | 106 | ||
| 113 | void store(QDataStream &stream) const | 107 | void store(QDataStream &stream) const |
| @@ -146,6 +140,9 @@ class CrossValidateDistance : public Distance | @@ -146,6 +140,9 @@ class CrossValidateDistance : public Distance | ||
| 146 | const int partitionB = b.file.get<int>(key, 0); | 140 | const int partitionB = b.file.get<int>(key, 0); |
| 147 | return (partitionA != partitionB) ? -std::numeric_limits<float>::max() : 0; | 141 | return (partitionA != partitionB) ? -std::numeric_limits<float>::max() : 0; |
| 148 | } | 142 | } |
| 143 | + | ||
| 144 | +public: | ||
| 145 | + CrossValidateDistance() : Distance(false) {} | ||
| 149 | }; | 146 | }; |
| 150 | 147 | ||
| 151 | BR_REGISTER(Distance, CrossValidateDistance) | 148 | BR_REGISTER(Distance, CrossValidateDistance) |