Commit c5d6e2da86f14fea2036d3e885218f58745afc4e

Authored by Josh Klontz
1 parent 6fdf9e0b

Revert "removed crossvalidate transform"

This reverts commit a856c11e5d7c79a0919d05284c38ed9da1efa885.
openbr/plugins/core/crossvalidate.cpp 0 → 100644
  1 +/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
  2 + * Copyright 2012 The MITRE Corporation *
  3 + * *
  4 + * Licensed under the Apache License, Version 2.0 (the "License"); *
  5 + * you may not use this file except in compliance with the License. *
  6 + * You may obtain a copy of the License at *
  7 + * *
  8 + * http://www.apache.org/licenses/LICENSE-2.0 *
  9 + * *
  10 + * Unless required by applicable law or agreed to in writing, software *
  11 + * distributed under the License is distributed on an "AS IS" BASIS, *
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
  13 + * See the License for the specific language governing permissions and *
  14 + * limitations under the License. *
  15 + * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
  16 +
  17 +#include <QtConcurrent>
  18 +
  19 +#include <openbr/plugins/openbr_internal.h>
  20 +#include <openbr/core/common.h>
  21 +
  22 +namespace br
  23 +{
  24 +
  25 +static void _train(Transform *transform, TemplateList data) // think data has to be a copy -cao
  26 +{
  27 + transform->train(data);
  28 +}
  29 +
  30 +/*!
  31 + * \ingroup transforms
  32 + * \brief Cross validate a trainable Transform.
  33 + *
  34 + * Two flags can be put in File metadata that are related to cross-validation and are used to
  35 + * extend a testing gallery:
  36 + *
  37 + * flag | description
  38 + * --- | ---
  39 + * allPartitions | This flag is intended to be used when comparing the performance of an untrainable algorithm (e.g. a COTS algorithm) against a trainable algorithm that was trained using cross-validation. All templates with the allPartitions flag will be compared against for every partition. As untrainable algorithms will have no use for the CrossValidateTransform, this flag is only meaningful at comparison time (but care has been taken so that one can train and enroll without issue if these Files are present in the used Gallery).
  40 + * duplicatePartitions | This flag is similar to allPartitions in that it causes the same template to be used during comparison for every partition. The difference is that duplicatePartitions will duplicate each marked template and project it into the model space constituded by the child transforms of CrossValidateTransform. Again, care has been take such that one can train with these templates in the used Gallery successfully (they will simply be omitted).
  41 + *
  42 + * To use an extended Gallery, add an allPartitions="true" flag to the gallery sigset for those images that should be compared
  43 + * against for all testing partitions.
  44 + *
  45 + * \author Josh Klontz \cite jklontz
  46 + * \author Scott Klum \cite sklum
  47 + */
  48 +class CrossValidateTransform : public MetaTransform
  49 +{
  50 + Q_OBJECT
  51 + Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)
  52 + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  53 + Q_PROPERTY(unsigned int randomSeed READ get_randomSeed WRITE set_randomSeed RESET reset_randomSeed STORED false)
  54 + BR_PROPERTY(QString, description, "Identity")
  55 + BR_PROPERTY(QString, inputVariable, "Label")
  56 + BR_PROPERTY(unsigned int, randomSeed, 0)
  57 +
  58 + // numPartitions copies of transform specified by description.
  59 + QList<br::Transform*> transforms;
  60 +
  61 + // Treating this transform as a leaf (in terms of updated training scheme), the child transform
  62 + // of this transform will lose any structure present in the training QList<TemplateList>, which
  63 + // is generally incorrect behavior.
  64 + void train(const TemplateList &data)
  65 + {
  66 + TemplateList partitionedData = data.partition(inputVariable, randomSeed, true);
  67 + QList<int> partitions = partitionedData.files().crossValidationPartitions();
  68 +
  69 + const int crossValidate = Globals->crossValidate;
  70 + // Only train once based on the 0th partition if crossValidate is negative.
  71 + const int numPartitions = (crossValidate < 0) ? 1 : Common::Max(partitions)+1;
  72 + while (transforms.size() < numPartitions)
  73 + transforms.append(make(description));
  74 +
  75 + if (std::abs(crossValidate) < 2) {
  76 + transforms.first()->train(data);
  77 + return;
  78 + }
  79 +
  80 + QFutureSynchronizer<void> futures;
  81 + for (int i=0; i<numPartitions; i++) {
  82 + TemplateList partition = partitionedData;
  83 + for (int j=partition.size()-1; j>=0; j--) {
  84 + if (partitions[j] == i)
  85 + // Remove data, it's designated for testing
  86 + partition.removeAt(j);
  87 + }
  88 + if (Globals->verbose)
  89 + qDebug() << QString("Training partition %1 on %2 templates.").arg(QString::number(i),QString::number(partition.size()));
  90 +
  91 + // Train on the remaining templates
  92 + futures.addFuture(QtConcurrent::run(_train, transforms[i], partition));
  93 + }
  94 + futures.waitForFinished();
  95 + }
  96 +
  97 + void project(const Template &src, Template &dst) const
  98 + {
  99 + Q_UNUSED(src);
  100 + Q_UNUSED(dst);
  101 +
  102 + qFatal("CrossValidateTransform::project(const Template &src, Template &dst) should not be called.");
  103 + }
  104 +
  105 + void project(const TemplateList &src, TemplateList &dst) const
  106 + {
  107 + TemplateList partitioned = src.partition(inputVariable, randomSeed, true);
  108 + const int crossValidate = Globals->crossValidate;
  109 +
  110 + if (crossValidate < 0) {
  111 + transforms[0]->project(partitioned, dst);
  112 + return;
  113 + }
  114 + for (int i=0; i<partitioned.size(); i++) {
  115 + int partition = partitioned[i].file.get<int>("Partition", 0);
  116 + transforms[partition]->project(partitioned, dst);
  117 + }
  118 + }
  119 +
  120 + void store(QDataStream &stream) const
  121 + {
  122 + stream << transforms.size();
  123 + foreach (Transform *transform, transforms)
  124 + transform->store(stream);
  125 + }
  126 +
  127 + void load(QDataStream &stream)
  128 + {
  129 + int numTransforms;
  130 + stream >> numTransforms;
  131 + while (transforms.size() < numTransforms)
  132 + transforms.append(make(description));
  133 + foreach (Transform *transform, transforms)
  134 + transform->load(stream);
  135 + }
  136 +};
  137 +
  138 +BR_REGISTER(Transform, CrossValidateTransform)
  139 +
  140 +} // namespace br
  141 +
  142 +#include "core/crossvalidate.moc"