Commit 8f49129679ea023faa9b97e0de31ff27f01515e2
1 parent
fca1772a
Using the Train flag to skip project in crossValidation
This should eliminate crossValidation's need for trainable distances entirely.
Showing
1 changed file
with
6 additions
and
9 deletions
openbr/plugins/validate.cpp
| ... | ... | @@ -74,9 +74,7 @@ class CrossValidateTransform : public MetaTransform |
| 74 | 74 | // Remove template that was repeated to make the testOnly template |
| 75 | 75 | if (subjectIndices.size() > 1 && subjectIndices.size() <= i) { |
| 76 | 76 | removed.append(subjectIndices[i%subjectIndices.size()]); |
| 77 | - } | |
| 78 | - // For the time being, we don't support addition training data added to every fold in the case of leaveOneImageOut | |
| 79 | - else if (partitionsBuffer[j] == i) { | |
| 77 | + } else if (partitionsBuffer[j] == i) { | |
| 80 | 78 | removed.append(j); |
| 81 | 79 | } |
| 82 | 80 | |
| ... | ... | @@ -88,10 +86,6 @@ class CrossValidateTransform : public MetaTransform |
| 88 | 86 | } else { |
| 89 | 87 | j--; |
| 90 | 88 | } |
| 91 | - } else if (partitions[j] == -1) { | |
| 92 | - // Keep data for training, but modify the partition so we project into the correct space | |
| 93 | - partitionedData[j].file.set("Partition",i); | |
| 94 | - j--; | |
| 95 | 89 | } else if (partitions[j] == i) { |
| 96 | 90 | // Remove data, it's designated for testing |
| 97 | 91 | partitionedData.removeAt(j); |
| ... | ... | @@ -106,8 +100,8 @@ class CrossValidateTransform : public MetaTransform |
| 106 | 100 | |
| 107 | 101 | void project(const Template &src, Template &dst) const |
| 108 | 102 | { |
| 109 | - qDebug() << src.file.get<int>("Partition", 0); | |
| 110 | - transforms[src.file.get<int>("Partition", 0)]->project(src, dst); | |
| 103 | + if (src.file.getBool("Train", true)) dst = src; | |
| 104 | + else transforms[src.file.get<int>("Partition", 0)]->project(src, dst); | |
| 111 | 105 | } |
| 112 | 106 | |
| 113 | 107 | void store(QDataStream &stream) const |
| ... | ... | @@ -146,6 +140,9 @@ class CrossValidateDistance : public Distance |
| 146 | 140 | const int partitionB = b.file.get<int>(key, 0); |
| 147 | 141 | return (partitionA != partitionB) ? -std::numeric_limits<float>::max() : 0; |
| 148 | 142 | } |
| 143 | + | |
| 144 | +public: | |
| 145 | + CrossValidateDistance() : Distance(false) {} | |
| 149 | 146 | }; |
| 150 | 147 | |
| 151 | 148 | BR_REGISTER(Distance, CrossValidateDistance) | ... | ... |