Commit 95cd60aa9ffcbdf6cc926746f018383c0f338762

Authored by Charles Otto
1 parent e0f923f8

Add a preliminary version of Stream::train

This should be correct, but is fairly inefficient. Further work could be done
if anyone is actually interested in using it.
Showing 1 changed file with 41 additions and 3 deletions
openbr/plugins/stream.cpp
@@ -1027,14 +1027,46 @@ public: @@ -1027,14 +1027,46 @@ public:
1027 1027
1028 void train(const TemplateList & data) 1028 void train(const TemplateList & data)
1029 { 1029 {
  1030 + (void) data;
  1031 + qFatal("terminal train called on interior node.");
  1032 + }
  1033 +
  1034 + void subProject(QList<TemplateList> & data, int end_idx)
  1035 + {
  1036 + if (end_idx == 0)
  1037 + return;
  1038 +
  1039 + // Set transforms to the start, up to end_idx
  1040 + QList<Transform *> backup = this->transforms;
  1041 + transforms = backup.mid(0,end_idx);
  1042 +
  1043 + // Reinitialize, we now act as a shorter stream.
  1044 + init();
  1045 +
  1046 + for (int i=0; i < data.size(); i++) {
  1047 + projectUpdate(data[i], data[i]);
  1048 + }
  1049 + transforms = backup;
  1050 + }
  1051 +
  1052 + void train(const QList<TemplateList> & data)
  1053 + {
1030 if (!trainable) { 1054 if (!trainable) {
1031 qWarning("Attempted to train untrainable transform, nothing will happen."); 1055 qWarning("Attempted to train untrainable transform, nothing will happen.");
1032 return; 1056 return;
1033 } 1057 }
1034 - qFatal("Stream train is currently not implemented.");  
1035 - foreach(Transform * transform, transforms) {  
1036 - transform->train(data); 1058 +
  1059 + for (int i=0; i < transforms.size(); i++) {
  1060 + // OK we have a trainable transform, we need to get input data for it.
  1061 + if (transforms[i]->trainable) {
  1062 + QList<TemplateList> copy = data;
  1063 + // Project from the start to the trainable stage.
  1064 + subProject(copy,i);
  1065 + transforms[i]->train(copy);
  1066 + }
1037 } 1067 }
  1068 + // Re-initialize because subProject probably messed us up.
  1069 + init();
1038 } 1070 }
1039 1071
1040 bool timeVarying() const { return true; } 1072 bool timeVarying() const { return true; }
@@ -1285,6 +1317,12 @@ public: @@ -1285,6 +1317,12 @@ public:
1285 1317
1286 void train(const TemplateList & data) 1318 void train(const TemplateList & data)
1287 { 1319 {
  1320 + (void) data;
  1321 + qFatal("terminal train called on interior node.");
  1322 + }
  1323 +
  1324 + void train(const QList<TemplateList> & data)
  1325 + {
1288 basis.train(data); 1326 basis.train(data);
1289 } 1327 }
1290 1328