Commit f914aca9cb48e2e2d088bcf9214b49c7a4e88658
1 parent
d97f68bd
Refactored CrossValidateTransform
Showing
1 changed file
with
24 additions
and
52 deletions
openbr/plugins/core/crossvalidate.cpp
| ... | ... | @@ -32,31 +32,41 @@ static void _train(Transform *transform, TemplateList data) // think data has to |
| 32 | 32 | * \brief Cross validate a trainable transform. |
| 33 | 33 | * \author Josh Klontz \cite jklontz |
| 34 | 34 | * \author Scott Klum \cite sklum |
| 35 | - * \note To use an extended gallery, add an allPartitions="true" flag to the gallery sigset for those images that should be compared | |
| 36 | - * against for all testing partitions. | |
| 35 | + * \note Two flags can be put in File metadata that are related to cross-validation and are used to | |
| 36 | + * extend a testing gallery: | |
| 37 | + * (i) allPartitions - This flag is intended to be used when comparing the | |
| 38 | + * performance of an untrainable algorithm (e.g. a COTS | |
| 39 | + * algorithm) against a trainable algorithm that was trained | |
| 40 | + * using cross-validation. All templates with the allPartitions | |
| 41 | + * flag will be compared against for every partition. As | |
| 42 | + * untrainable algorithms will have no use for the | |
| 43 | + * CrossValidateTransform, this flag is only meaningful at comparison | |
| 44 | + * time (but care has been taken so that one can train and enroll | |
| 45 | + * without issue if these Files are present in the used Gallery). | |
| 46 | + * (ii) duplicatePartitions - This flag is similar to allPartitions in that it causes | |
| 47 | + * the same template to be used during comparison for every partition. | |
| 48 | + * The difference is that duplicatePartitions will duplicate each | |
| 49 | + * marked template and project it into the model space constituded | |
| 50 | + * by the child transforms of CrossValidateTransform. Again, care | |
| 51 | + * has been take such that one can train with these templates in the | |
| 52 | + * used Gallery successfully (they will simply be omitted). | |
| 37 | 53 | */ |
| 38 | 54 | class CrossValidateTransform : public MetaTransform |
| 39 | 55 | { |
| 40 | 56 | Q_OBJECT |
| 41 | 57 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) |
| 42 | - Q_PROPERTY(bool leaveOneImageOut READ get_leaveOneImageOut WRITE set_leaveOneImageOut RESET reset_leaveOneImageOut STORED false) | |
| 43 | 58 | BR_PROPERTY(QString, description, "Identity") |
| 44 | - BR_PROPERTY(bool, leaveOneImageOut, false) | |
| 45 | 59 | |
| 46 | 60 | // numPartitions copies of transform specified by description. |
| 47 | 61 | QList<br::Transform*> transforms; |
| 48 | 62 | |
| 49 | - // Treating this transform as a leaf (in terms of update training scheme), the child transform | |
| 63 | + // Treating this transform as a leaf (in terms of updated training scheme), the child transform | |
| 50 | 64 | // of this transform will lose any structure present in the training QList<TemplateList>, which |
| 51 | 65 | // is generally incorrect behavior. |
| 52 | 66 | void train(const TemplateList &data) |
| 53 | 67 | { |
| 54 | - int numPartitions = 0; | |
| 55 | - QList<int> partitions; partitions.reserve(data.size()); | |
| 56 | - foreach (const File &file, data.files()) { | |
| 57 | - partitions.append(file.get<int>("Partition", 0)); | |
| 58 | - numPartitions = std::max(numPartitions, partitions.last()+1); | |
| 59 | - } | |
| 68 | + QList<int> partitions = data.files().crossValidationPartitions(); | |
| 69 | + const int numPartitions = Common::Max(partitions)+1; | |
| 60 | 70 | |
| 61 | 71 | while (transforms.size() < numPartitions) |
| 62 | 72 | transforms.append(make(description)); |
| ... | ... | @@ -68,45 +78,12 @@ class CrossValidateTransform : public MetaTransform |
| 68 | 78 | |
| 69 | 79 | QFutureSynchronizer<void> futures; |
| 70 | 80 | for (int i=0; i<numPartitions; i++) { |
| 71 | - QList<int> partitionsBuffer = partitions; | |
| 72 | 81 | TemplateList partitionedData = data; |
| 73 | - int j = partitionedData.size()-1; | |
| 74 | - while (j>=0) { | |
| 75 | - // Remove all templates belonging to partition i | |
| 76 | - // if leaveOneImageOut is true, | |
| 77 | - // and i is greater than the number of images for a particular subject | |
| 78 | - // even if the partitions are different | |
| 79 | - if (leaveOneImageOut) { | |
| 80 | - const QString label = partitionedData.at(j).file.get<QString>("Label"); | |
| 81 | - QList<int> subjectIndices = partitionedData.find("Label",label); | |
| 82 | - QList<int> removed; | |
| 83 | - // Remove target only data | |
| 84 | - for (int k=subjectIndices.size()-1; k>=0; k--) | |
| 85 | - if (partitionedData[subjectIndices[k]].file.getBool("targetOnly")) { | |
| 86 | - removed.append(subjectIndices[k]); | |
| 87 | - subjectIndices.removeAt(k); | |
| 88 | - } | |
| 89 | - // Remove template that was repeated to make the testOnly template | |
| 90 | - if (subjectIndices.size() > 1 && subjectIndices.size() <= i) { | |
| 91 | - removed.append(subjectIndices[i%subjectIndices.size()]); | |
| 92 | - } else if (partitionsBuffer[j] == i) { | |
| 93 | - removed.append(j); | |
| 94 | - } | |
| 95 | - | |
| 96 | - if (!removed.empty()) { | |
| 97 | - typedef QPair<int,int> Pair; | |
| 98 | - foreach (Pair pair, Common::Sort(removed,true)) { | |
| 99 | - partitionedData.removeAt(pair.first); partitionsBuffer.removeAt(pair.first); j--; | |
| 100 | - } | |
| 101 | - } else { | |
| 102 | - j--; | |
| 103 | - } | |
| 104 | - } else if (partitions[j] == i) { | |
| 82 | + for (int j=partitionedData.size()-1; j>=0; j--) | |
| 83 | + if (partitions[j] == i) { | |
| 105 | 84 | // Remove data, it's designated for testing |
| 106 | 85 | partitionedData.removeAt(j); |
| 107 | - j--; | |
| 108 | - } else j--; | |
| 109 | - } | |
| 86 | + } | |
| 110 | 87 | // Train on the remaining templates |
| 111 | 88 | futures.addFuture(QtConcurrent::run(_train, transforms[i], partitionedData)); |
| 112 | 89 | } |
| ... | ... | @@ -115,11 +92,6 @@ class CrossValidateTransform : public MetaTransform |
| 115 | 92 | |
| 116 | 93 | void project(const Template &src, Template &dst) const |
| 117 | 94 | { |
| 118 | - // Remember, the partition should never be -1 | |
| 119 | - // since it is assumed that the allPartitions | |
| 120 | - // flag is only used during comparison | |
| 121 | - // (i.e. only used when making a mask) | |
| 122 | - | |
| 123 | 95 | // If we want to duplicate templates but use the same training data |
| 124 | 96 | // for all partitions (i.e. transforms.size() == 1), we need to |
| 125 | 97 | // restrict the partition | ... | ... |