Commit 95cd60aa9ffcbdf6cc926746f018383c0f338762
1 parent
e0f923f8
Add a preliminary version of Stream::train
This should be correct, but is fairly inefficient. Further work could be done if anyone is actually interested in using it.
Showing
1 changed file
with
41 additions
and
3 deletions
openbr/plugins/stream.cpp
| ... | ... | @@ -1027,14 +1027,46 @@ public: |
| 1027 | 1027 | |
| 1028 | 1028 | void train(const TemplateList & data) |
| 1029 | 1029 | { |
| 1030 | + (void) data; | |
| 1031 | + qFatal("terminal train called on interior node."); | |
| 1032 | + } | |
| 1033 | + | |
| 1034 | + void subProject(QList<TemplateList> & data, int end_idx) | |
| 1035 | + { | |
| 1036 | + if (end_idx == 0) | |
| 1037 | + return; | |
| 1038 | + | |
| 1039 | + // Set transforms to the start, up to end_idx | |
| 1040 | + QList<Transform *> backup = this->transforms; | |
| 1041 | + transforms = backup.mid(0,end_idx); | |
| 1042 | + | |
| 1043 | + // Reinitialize, we now act as a shorter stream. | |
| 1044 | + init(); | |
| 1045 | + | |
| 1046 | + for (int i=0; i < data.size(); i++) { | |
| 1047 | + projectUpdate(data[i], data[i]); | |
| 1048 | + } | |
| 1049 | + transforms = backup; | |
| 1050 | + } | |
| 1051 | + | |
| 1052 | + void train(const QList<TemplateList> & data) | |
| 1053 | + { | |
| 1030 | 1054 | if (!trainable) { |
| 1031 | 1055 | qWarning("Attempted to train untrainable transform, nothing will happen."); |
| 1032 | 1056 | return; |
| 1033 | 1057 | } |
| 1034 | - qFatal("Stream train is currently not implemented."); | |
| 1035 | - foreach(Transform * transform, transforms) { | |
| 1036 | - transform->train(data); | |
| 1058 | + | |
| 1059 | + for (int i=0; i < transforms.size(); i++) { | |
| 1060 | + // OK we have a trainable transform, we need to get input data for it. | |
| 1061 | + if (transforms[i]->trainable) { | |
| 1062 | + QList<TemplateList> copy = data; | |
| 1063 | + // Project from the start to the trainable stage. | |
| 1064 | + subProject(copy,i); | |
| 1065 | + transforms[i]->train(copy); | |
| 1066 | + } | |
| 1037 | 1067 | } |
| 1068 | + // Re-initialize because subProject probably messed us up. | |
| 1069 | + init(); | |
| 1038 | 1070 | } |
| 1039 | 1071 | |
| 1040 | 1072 | bool timeVarying() const { return true; } |
| ... | ... | @@ -1285,6 +1317,12 @@ public: |
| 1285 | 1317 | |
| 1286 | 1318 | void train(const TemplateList & data) |
| 1287 | 1319 | { |
| 1320 | + (void) data; | |
| 1321 | + qFatal("terminal train called on interior node."); | |
| 1322 | + } | |
| 1323 | + | |
| 1324 | + void train(const QList<TemplateList> & data) | |
| 1325 | + { | |
| 1288 | 1326 | basis.train(data); |
| 1289 | 1327 | } |
| 1290 | 1328 | ... | ... |