From f914aca9cb48e2e2d088bcf9214b49c7a4e88658 Mon Sep 17 00:00:00 2001 From: Scott Klum Date: Mon, 13 Apr 2015 16:23:34 -0400 Subject: [PATCH] Refactored CrossValidateTransform --- openbr/plugins/core/crossvalidate.cpp | 76 ++++++++++++++++++++++++---------------------------------------------------- 1 file changed, 24 insertions(+), 52 deletions(-) diff --git a/openbr/plugins/core/crossvalidate.cpp b/openbr/plugins/core/crossvalidate.cpp index eb86143..8f37e9f 100644 --- a/openbr/plugins/core/crossvalidate.cpp +++ b/openbr/plugins/core/crossvalidate.cpp @@ -32,31 +32,41 @@ static void _train(Transform *transform, TemplateList data) // think data has to * \brief Cross validate a trainable transform. * \author Josh Klontz \cite jklontz * \author Scott Klum \cite sklum - * \note To use an extended gallery, add an allPartitions="true" flag to the gallery sigset for those images that should be compared - * against for all testing partitions. + * \note Two flags can be put in File metadata that are related to cross-validation and are used to + * extend a testing gallery: + * (i) allPartitions - This flag is intended to be used when comparing the + * performance of an untrainable algorithm (e.g. a COTS + * algorithm) against a trainable algorithm that was trained + * using cross-validation. All templates with the allPartitions + * flag will be compared against for every partition. As + * untrainable algorithms will have no use for the + * CrossValidateTransform, this flag is only meaningful at comparison + * time (but care has been taken so that one can train and enroll + * without issue if these Files are present in the used Gallery). + * (ii) duplicatePartitions - This flag is similar to allPartitions in that it causes + * the same template to be used during comparison for every partition. + * The difference is that duplicatePartitions will duplicate each + * marked template and project it into the model space constituded + * by the child transforms of CrossValidateTransform. Again, care + * has been take such that one can train with these templates in the + * used Gallery successfully (they will simply be omitted). */ class CrossValidateTransform : public MetaTransform { Q_OBJECT Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) - Q_PROPERTY(bool leaveOneImageOut READ get_leaveOneImageOut WRITE set_leaveOneImageOut RESET reset_leaveOneImageOut STORED false) BR_PROPERTY(QString, description, "Identity") - BR_PROPERTY(bool, leaveOneImageOut, false) // numPartitions copies of transform specified by description. QList transforms; - // Treating this transform as a leaf (in terms of update training scheme), the child transform + // Treating this transform as a leaf (in terms of updated training scheme), the child transform // of this transform will lose any structure present in the training QList, which // is generally incorrect behavior. void train(const TemplateList &data) { - int numPartitions = 0; - QList partitions; partitions.reserve(data.size()); - foreach (const File &file, data.files()) { - partitions.append(file.get("Partition", 0)); - numPartitions = std::max(numPartitions, partitions.last()+1); - } + QList partitions = data.files().crossValidationPartitions(); + const int numPartitions = Common::Max(partitions)+1; while (transforms.size() < numPartitions) transforms.append(make(description)); @@ -68,45 +78,12 @@ class CrossValidateTransform : public MetaTransform QFutureSynchronizer futures; for (int i=0; i partitionsBuffer = partitions; TemplateList partitionedData = data; - int j = partitionedData.size()-1; - while (j>=0) { - // Remove all templates belonging to partition i - // if leaveOneImageOut is true, - // and i is greater than the number of images for a particular subject - // even if the partitions are different - if (leaveOneImageOut) { - const QString label = partitionedData.at(j).file.get("Label"); - QList subjectIndices = partitionedData.find("Label",label); - QList removed; - // Remove target only data - for (int k=subjectIndices.size()-1; k>=0; k--) - if (partitionedData[subjectIndices[k]].file.getBool("targetOnly")) { - removed.append(subjectIndices[k]); - subjectIndices.removeAt(k); - } - // Remove template that was repeated to make the testOnly template - if (subjectIndices.size() > 1 && subjectIndices.size() <= i) { - removed.append(subjectIndices[i%subjectIndices.size()]); - } else if (partitionsBuffer[j] == i) { - removed.append(j); - } - - if (!removed.empty()) { - typedef QPair Pair; - foreach (Pair pair, Common::Sort(removed,true)) { - partitionedData.removeAt(pair.first); partitionsBuffer.removeAt(pair.first); j--; - } - } else { - j--; - } - } else if (partitions[j] == i) { + for (int j=partitionedData.size()-1; j>=0; j--) + if (partitions[j] == i) { // Remove data, it's designated for testing partitionedData.removeAt(j); - j--; - } else j--; - } + } // Train on the remaining templates futures.addFuture(QtConcurrent::run(_train, transforms[i], partitionedData)); } @@ -115,11 +92,6 @@ class CrossValidateTransform : public MetaTransform void project(const Template &src, Template &dst) const { - // Remember, the partition should never be -1 - // since it is assumed that the allPartitions - // flag is only used during comparison - // (i.e. only used when making a mask) - // If we want to duplicate templates but use the same training data // for all partitions (i.e. transforms.size() == 1), we need to // restrict the partition -- libgit2 0.21.4