Commit 50e50fe695c71c8f8fd5a903697ac6982a2a4777
1 parent
3b11241d
Allow crossValidation to train on a single dataset but project for multiple partitions
Showing
1 changed file
with
9 additions
and
5 deletions
openbr/plugins/validate.cpp
| ... | ... | @@ -100,8 +100,15 @@ class CrossValidateTransform : public MetaTransform |
| 100 | 100 | |
| 101 | 101 | void project(const Template &src, Template &dst) const |
| 102 | 102 | { |
| 103 | - if (src.file.getBool("Train", true)) dst = src; | |
| 104 | - else transforms[src.file.get<int>("Partition", 0)]->project(src, dst); | |
| 103 | + if (src.file.getBool("Train", false)) dst = src; | |
| 104 | + else { | |
| 105 | + // If we want to duplicate templates but use the same training data | |
| 106 | + // for all partitions (i.e. transforms.size() == 1), we need to | |
| 107 | + // restrict the partition | |
| 108 | + int partition = src.file.get<int>("Partition", 0); | |
| 109 | + partition = (partition >= transforms.size()) ? 0 : partition; | |
| 110 | + transforms[partition]->project(src, dst); | |
| 111 | + } | |
| 105 | 112 | } |
| 106 | 113 | |
| 107 | 114 | void store(QDataStream &stream) const |
| ... | ... | @@ -140,9 +147,6 @@ class CrossValidateDistance : public Distance |
| 140 | 147 | const int partitionB = b.file.get<int>(key, 0); |
| 141 | 148 | return (partitionA != partitionB) ? -std::numeric_limits<float>::max() : 0; |
| 142 | 149 | } |
| 143 | - | |
| 144 | -public: | |
| 145 | - CrossValidateDistance() : Distance(false) {} | |
| 146 | 150 | }; |
| 147 | 151 | |
| 148 | 152 | BR_REGISTER(Distance, CrossValidateDistance) | ... | ... |