Commit f914aca9cb48e2e2d088bcf9214b49c7a4e88658

Authored by Scott Klum
1 parent d97f68bd

Refactored CrossValidateTransform

openbr/plugins/core/crossvalidate.cpp
@@ -32,31 +32,41 @@ static void _train(Transform *transform, TemplateList data) // think data has to @@ -32,31 +32,41 @@ static void _train(Transform *transform, TemplateList data) // think data has to
32 * \brief Cross validate a trainable transform. 32 * \brief Cross validate a trainable transform.
33 * \author Josh Klontz \cite jklontz 33 * \author Josh Klontz \cite jklontz
34 * \author Scott Klum \cite sklum 34 * \author Scott Klum \cite sklum
35 - * \note To use an extended gallery, add an allPartitions="true" flag to the gallery sigset for those images that should be compared  
36 - * against for all testing partitions. 35 + * \note Two flags can be put in File metadata that are related to cross-validation and are used to
  36 + * extend a testing gallery:
  37 + * (i) allPartitions - This flag is intended to be used when comparing the
  38 + * performance of an untrainable algorithm (e.g. a COTS
  39 + * algorithm) against a trainable algorithm that was trained
  40 + * using cross-validation. All templates with the allPartitions
  41 + * flag will be compared against for every partition. As
  42 + * untrainable algorithms will have no use for the
  43 + * CrossValidateTransform, this flag is only meaningful at comparison
  44 + * time (but care has been taken so that one can train and enroll
  45 + * without issue if these Files are present in the used Gallery).
  46 + * (ii) duplicatePartitions - This flag is similar to allPartitions in that it causes
  47 + * the same template to be used during comparison for every partition.
  48 + * The difference is that duplicatePartitions will duplicate each
  49 + * marked template and project it into the model space constituded
  50 + * by the child transforms of CrossValidateTransform. Again, care
  51 + * has been take such that one can train with these templates in the
  52 + * used Gallery successfully (they will simply be omitted).
37 */ 53 */
38 class CrossValidateTransform : public MetaTransform 54 class CrossValidateTransform : public MetaTransform
39 { 55 {
40 Q_OBJECT 56 Q_OBJECT
41 Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) 57 Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)
42 - Q_PROPERTY(bool leaveOneImageOut READ get_leaveOneImageOut WRITE set_leaveOneImageOut RESET reset_leaveOneImageOut STORED false)  
43 BR_PROPERTY(QString, description, "Identity") 58 BR_PROPERTY(QString, description, "Identity")
44 - BR_PROPERTY(bool, leaveOneImageOut, false)  
45 59
46 // numPartitions copies of transform specified by description. 60 // numPartitions copies of transform specified by description.
47 QList<br::Transform*> transforms; 61 QList<br::Transform*> transforms;
48 62
49 - // Treating this transform as a leaf (in terms of update training scheme), the child transform 63 + // Treating this transform as a leaf (in terms of updated training scheme), the child transform
50 // of this transform will lose any structure present in the training QList<TemplateList>, which 64 // of this transform will lose any structure present in the training QList<TemplateList>, which
51 // is generally incorrect behavior. 65 // is generally incorrect behavior.
52 void train(const TemplateList &data) 66 void train(const TemplateList &data)
53 { 67 {
54 - int numPartitions = 0;  
55 - QList<int> partitions; partitions.reserve(data.size());  
56 - foreach (const File &file, data.files()) {  
57 - partitions.append(file.get<int>("Partition", 0));  
58 - numPartitions = std::max(numPartitions, partitions.last()+1);  
59 - } 68 + QList<int> partitions = data.files().crossValidationPartitions();
  69 + const int numPartitions = Common::Max(partitions)+1;
60 70
61 while (transforms.size() < numPartitions) 71 while (transforms.size() < numPartitions)
62 transforms.append(make(description)); 72 transforms.append(make(description));
@@ -68,45 +78,12 @@ class CrossValidateTransform : public MetaTransform @@ -68,45 +78,12 @@ class CrossValidateTransform : public MetaTransform
68 78
69 QFutureSynchronizer<void> futures; 79 QFutureSynchronizer<void> futures;
70 for (int i=0; i<numPartitions; i++) { 80 for (int i=0; i<numPartitions; i++) {
71 - QList<int> partitionsBuffer = partitions;  
72 TemplateList partitionedData = data; 81 TemplateList partitionedData = data;
73 - int j = partitionedData.size()-1;  
74 - while (j>=0) {  
75 - // Remove all templates belonging to partition i  
76 - // if leaveOneImageOut is true,  
77 - // and i is greater than the number of images for a particular subject  
78 - // even if the partitions are different  
79 - if (leaveOneImageOut) {  
80 - const QString label = partitionedData.at(j).file.get<QString>("Label");  
81 - QList<int> subjectIndices = partitionedData.find("Label",label);  
82 - QList<int> removed;  
83 - // Remove target only data  
84 - for (int k=subjectIndices.size()-1; k>=0; k--)  
85 - if (partitionedData[subjectIndices[k]].file.getBool("targetOnly")) {  
86 - removed.append(subjectIndices[k]);  
87 - subjectIndices.removeAt(k);  
88 - }  
89 - // Remove template that was repeated to make the testOnly template  
90 - if (subjectIndices.size() > 1 && subjectIndices.size() <= i) {  
91 - removed.append(subjectIndices[i%subjectIndices.size()]);  
92 - } else if (partitionsBuffer[j] == i) {  
93 - removed.append(j);  
94 - }  
95 -  
96 - if (!removed.empty()) {  
97 - typedef QPair<int,int> Pair;  
98 - foreach (Pair pair, Common::Sort(removed,true)) {  
99 - partitionedData.removeAt(pair.first); partitionsBuffer.removeAt(pair.first); j--;  
100 - }  
101 - } else {  
102 - j--;  
103 - }  
104 - } else if (partitions[j] == i) { 82 + for (int j=partitionedData.size()-1; j>=0; j--)
  83 + if (partitions[j] == i) {
105 // Remove data, it's designated for testing 84 // Remove data, it's designated for testing
106 partitionedData.removeAt(j); 85 partitionedData.removeAt(j);
107 - j--;  
108 - } else j--;  
109 - } 86 + }
110 // Train on the remaining templates 87 // Train on the remaining templates
111 futures.addFuture(QtConcurrent::run(_train, transforms[i], partitionedData)); 88 futures.addFuture(QtConcurrent::run(_train, transforms[i], partitionedData));
112 } 89 }
@@ -115,11 +92,6 @@ class CrossValidateTransform : public MetaTransform @@ -115,11 +92,6 @@ class CrossValidateTransform : public MetaTransform
115 92
116 void project(const Template &src, Template &dst) const 93 void project(const Template &src, Template &dst) const
117 { 94 {
118 - // Remember, the partition should never be -1  
119 - // since it is assumed that the allPartitions  
120 - // flag is only used during comparison  
121 - // (i.e. only used when making a mask)  
122 -  
123 // If we want to duplicate templates but use the same training data 95 // If we want to duplicate templates but use the same training data
124 // for all partitions (i.e. transforms.size() == 1), we need to 96 // for all partitions (i.e. transforms.size() == 1), we need to
125 // restrict the partition 97 // restrict the partition