Commit a47027675263b1c4a02e2d42758122ee83635b32

Authored by Josh Klontz
1 parent 656823d5

removed crossvalidate transform

openbr/plugins/core/crossvalidate.cpp deleted
1 -/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *  
2 - * Copyright 2012 The MITRE Corporation *  
3 - * *  
4 - * Licensed under the Apache License, Version 2.0 (the "License"); *  
5 - * you may not use this file except in compliance with the License. *  
6 - * You may obtain a copy of the License at *  
7 - * *  
8 - * http://www.apache.org/licenses/LICENSE-2.0 *  
9 - * *  
10 - * Unless required by applicable law or agreed to in writing, software *  
11 - * distributed under the License is distributed on an "AS IS" BASIS, *  
12 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *  
13 - * See the License for the specific language governing permissions and *  
14 - * limitations under the License. *  
15 - * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */  
16 -  
17 -#include <QtConcurrent>  
18 -  
19 -#include <openbr/plugins/openbr_internal.h>  
20 -#include <openbr/core/common.h>  
21 -  
22 -namespace br  
23 -{  
24 -  
25 -static void _train(Transform *transform, TemplateList data) // think data has to be a copy -cao  
26 -{  
27 - transform->train(data);  
28 -}  
29 -  
30 -/*!  
31 - * \ingroup transforms  
32 - * \brief Cross validate a trainable Transform.  
33 - *  
34 - * Two flags can be put in File metadata that are related to cross-validation and are used to  
35 - * extend a testing gallery:  
36 - *  
37 - * flag | description  
38 - * --- | ---  
39 - * allPartitions | This flag is intended to be used when comparing the performance of an untrainable algorithm (e.g. a COTS algorithm) against a trainable algorithm that was trained using cross-validation. All templates with the allPartitions flag will be compared against for every partition. As untrainable algorithms will have no use for the CrossValidateTransform, this flag is only meaningful at comparison time (but care has been taken so that one can train and enroll without issue if these Files are present in the used Gallery).  
40 - * duplicatePartitions | This flag is similar to allPartitions in that it causes the same template to be used during comparison for every partition. The difference is that duplicatePartitions will duplicate each marked template and project it into the model space constituded by the child transforms of CrossValidateTransform. Again, care has been take such that one can train with these templates in the used Gallery successfully (they will simply be omitted).  
41 - *  
42 - * To use an extended Gallery, add an allPartitions="true" flag to the gallery sigset for those images that should be compared  
43 - * against for all testing partitions.  
44 - *  
45 - * \author Josh Klontz \cite jklontz  
46 - * \author Scott Klum \cite sklum  
47 - */  
48 -class CrossValidateTransform : public MetaTransform  
49 -{  
50 - Q_OBJECT  
51 - Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)  
52 - Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)  
53 - Q_PROPERTY(unsigned int randomSeed READ get_randomSeed WRITE set_randomSeed RESET reset_randomSeed STORED false)  
54 - BR_PROPERTY(QString, description, "Identity")  
55 - BR_PROPERTY(QString, inputVariable, "Label")  
56 - BR_PROPERTY(unsigned int, randomSeed, 0)  
57 -  
58 - // numPartitions copies of transform specified by description.  
59 - QList<br::Transform*> transforms;  
60 -  
61 - // Treating this transform as a leaf (in terms of updated training scheme), the child transform  
62 - // of this transform will lose any structure present in the training QList<TemplateList>, which  
63 - // is generally incorrect behavior.  
64 - void train(const TemplateList &data)  
65 - {  
66 - TemplateList partitionedData = data.partition(inputVariable, randomSeed, true);  
67 - QList<int> partitions = partitionedData.files().crossValidationPartitions();  
68 -  
69 - const int crossValidate = Globals->crossValidate;  
70 - // Only train once based on the 0th partition if crossValidate is negative.  
71 - const int numPartitions = (crossValidate < 0) ? 1 : Common::Max(partitions)+1;  
72 - while (transforms.size() < numPartitions)  
73 - transforms.append(make(description));  
74 -  
75 - if (std::abs(crossValidate) < 2) {  
76 - transforms.first()->train(data);  
77 - return;  
78 - }  
79 -  
80 - QFutureSynchronizer<void> futures;  
81 - for (int i=0; i<numPartitions; i++) {  
82 - TemplateList partition = partitionedData;  
83 - for (int j=partition.size()-1; j>=0; j--) {  
84 - if (partitions[j] == i)  
85 - // Remove data, it's designated for testing  
86 - partition.removeAt(j);  
87 - }  
88 - if (Globals->verbose)  
89 - qDebug() << QString("Training partition %1 on %2 templates.").arg(QString::number(i),QString::number(partition.size()));  
90 -  
91 - // Train on the remaining templates  
92 - futures.addFuture(QtConcurrent::run(_train, transforms[i], partition));  
93 - }  
94 - futures.waitForFinished();  
95 - }  
96 -  
97 - void project(const Template &src, Template &dst) const  
98 - {  
99 - Q_UNUSED(src);  
100 - Q_UNUSED(dst);  
101 -  
102 - qFatal("CrossValidateTransform::project(const Template &src, Template &dst) should not be called.");  
103 - }  
104 -  
105 - void project(const TemplateList &src, TemplateList &dst) const  
106 - {  
107 - TemplateList partitioned = src.partition(inputVariable, randomSeed, true);  
108 - const int crossValidate = Globals->crossValidate;  
109 -  
110 - if (crossValidate < 0) {  
111 - transforms[0]->project(partitioned, dst);  
112 - return;  
113 - }  
114 - for (int i=0; i<partitioned.size(); i++) {  
115 - int partition = partitioned[i].file.get<int>("Partition", 0);  
116 - transforms[partition]->project(partitioned, dst);  
117 - }  
118 - }  
119 -  
120 - void store(QDataStream &stream) const  
121 - {  
122 - stream << transforms.size();  
123 - foreach (Transform *transform, transforms)  
124 - transform->store(stream);  
125 - }  
126 -  
127 - void load(QDataStream &stream)  
128 - {  
129 - int numTransforms;  
130 - stream >> numTransforms;  
131 - while (transforms.size() < numTransforms)  
132 - transforms.append(make(description));  
133 - foreach (Transform *transform, transforms)  
134 - transform->load(stream);  
135 - }  
136 -};  
137 -  
138 -BR_REGISTER(Transform, CrossValidateTransform)  
139 -  
140 -} // namespace br  
141 -  
142 -#include "core/crossvalidate.moc"