Commit f914aca9cb48e2e2d088bcf9214b49c7a4e88658

Authored by Scott Klum
1 parent d97f68bd

Refactored CrossValidateTransform

openbr/plugins/core/crossvalidate.cpp
... ... @@ -32,31 +32,41 @@ static void _train(Transform *transform, TemplateList data) // think data has to
32 32 * \brief Cross validate a trainable transform.
33 33 * \author Josh Klontz \cite jklontz
34 34 * \author Scott Klum \cite sklum
35   - * \note To use an extended gallery, add an allPartitions="true" flag to the gallery sigset for those images that should be compared
36   - * against for all testing partitions.
  35 + * \note Two flags can be put in File metadata that are related to cross-validation and are used to
  36 + * extend a testing gallery:
  37 + * (i) allPartitions - This flag is intended to be used when comparing the
  38 + * performance of an untrainable algorithm (e.g. a COTS
  39 + * algorithm) against a trainable algorithm that was trained
  40 + * using cross-validation. All templates with the allPartitions
  41 + * flag will be compared against for every partition. As
  42 + * untrainable algorithms will have no use for the
  43 + * CrossValidateTransform, this flag is only meaningful at comparison
  44 + * time (but care has been taken so that one can train and enroll
  45 + * without issue if these Files are present in the used Gallery).
  46 + * (ii) duplicatePartitions - This flag is similar to allPartitions in that it causes
  47 + * the same template to be used during comparison for every partition.
  48 + * The difference is that duplicatePartitions will duplicate each
  49 + * marked template and project it into the model space constituded
  50 + * by the child transforms of CrossValidateTransform. Again, care
  51 + * has been take such that one can train with these templates in the
  52 + * used Gallery successfully (they will simply be omitted).
37 53 */
38 54 class CrossValidateTransform : public MetaTransform
39 55 {
40 56 Q_OBJECT
41 57 Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)
42   - Q_PROPERTY(bool leaveOneImageOut READ get_leaveOneImageOut WRITE set_leaveOneImageOut RESET reset_leaveOneImageOut STORED false)
43 58 BR_PROPERTY(QString, description, "Identity")
44   - BR_PROPERTY(bool, leaveOneImageOut, false)
45 59  
46 60 // numPartitions copies of transform specified by description.
47 61 QList<br::Transform*> transforms;
48 62  
49   - // Treating this transform as a leaf (in terms of update training scheme), the child transform
  63 + // Treating this transform as a leaf (in terms of updated training scheme), the child transform
50 64 // of this transform will lose any structure present in the training QList<TemplateList>, which
51 65 // is generally incorrect behavior.
52 66 void train(const TemplateList &data)
53 67 {
54   - int numPartitions = 0;
55   - QList<int> partitions; partitions.reserve(data.size());
56   - foreach (const File &file, data.files()) {
57   - partitions.append(file.get<int>("Partition", 0));
58   - numPartitions = std::max(numPartitions, partitions.last()+1);
59   - }
  68 + QList<int> partitions = data.files().crossValidationPartitions();
  69 + const int numPartitions = Common::Max(partitions)+1;
60 70  
61 71 while (transforms.size() < numPartitions)
62 72 transforms.append(make(description));
... ... @@ -68,45 +78,12 @@ class CrossValidateTransform : public MetaTransform
68 78  
69 79 QFutureSynchronizer<void> futures;
70 80 for (int i=0; i<numPartitions; i++) {
71   - QList<int> partitionsBuffer = partitions;
72 81 TemplateList partitionedData = data;
73   - int j = partitionedData.size()-1;
74   - while (j>=0) {
75   - // Remove all templates belonging to partition i
76   - // if leaveOneImageOut is true,
77   - // and i is greater than the number of images for a particular subject
78   - // even if the partitions are different
79   - if (leaveOneImageOut) {
80   - const QString label = partitionedData.at(j).file.get<QString>("Label");
81   - QList<int> subjectIndices = partitionedData.find("Label",label);
82   - QList<int> removed;
83   - // Remove target only data
84   - for (int k=subjectIndices.size()-1; k>=0; k--)
85   - if (partitionedData[subjectIndices[k]].file.getBool("targetOnly")) {
86   - removed.append(subjectIndices[k]);
87   - subjectIndices.removeAt(k);
88   - }
89   - // Remove template that was repeated to make the testOnly template
90   - if (subjectIndices.size() > 1 && subjectIndices.size() <= i) {
91   - removed.append(subjectIndices[i%subjectIndices.size()]);
92   - } else if (partitionsBuffer[j] == i) {
93   - removed.append(j);
94   - }
95   -
96   - if (!removed.empty()) {
97   - typedef QPair<int,int> Pair;
98   - foreach (Pair pair, Common::Sort(removed,true)) {
99   - partitionedData.removeAt(pair.first); partitionsBuffer.removeAt(pair.first); j--;
100   - }
101   - } else {
102   - j--;
103   - }
104   - } else if (partitions[j] == i) {
  82 + for (int j=partitionedData.size()-1; j>=0; j--)
  83 + if (partitions[j] == i) {
105 84 // Remove data, it's designated for testing
106 85 partitionedData.removeAt(j);
107   - j--;
108   - } else j--;
109   - }
  86 + }
110 87 // Train on the remaining templates
111 88 futures.addFuture(QtConcurrent::run(_train, transforms[i], partitionedData));
112 89 }
... ... @@ -115,11 +92,6 @@ class CrossValidateTransform : public MetaTransform
115 92  
116 93 void project(const Template &src, Template &dst) const
117 94 {
118   - // Remember, the partition should never be -1
119   - // since it is assumed that the allPartitions
120   - // flag is only used during comparison
121   - // (i.e. only used when making a mask)
122   -
123 95 // If we want to duplicate templates but use the same training data
124 96 // for all partitions (i.e. transforms.size() == 1), we need to
125 97 // restrict the partition
... ...