Commit a47027675263b1c4a02e2d42758122ee83635b32

Authored by Josh Klontz
1 parent 656823d5

removed crossvalidate transform

openbr/plugins/core/crossvalidate.cpp deleted
1   -/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
2   - * Copyright 2012 The MITRE Corporation *
3   - * *
4   - * Licensed under the Apache License, Version 2.0 (the "License"); *
5   - * you may not use this file except in compliance with the License. *
6   - * You may obtain a copy of the License at *
7   - * *
8   - * http://www.apache.org/licenses/LICENSE-2.0 *
9   - * *
10   - * Unless required by applicable law or agreed to in writing, software *
11   - * distributed under the License is distributed on an "AS IS" BASIS, *
12   - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
13   - * See the License for the specific language governing permissions and *
14   - * limitations under the License. *
15   - * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
16   -
17   -#include <QtConcurrent>
18   -
19   -#include <openbr/plugins/openbr_internal.h>
20   -#include <openbr/core/common.h>
21   -
22   -namespace br
23   -{
24   -
25   -static void _train(Transform *transform, TemplateList data) // think data has to be a copy -cao
26   -{
27   - transform->train(data);
28   -}
29   -
30   -/*!
31   - * \ingroup transforms
32   - * \brief Cross validate a trainable Transform.
33   - *
34   - * Two flags can be put in File metadata that are related to cross-validation and are used to
35   - * extend a testing gallery:
36   - *
37   - * flag | description
38   - * --- | ---
39   - * allPartitions | This flag is intended to be used when comparing the performance of an untrainable algorithm (e.g. a COTS algorithm) against a trainable algorithm that was trained using cross-validation. All templates with the allPartitions flag will be compared against for every partition. As untrainable algorithms will have no use for the CrossValidateTransform, this flag is only meaningful at comparison time (but care has been taken so that one can train and enroll without issue if these Files are present in the used Gallery).
40   - * duplicatePartitions | This flag is similar to allPartitions in that it causes the same template to be used during comparison for every partition. The difference is that duplicatePartitions will duplicate each marked template and project it into the model space constituded by the child transforms of CrossValidateTransform. Again, care has been take such that one can train with these templates in the used Gallery successfully (they will simply be omitted).
41   - *
42   - * To use an extended Gallery, add an allPartitions="true" flag to the gallery sigset for those images that should be compared
43   - * against for all testing partitions.
44   - *
45   - * \author Josh Klontz \cite jklontz
46   - * \author Scott Klum \cite sklum
47   - */
48   -class CrossValidateTransform : public MetaTransform
49   -{
50   - Q_OBJECT
51   - Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)
52   - Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
53   - Q_PROPERTY(unsigned int randomSeed READ get_randomSeed WRITE set_randomSeed RESET reset_randomSeed STORED false)
54   - BR_PROPERTY(QString, description, "Identity")
55   - BR_PROPERTY(QString, inputVariable, "Label")
56   - BR_PROPERTY(unsigned int, randomSeed, 0)
57   -
58   - // numPartitions copies of transform specified by description.
59   - QList<br::Transform*> transforms;
60   -
61   - // Treating this transform as a leaf (in terms of updated training scheme), the child transform
62   - // of this transform will lose any structure present in the training QList<TemplateList>, which
63   - // is generally incorrect behavior.
64   - void train(const TemplateList &data)
65   - {
66   - TemplateList partitionedData = data.partition(inputVariable, randomSeed, true);
67   - QList<int> partitions = partitionedData.files().crossValidationPartitions();
68   -
69   - const int crossValidate = Globals->crossValidate;
70   - // Only train once based on the 0th partition if crossValidate is negative.
71   - const int numPartitions = (crossValidate < 0) ? 1 : Common::Max(partitions)+1;
72   - while (transforms.size() < numPartitions)
73   - transforms.append(make(description));
74   -
75   - if (std::abs(crossValidate) < 2) {
76   - transforms.first()->train(data);
77   - return;
78   - }
79   -
80   - QFutureSynchronizer<void> futures;
81   - for (int i=0; i<numPartitions; i++) {
82   - TemplateList partition = partitionedData;
83   - for (int j=partition.size()-1; j>=0; j--) {
84   - if (partitions[j] == i)
85   - // Remove data, it's designated for testing
86   - partition.removeAt(j);
87   - }
88   - if (Globals->verbose)
89   - qDebug() << QString("Training partition %1 on %2 templates.").arg(QString::number(i),QString::number(partition.size()));
90   -
91   - // Train on the remaining templates
92   - futures.addFuture(QtConcurrent::run(_train, transforms[i], partition));
93   - }
94   - futures.waitForFinished();
95   - }
96   -
97   - void project(const Template &src, Template &dst) const
98   - {
99   - Q_UNUSED(src);
100   - Q_UNUSED(dst);
101   -
102   - qFatal("CrossValidateTransform::project(const Template &src, Template &dst) should not be called.");
103   - }
104   -
105   - void project(const TemplateList &src, TemplateList &dst) const
106   - {
107   - TemplateList partitioned = src.partition(inputVariable, randomSeed, true);
108   - const int crossValidate = Globals->crossValidate;
109   -
110   - if (crossValidate < 0) {
111   - transforms[0]->project(partitioned, dst);
112   - return;
113   - }
114   - for (int i=0; i<partitioned.size(); i++) {
115   - int partition = partitioned[i].file.get<int>("Partition", 0);
116   - transforms[partition]->project(partitioned, dst);
117   - }
118   - }
119   -
120   - void store(QDataStream &stream) const
121   - {
122   - stream << transforms.size();
123   - foreach (Transform *transform, transforms)
124   - transform->store(stream);
125   - }
126   -
127   - void load(QDataStream &stream)
128   - {
129   - int numTransforms;
130   - stream >> numTransforms;
131   - while (transforms.size() < numTransforms)
132   - transforms.append(make(description));
133   - foreach (Transform *transform, transforms)
134   - transform->load(stream);
135   - }
136   -};
137   -
138   -BR_REGISTER(Transform, CrossValidateTransform)
139   -
140   -} // namespace br
141   -
142   -#include "core/crossvalidate.moc"