diff --git a/openbr/plugins/stream.cpp b/openbr/plugins/stream.cpp index 573169a..22ee4d8 100644 --- a/openbr/plugins/stream.cpp +++ b/openbr/plugins/stream.cpp @@ -1027,14 +1027,46 @@ public: void train(const TemplateList & data) { + (void) data; + qFatal("terminal train called on interior node."); + } + + void subProject(QList & data, int end_idx) + { + if (end_idx == 0) + return; + + // Set transforms to the start, up to end_idx + QList backup = this->transforms; + transforms = backup.mid(0,end_idx); + + // Reinitialize, we now act as a shorter stream. + init(); + + for (int i=0; i < data.size(); i++) { + projectUpdate(data[i], data[i]); + } + transforms = backup; + } + + void train(const QList & data) + { if (!trainable) { qWarning("Attempted to train untrainable transform, nothing will happen."); return; } - qFatal("Stream train is currently not implemented."); - foreach(Transform * transform, transforms) { - transform->train(data); + + for (int i=0; i < transforms.size(); i++) { + // OK we have a trainable transform, we need to get input data for it. + if (transforms[i]->trainable) { + QList copy = data; + // Project from the start to the trainable stage. + subProject(copy,i); + transforms[i]->train(copy); + } } + // Re-initialize because subProject probably messed us up. + init(); } bool timeVarying() const { return true; } @@ -1285,6 +1317,12 @@ public: void train(const TemplateList & data) { + (void) data; + qFatal("terminal train called on interior node."); + } + + void train(const QList & data) + { basis.train(data); }