Commit f914aca9cb48e2e2d088bcf9214b49c7a4e88658
1 parent
d97f68bd
Refactored CrossValidateTransform
Showing
1 changed file
with
24 additions
and
52 deletions
openbr/plugins/core/crossvalidate.cpp
| @@ -32,31 +32,41 @@ static void _train(Transform *transform, TemplateList data) // think data has to | @@ -32,31 +32,41 @@ static void _train(Transform *transform, TemplateList data) // think data has to | ||
| 32 | * \brief Cross validate a trainable transform. | 32 | * \brief Cross validate a trainable transform. |
| 33 | * \author Josh Klontz \cite jklontz | 33 | * \author Josh Klontz \cite jklontz |
| 34 | * \author Scott Klum \cite sklum | 34 | * \author Scott Klum \cite sklum |
| 35 | - * \note To use an extended gallery, add an allPartitions="true" flag to the gallery sigset for those images that should be compared | ||
| 36 | - * against for all testing partitions. | 35 | + * \note Two flags can be put in File metadata that are related to cross-validation and are used to |
| 36 | + * extend a testing gallery: | ||
| 37 | + * (i) allPartitions - This flag is intended to be used when comparing the | ||
| 38 | + * performance of an untrainable algorithm (e.g. a COTS | ||
| 39 | + * algorithm) against a trainable algorithm that was trained | ||
| 40 | + * using cross-validation. All templates with the allPartitions | ||
| 41 | + * flag will be compared against for every partition. As | ||
| 42 | + * untrainable algorithms will have no use for the | ||
| 43 | + * CrossValidateTransform, this flag is only meaningful at comparison | ||
| 44 | + * time (but care has been taken so that one can train and enroll | ||
| 45 | + * without issue if these Files are present in the used Gallery). | ||
| 46 | + * (ii) duplicatePartitions - This flag is similar to allPartitions in that it causes | ||
| 47 | + * the same template to be used during comparison for every partition. | ||
| 48 | + * The difference is that duplicatePartitions will duplicate each | ||
| 49 | + * marked template and project it into the model space constituded | ||
| 50 | + * by the child transforms of CrossValidateTransform. Again, care | ||
| 51 | + * has been take such that one can train with these templates in the | ||
| 52 | + * used Gallery successfully (they will simply be omitted). | ||
| 37 | */ | 53 | */ |
| 38 | class CrossValidateTransform : public MetaTransform | 54 | class CrossValidateTransform : public MetaTransform |
| 39 | { | 55 | { |
| 40 | Q_OBJECT | 56 | Q_OBJECT |
| 41 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) | 57 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) |
| 42 | - Q_PROPERTY(bool leaveOneImageOut READ get_leaveOneImageOut WRITE set_leaveOneImageOut RESET reset_leaveOneImageOut STORED false) | ||
| 43 | BR_PROPERTY(QString, description, "Identity") | 58 | BR_PROPERTY(QString, description, "Identity") |
| 44 | - BR_PROPERTY(bool, leaveOneImageOut, false) | ||
| 45 | 59 | ||
| 46 | // numPartitions copies of transform specified by description. | 60 | // numPartitions copies of transform specified by description. |
| 47 | QList<br::Transform*> transforms; | 61 | QList<br::Transform*> transforms; |
| 48 | 62 | ||
| 49 | - // Treating this transform as a leaf (in terms of update training scheme), the child transform | 63 | + // Treating this transform as a leaf (in terms of updated training scheme), the child transform |
| 50 | // of this transform will lose any structure present in the training QList<TemplateList>, which | 64 | // of this transform will lose any structure present in the training QList<TemplateList>, which |
| 51 | // is generally incorrect behavior. | 65 | // is generally incorrect behavior. |
| 52 | void train(const TemplateList &data) | 66 | void train(const TemplateList &data) |
| 53 | { | 67 | { |
| 54 | - int numPartitions = 0; | ||
| 55 | - QList<int> partitions; partitions.reserve(data.size()); | ||
| 56 | - foreach (const File &file, data.files()) { | ||
| 57 | - partitions.append(file.get<int>("Partition", 0)); | ||
| 58 | - numPartitions = std::max(numPartitions, partitions.last()+1); | ||
| 59 | - } | 68 | + QList<int> partitions = data.files().crossValidationPartitions(); |
| 69 | + const int numPartitions = Common::Max(partitions)+1; | ||
| 60 | 70 | ||
| 61 | while (transforms.size() < numPartitions) | 71 | while (transforms.size() < numPartitions) |
| 62 | transforms.append(make(description)); | 72 | transforms.append(make(description)); |
| @@ -68,45 +78,12 @@ class CrossValidateTransform : public MetaTransform | @@ -68,45 +78,12 @@ class CrossValidateTransform : public MetaTransform | ||
| 68 | 78 | ||
| 69 | QFutureSynchronizer<void> futures; | 79 | QFutureSynchronizer<void> futures; |
| 70 | for (int i=0; i<numPartitions; i++) { | 80 | for (int i=0; i<numPartitions; i++) { |
| 71 | - QList<int> partitionsBuffer = partitions; | ||
| 72 | TemplateList partitionedData = data; | 81 | TemplateList partitionedData = data; |
| 73 | - int j = partitionedData.size()-1; | ||
| 74 | - while (j>=0) { | ||
| 75 | - // Remove all templates belonging to partition i | ||
| 76 | - // if leaveOneImageOut is true, | ||
| 77 | - // and i is greater than the number of images for a particular subject | ||
| 78 | - // even if the partitions are different | ||
| 79 | - if (leaveOneImageOut) { | ||
| 80 | - const QString label = partitionedData.at(j).file.get<QString>("Label"); | ||
| 81 | - QList<int> subjectIndices = partitionedData.find("Label",label); | ||
| 82 | - QList<int> removed; | ||
| 83 | - // Remove target only data | ||
| 84 | - for (int k=subjectIndices.size()-1; k>=0; k--) | ||
| 85 | - if (partitionedData[subjectIndices[k]].file.getBool("targetOnly")) { | ||
| 86 | - removed.append(subjectIndices[k]); | ||
| 87 | - subjectIndices.removeAt(k); | ||
| 88 | - } | ||
| 89 | - // Remove template that was repeated to make the testOnly template | ||
| 90 | - if (subjectIndices.size() > 1 && subjectIndices.size() <= i) { | ||
| 91 | - removed.append(subjectIndices[i%subjectIndices.size()]); | ||
| 92 | - } else if (partitionsBuffer[j] == i) { | ||
| 93 | - removed.append(j); | ||
| 94 | - } | ||
| 95 | - | ||
| 96 | - if (!removed.empty()) { | ||
| 97 | - typedef QPair<int,int> Pair; | ||
| 98 | - foreach (Pair pair, Common::Sort(removed,true)) { | ||
| 99 | - partitionedData.removeAt(pair.first); partitionsBuffer.removeAt(pair.first); j--; | ||
| 100 | - } | ||
| 101 | - } else { | ||
| 102 | - j--; | ||
| 103 | - } | ||
| 104 | - } else if (partitions[j] == i) { | 82 | + for (int j=partitionedData.size()-1; j>=0; j--) |
| 83 | + if (partitions[j] == i) { | ||
| 105 | // Remove data, it's designated for testing | 84 | // Remove data, it's designated for testing |
| 106 | partitionedData.removeAt(j); | 85 | partitionedData.removeAt(j); |
| 107 | - j--; | ||
| 108 | - } else j--; | ||
| 109 | - } | 86 | + } |
| 110 | // Train on the remaining templates | 87 | // Train on the remaining templates |
| 111 | futures.addFuture(QtConcurrent::run(_train, transforms[i], partitionedData)); | 88 | futures.addFuture(QtConcurrent::run(_train, transforms[i], partitionedData)); |
| 112 | } | 89 | } |
| @@ -115,11 +92,6 @@ class CrossValidateTransform : public MetaTransform | @@ -115,11 +92,6 @@ class CrossValidateTransform : public MetaTransform | ||
| 115 | 92 | ||
| 116 | void project(const Template &src, Template &dst) const | 93 | void project(const Template &src, Template &dst) const |
| 117 | { | 94 | { |
| 118 | - // Remember, the partition should never be -1 | ||
| 119 | - // since it is assumed that the allPartitions | ||
| 120 | - // flag is only used during comparison | ||
| 121 | - // (i.e. only used when making a mask) | ||
| 122 | - | ||
| 123 | // If we want to duplicate templates but use the same training data | 95 | // If we want to duplicate templates but use the same training data |
| 124 | // for all partitions (i.e. transforms.size() == 1), we need to | 96 | // for all partitions (i.e. transforms.size() == 1), we need to |
| 125 | // restrict the partition | 97 | // restrict the partition |