diff --git a/openbr/core/core.cpp b/openbr/core/core.cpp index 1378302..f6e8939 100644 --- a/openbr/core/core.cpp +++ b/openbr/core/core.cpp @@ -48,7 +48,9 @@ struct AlgorithmCore QTime time; time.start(); qDebug("Training Enrollment"); - transform->train(data); + QList uniform; + uniform.append(data); + transform->train(uniform); if (!distance.isNull()) { qDebug("Projecting Enrollment"); diff --git a/openbr/openbr_plugin.h b/openbr/openbr_plugin.h index 0f0d800..b173839 100644 --- a/openbr/openbr_plugin.h +++ b/openbr/openbr_plugin.h @@ -1058,7 +1058,19 @@ public: static QSharedPointer fromAlgorithm(const QString &algorithm); /*!< \brief Retrieve an algorithm's transform. */ virtual Transform *clone() const; /*!< \brief Copy the transform. */ + /*!< \brief Train the transform, separate list items represent the way calls to project would be broken up*/ + virtual void train(const QList &data) + { + TemplateList combined; + foreach(const TemplateList & set, data) { + combined.append(set); + } + train(combined); + } + + // Terminal train, call with a complete training set when no further structure is needed virtual void train(const TemplateList &data) = 0; /*!< \brief Train the transform. */ + virtual void project(const Template &src, Template &dst) const = 0; /*!< \brief Apply the transform. */ virtual void project(const TemplateList &src, TemplateList &dst) const; /*!< \brief Apply the transform. */ diff --git a/openbr/plugins/meta.cpp b/openbr/plugins/meta.cpp index b4cc865..2f16c27 100644 --- a/openbr/plugins/meta.cpp +++ b/openbr/plugins/meta.cpp @@ -57,7 +57,7 @@ static TemplateList Expanded(const TemplateList &templates) return expanded; } -static void _train(Transform *transform, const TemplateList *data) +static void _train(Transform *transform, const QList *data) { transform->train(*data); } @@ -82,18 +82,17 @@ class PipeTransform : public CompositeTransform *srcdst >> *transforms[i]; } - void train(const TemplateList &data) + void train(const TemplateList & data) + { + (void) data; + qFatal("Terminal train called on interior node"); + } + + void train(const QList &data) { if (!trainable) return; - TemplateList copy(data); - QList singleItemLists; - for (int i=0; i < copy.size(); i++) - { - TemplateList temp; - temp.append(copy[i]); - singleItemLists.append(temp); - } + QList dataLines(data); int i = 0; while (i < transforms.size()) { @@ -102,23 +101,18 @@ class PipeTransform : public CompositeTransform // Conditional statement covers likely case that first transform is untrainable if (transforms[i]->trainable) { fprintf(stderr, " training..."); - transforms[i]->train(copy); + transforms[i]->train(dataLines); } // if the transform is time varying, we can't project it in parallel if (transforms[i]->timeVarying()) { fprintf(stderr, "\n%s projecting...", qPrintable(transforms[i]->objectName())); - for (int j=0; j < singleItemLists.size();j++) - transforms[i]->projectUpdate(singleItemLists[j], singleItemLists[j]); + for (int j=0; j < dataLines.size();j++) + transforms[i]->projectUpdate(dataLines[j], dataLines[j]); // advance i since we already projected for this stage. i++; - // set up copy again - copy.clear(); - for (int j=0; j < singleItemLists.size(); j++) - copy.append(singleItemLists[j]); - // the next stage might be trainable, so continue to evaluate it. continue; } @@ -136,14 +130,10 @@ class PipeTransform : public CompositeTransform fprintf(stderr, " projecting..."); QFutureSynchronizer futures; - for (int j=0; j < singleItemLists.size(); j++) - futures.addFuture(QtConcurrent::run(this, &PipeTransform::_projectPartial, &singleItemLists[j], i, nextTrainableTransform)); + for (int j=0; j < dataLines.size(); j++) + futures.addFuture(QtConcurrent::run(this, &PipeTransform::_projectPartial, &dataLines[j], i, nextTrainableTransform)); futures.waitForFinished(); - copy.clear(); - for (int j=0; j < singleItemLists.size(); j++) - copy.append(singleItemLists[j]); - i = nextTrainableTransform; } } @@ -296,7 +286,13 @@ class ForkTransform : public CompositeTransform { Q_OBJECT - void train(const TemplateList &data) + void train(const TemplateList & data) + { + (void) data; + qFatal("Terminal train called on interior node."); + } + + void train(const QList &data) { if (!trainable) return; QFutureSynchronizer futures; @@ -612,9 +608,23 @@ public: return output; } - void train(const TemplateList &data) + void train(const TemplateList & data) { - transform->train(data); + (void) data; + qFatal("terminal train called on non-leaf transform"); + } + + void train(const QList &data) + { + QList separated; + foreach (const TemplateList & list, data) { + foreach(const Template & t, list) { + separated.append(TemplateList()); + separated.last().append(t); + } + } + + transform->train(separated); } void project(const Template &src, Template &dst) const diff --git a/openbr/plugins/openbr_internal.h b/openbr/plugins/openbr_internal.h index 27136da..a4001e7 100644 --- a/openbr/plugins/openbr_internal.h +++ b/openbr/plugins/openbr_internal.h @@ -140,7 +140,13 @@ public: transformSource.release(aTransform); } - void train(const TemplateList &data) + void train(const TemplateList & data) + { + (void) data; + qFatal("terminal train called on non-leaf node"); + } + + void train(const QList &data) { baseTransform->train(data); } diff --git a/openbr/plugins/validate.cpp b/openbr/plugins/validate.cpp index 8b3a174..dd5bbd8 100644 --- a/openbr/plugins/validate.cpp +++ b/openbr/plugins/validate.cpp @@ -6,6 +6,11 @@ namespace br { +static void _train(Transform * transform, TemplateList data) // think data has to be a copy -cao +{ + transform->train(data); +} + /*! * \ingroup transforms * \brief Cross validate a trainable transform. @@ -46,7 +51,8 @@ class CrossValidateTransform : public MetaTransform if (partitions[j] == i) partitionedData.removeAt(j); // Train on the remaining templates - futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); + //futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); + futures.addFuture(QtConcurrent::run(_train, transforms[i], partitionedData)); } futures.waitForFinished(); }