Commit 50e50fe695c71c8f8fd5a903697ac6982a2a4777
1 parent
3b11241d
Allow crossValidation to train on a single dataset but project for multiple partitions
Showing
1 changed file
with
9 additions
and
5 deletions
openbr/plugins/validate.cpp
| @@ -100,8 +100,15 @@ class CrossValidateTransform : public MetaTransform | @@ -100,8 +100,15 @@ class CrossValidateTransform : public MetaTransform | ||
| 100 | 100 | ||
| 101 | void project(const Template &src, Template &dst) const | 101 | void project(const Template &src, Template &dst) const |
| 102 | { | 102 | { |
| 103 | - if (src.file.getBool("Train", true)) dst = src; | ||
| 104 | - else transforms[src.file.get<int>("Partition", 0)]->project(src, dst); | 103 | + if (src.file.getBool("Train", false)) dst = src; |
| 104 | + else { | ||
| 105 | + // If we want to duplicate templates but use the same training data | ||
| 106 | + // for all partitions (i.e. transforms.size() == 1), we need to | ||
| 107 | + // restrict the partition | ||
| 108 | + int partition = src.file.get<int>("Partition", 0); | ||
| 109 | + partition = (partition >= transforms.size()) ? 0 : partition; | ||
| 110 | + transforms[partition]->project(src, dst); | ||
| 111 | + } | ||
| 105 | } | 112 | } |
| 106 | 113 | ||
| 107 | void store(QDataStream &stream) const | 114 | void store(QDataStream &stream) const |
| @@ -140,9 +147,6 @@ class CrossValidateDistance : public Distance | @@ -140,9 +147,6 @@ class CrossValidateDistance : public Distance | ||
| 140 | const int partitionB = b.file.get<int>(key, 0); | 147 | const int partitionB = b.file.get<int>(key, 0); |
| 141 | return (partitionA != partitionB) ? -std::numeric_limits<float>::max() : 0; | 148 | return (partitionA != partitionB) ? -std::numeric_limits<float>::max() : 0; |
| 142 | } | 149 | } |
| 143 | - | ||
| 144 | -public: | ||
| 145 | - CrossValidateDistance() : Distance(false) {} | ||
| 146 | }; | 150 | }; |
| 147 | 151 | ||
| 148 | BR_REGISTER(Distance, CrossValidateDistance) | 152 | BR_REGISTER(Distance, CrossValidateDistance) |