Commit d0aa1fd8eedc2fe9136b9af396045ba328495839
1 parent
f0eb0cfb
Model separate project calls as seaprate list items in train
For non-leaf transforms, add a separate train method that takes QList<TemplateList> as arguments, where the separate list items correspond to the separate calls to project that would be made by parent transforms during enrollment. This is basically to deal with the effects of distribute on untrainable meta transforms such as flatten, we have to maintain a consistent grouping with project, otherwise the intermediate projection results used during training won't be consistent with what they would be during enrollment.
Showing
5 changed files
with
66 additions
and
30 deletions
openbr/core/core.cpp
| ... | ... | @@ -48,7 +48,9 @@ struct AlgorithmCore |
| 48 | 48 | |
| 49 | 49 | QTime time; time.start(); |
| 50 | 50 | qDebug("Training Enrollment"); |
| 51 | - transform->train(data); | |
| 51 | + QList<TemplateList> uniform; | |
| 52 | + uniform.append(data); | |
| 53 | + transform->train(uniform); | |
| 52 | 54 | |
| 53 | 55 | if (!distance.isNull()) { |
| 54 | 56 | qDebug("Projecting Enrollment"); | ... | ... |
openbr/openbr_plugin.h
| ... | ... | @@ -1058,7 +1058,19 @@ public: |
| 1058 | 1058 | static QSharedPointer<Transform> fromAlgorithm(const QString &algorithm); /*!< \brief Retrieve an algorithm's transform. */ |
| 1059 | 1059 | |
| 1060 | 1060 | virtual Transform *clone() const; /*!< \brief Copy the transform. */ |
| 1061 | + /*!< \brief Train the transform, separate list items represent the way calls to project would be broken up*/ | |
| 1062 | + virtual void train(const QList<TemplateList> &data) | |
| 1063 | + { | |
| 1064 | + TemplateList combined; | |
| 1065 | + foreach(const TemplateList & set, data) { | |
| 1066 | + combined.append(set); | |
| 1067 | + } | |
| 1068 | + train(combined); | |
| 1069 | + } | |
| 1070 | + | |
| 1071 | + // Terminal train, call with a complete training set when no further structure is needed | |
| 1061 | 1072 | virtual void train(const TemplateList &data) = 0; /*!< \brief Train the transform. */ |
| 1073 | + | |
| 1062 | 1074 | virtual void project(const Template &src, Template &dst) const = 0; /*!< \brief Apply the transform. */ |
| 1063 | 1075 | virtual void project(const TemplateList &src, TemplateList &dst) const; /*!< \brief Apply the transform. */ |
| 1064 | 1076 | ... | ... |
openbr/plugins/meta.cpp
| ... | ... | @@ -57,7 +57,7 @@ static TemplateList Expanded(const TemplateList &templates) |
| 57 | 57 | return expanded; |
| 58 | 58 | } |
| 59 | 59 | |
| 60 | -static void _train(Transform *transform, const TemplateList *data) | |
| 60 | +static void _train(Transform *transform, const QList<TemplateList> *data) | |
| 61 | 61 | { |
| 62 | 62 | transform->train(*data); |
| 63 | 63 | } |
| ... | ... | @@ -82,18 +82,17 @@ class PipeTransform : public CompositeTransform |
| 82 | 82 | *srcdst >> *transforms[i]; |
| 83 | 83 | } |
| 84 | 84 | |
| 85 | - void train(const TemplateList &data) | |
| 85 | + void train(const TemplateList & data) | |
| 86 | + { | |
| 87 | + (void) data; | |
| 88 | + qFatal("Terminal train called on interior node"); | |
| 89 | + } | |
| 90 | + | |
| 91 | + void train(const QList<TemplateList> &data) | |
| 86 | 92 | { |
| 87 | 93 | if (!trainable) return; |
| 88 | 94 | |
| 89 | - TemplateList copy(data); | |
| 90 | - QList<TemplateList> singleItemLists; | |
| 91 | - for (int i=0; i < copy.size(); i++) | |
| 92 | - { | |
| 93 | - TemplateList temp; | |
| 94 | - temp.append(copy[i]); | |
| 95 | - singleItemLists.append(temp); | |
| 96 | - } | |
| 95 | + QList<TemplateList> dataLines(data); | |
| 97 | 96 | |
| 98 | 97 | int i = 0; |
| 99 | 98 | while (i < transforms.size()) { |
| ... | ... | @@ -102,23 +101,18 @@ class PipeTransform : public CompositeTransform |
| 102 | 101 | // Conditional statement covers likely case that first transform is untrainable |
| 103 | 102 | if (transforms[i]->trainable) { |
| 104 | 103 | fprintf(stderr, " training..."); |
| 105 | - transforms[i]->train(copy); | |
| 104 | + transforms[i]->train(dataLines); | |
| 106 | 105 | } |
| 107 | 106 | |
| 108 | 107 | // if the transform is time varying, we can't project it in parallel |
| 109 | 108 | if (transforms[i]->timeVarying()) { |
| 110 | 109 | fprintf(stderr, "\n%s projecting...", qPrintable(transforms[i]->objectName())); |
| 111 | - for (int j=0; j < singleItemLists.size();j++) | |
| 112 | - transforms[i]->projectUpdate(singleItemLists[j], singleItemLists[j]); | |
| 110 | + for (int j=0; j < dataLines.size();j++) | |
| 111 | + transforms[i]->projectUpdate(dataLines[j], dataLines[j]); | |
| 113 | 112 | |
| 114 | 113 | // advance i since we already projected for this stage. |
| 115 | 114 | i++; |
| 116 | 115 | |
| 117 | - // set up copy again | |
| 118 | - copy.clear(); | |
| 119 | - for (int j=0; j < singleItemLists.size(); j++) | |
| 120 | - copy.append(singleItemLists[j]); | |
| 121 | - | |
| 122 | 116 | // the next stage might be trainable, so continue to evaluate it. |
| 123 | 117 | continue; |
| 124 | 118 | } |
| ... | ... | @@ -136,14 +130,10 @@ class PipeTransform : public CompositeTransform |
| 136 | 130 | |
| 137 | 131 | fprintf(stderr, " projecting..."); |
| 138 | 132 | QFutureSynchronizer<void> futures; |
| 139 | - for (int j=0; j < singleItemLists.size(); j++) | |
| 140 | - futures.addFuture(QtConcurrent::run(this, &PipeTransform::_projectPartial, &singleItemLists[j], i, nextTrainableTransform)); | |
| 133 | + for (int j=0; j < dataLines.size(); j++) | |
| 134 | + futures.addFuture(QtConcurrent::run(this, &PipeTransform::_projectPartial, &dataLines[j], i, nextTrainableTransform)); | |
| 141 | 135 | futures.waitForFinished(); |
| 142 | 136 | |
| 143 | - copy.clear(); | |
| 144 | - for (int j=0; j < singleItemLists.size(); j++) | |
| 145 | - copy.append(singleItemLists[j]); | |
| 146 | - | |
| 147 | 137 | i = nextTrainableTransform; |
| 148 | 138 | } |
| 149 | 139 | } |
| ... | ... | @@ -296,7 +286,13 @@ class ForkTransform : public CompositeTransform |
| 296 | 286 | { |
| 297 | 287 | Q_OBJECT |
| 298 | 288 | |
| 299 | - void train(const TemplateList &data) | |
| 289 | + void train(const TemplateList & data) | |
| 290 | + { | |
| 291 | + (void) data; | |
| 292 | + qFatal("Terminal train called on interior node."); | |
| 293 | + } | |
| 294 | + | |
| 295 | + void train(const QList<TemplateList> &data) | |
| 300 | 296 | { |
| 301 | 297 | if (!trainable) return; |
| 302 | 298 | QFutureSynchronizer<void> futures; |
| ... | ... | @@ -612,9 +608,23 @@ public: |
| 612 | 608 | return output; |
| 613 | 609 | } |
| 614 | 610 | |
| 615 | - void train(const TemplateList &data) | |
| 611 | + void train(const TemplateList & data) | |
| 616 | 612 | { |
| 617 | - transform->train(data); | |
| 613 | + (void) data; | |
| 614 | + qFatal("terminal train called on non-leaf transform"); | |
| 615 | + } | |
| 616 | + | |
| 617 | + void train(const QList<TemplateList> &data) | |
| 618 | + { | |
| 619 | + QList<TemplateList> separated; | |
| 620 | + foreach (const TemplateList & list, data) { | |
| 621 | + foreach(const Template & t, list) { | |
| 622 | + separated.append(TemplateList()); | |
| 623 | + separated.last().append(t); | |
| 624 | + } | |
| 625 | + } | |
| 626 | + | |
| 627 | + transform->train(separated); | |
| 618 | 628 | } |
| 619 | 629 | |
| 620 | 630 | void project(const Template &src, Template &dst) const | ... | ... |
openbr/plugins/openbr_internal.h
| ... | ... | @@ -140,7 +140,13 @@ public: |
| 140 | 140 | transformSource.release(aTransform); |
| 141 | 141 | } |
| 142 | 142 | |
| 143 | - void train(const TemplateList &data) | |
| 143 | + void train(const TemplateList & data) | |
| 144 | + { | |
| 145 | + (void) data; | |
| 146 | + qFatal("terminal train called on non-leaf node"); | |
| 147 | + } | |
| 148 | + | |
| 149 | + void train(const QList<TemplateList> &data) | |
| 144 | 150 | { |
| 145 | 151 | baseTransform->train(data); |
| 146 | 152 | } | ... | ... |
openbr/plugins/validate.cpp
| ... | ... | @@ -6,6 +6,11 @@ |
| 6 | 6 | namespace br |
| 7 | 7 | { |
| 8 | 8 | |
| 9 | +static void _train(Transform * transform, TemplateList data) // think data has to be a copy -cao | |
| 10 | +{ | |
| 11 | + transform->train(data); | |
| 12 | +} | |
| 13 | + | |
| 9 | 14 | /*! |
| 10 | 15 | * \ingroup transforms |
| 11 | 16 | * \brief Cross validate a trainable transform. |
| ... | ... | @@ -46,7 +51,8 @@ class CrossValidateTransform : public MetaTransform |
| 46 | 51 | if (partitions[j] == i) |
| 47 | 52 | partitionedData.removeAt(j); |
| 48 | 53 | // Train on the remaining templates |
| 49 | - futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); | |
| 54 | + //futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); | |
| 55 | + futures.addFuture(QtConcurrent::run(_train, transforms[i], partitionedData)); | |
| 50 | 56 | } |
| 51 | 57 | futures.waitForFinished(); |
| 52 | 58 | } | ... | ... |