Commit 50e50fe695c71c8f8fd5a903697ac6982a2a4777

Authored by Scott Klum
1 parent 3b11241d

Allow crossValidation to train on a single dataset but project for multiple partitions

Showing 1 changed file with 9 additions and 5 deletions
openbr/plugins/validate.cpp
@@ -100,8 +100,15 @@ class CrossValidateTransform : public MetaTransform @@ -100,8 +100,15 @@ class CrossValidateTransform : public MetaTransform
100 100
101 void project(const Template &src, Template &dst) const 101 void project(const Template &src, Template &dst) const
102 { 102 {
103 - if (src.file.getBool("Train", true)) dst = src;  
104 - else transforms[src.file.get<int>("Partition", 0)]->project(src, dst); 103 + if (src.file.getBool("Train", false)) dst = src;
  104 + else {
  105 + // If we want to duplicate templates but use the same training data
  106 + // for all partitions (i.e. transforms.size() == 1), we need to
  107 + // restrict the partition
  108 + int partition = src.file.get<int>("Partition", 0);
  109 + partition = (partition >= transforms.size()) ? 0 : partition;
  110 + transforms[partition]->project(src, dst);
  111 + }
105 } 112 }
106 113
107 void store(QDataStream &stream) const 114 void store(QDataStream &stream) const
@@ -140,9 +147,6 @@ class CrossValidateDistance : public Distance @@ -140,9 +147,6 @@ class CrossValidateDistance : public Distance
140 const int partitionB = b.file.get<int>(key, 0); 147 const int partitionB = b.file.get<int>(key, 0);
141 return (partitionA != partitionB) ? -std::numeric_limits<float>::max() : 0; 148 return (partitionA != partitionB) ? -std::numeric_limits<float>::max() : 0;
142 } 149 }
143 -  
144 -public:  
145 - CrossValidateDistance() : Distance(false) {}  
146 }; 150 };
147 151
148 BR_REGISTER(Distance, CrossValidateDistance) 152 BR_REGISTER(Distance, CrossValidateDistance)