Commit b52a58b69dcf13583f9482a32e41feb30eaa9a7a
1 parent
830083ad
optimized pipe training
Showing
2 changed files
with
36 additions
and
8 deletions
openbr/openbr_plugin.cpp
| ... | ... | @@ -1162,6 +1162,7 @@ public: |
| 1162 | 1162 | transform->setParent(this); |
| 1163 | 1163 | transforms.append(transform); |
| 1164 | 1164 | file = transform->file; |
| 1165 | + trainable = transform->trainable; | |
| 1165 | 1166 | setObjectName(transforms.first()->objectName()); |
| 1166 | 1167 | } |
| 1167 | 1168 | |
| ... | ... | @@ -1178,9 +1179,8 @@ private: |
| 1178 | 1179 | |
| 1179 | 1180 | void train(const TemplateList &data) |
| 1180 | 1181 | { |
| 1181 | - // Don't bother constructing datasets if the transform is untrainable | |
| 1182 | - if (dynamic_cast<UntrainableTransform*>(transforms.first())) | |
| 1183 | - return; | |
| 1182 | + // Don't bother if the transform is untrainable | |
| 1183 | + if (!trainable) return; | |
| 1184 | 1184 | |
| 1185 | 1185 | QList<TemplateList> templatesList; |
| 1186 | 1186 | foreach (const Template &t, data) { | ... | ... |
openbr/plugins/meta.cpp
| ... | ... | @@ -75,14 +75,42 @@ class PipeTransform : public CompositeTransform |
| 75 | 75 | { |
| 76 | 76 | Q_OBJECT |
| 77 | 77 | |
| 78 | + void _projectPartial(Template *srcdst, int startIndex, int stopIndex) | |
| 79 | + { | |
| 80 | + for (int i=startIndex; i<stopIndex; i++) | |
| 81 | + *srcdst >> *transforms[i]; | |
| 82 | + } | |
| 83 | + | |
| 78 | 84 | void train(const TemplateList &data) |
| 79 | 85 | { |
| 80 | 86 | TemplateList copy(data); |
| 81 | - for (int i=0; i<transforms.size(); i++) { | |
| 82 | - fprintf(stderr, "%s training... ", qPrintable(transforms[i]->objectName())); | |
| 83 | - transforms[i]->train(copy); | |
| 84 | - fprintf(stderr, "projecting...\n"); | |
| 85 | - copy >> *transforms[i]; | |
| 87 | + int i = 0; | |
| 88 | + while (i < transforms.size()) { | |
| 89 | + fprintf(stderr, "%s", qPrintable(transforms[i]->objectName())); | |
| 90 | + | |
| 91 | + // Conditional statement covers likely case that first transform is untrainable | |
| 92 | + if (transforms[i]->trainable) { | |
| 93 | + fprintf(stderr, " training..."); | |
| 94 | + transforms[i]->train(copy); | |
| 95 | + } | |
| 96 | + | |
| 97 | + // We project through any subsequent untrainable transforms at once | |
| 98 | + // as a memory optimization in case any of these intermediate | |
| 99 | + // transforms allocate a lot of memory (like OpenTransform) | |
| 100 | + // then we don't want all the training templates to be processed | |
| 101 | + // by that transform at once if we can avoid it. | |
| 102 | + int nextTrainableTransform = i+1; | |
| 103 | + while ((nextTrainableTransform < transforms.size()) && | |
| 104 | + !transforms[nextTrainableTransform]->trainable) | |
| 105 | + nextTrainableTransform++; | |
| 106 | + | |
| 107 | + fprintf(stderr, " projecting...\n"); | |
| 108 | + QFutureSynchronizer<void> futures; | |
| 109 | + for (int j=0; j<copy.size(); j++) | |
| 110 | + if (Globals->parallelism) futures.addFuture(QtConcurrent::run(this, &PipeTransform::_projectPartial, ©[j], i, nextTrainableTransform)); | |
| 111 | + else _projectPartial( ©[j], i, nextTrainableTransform); | |
| 112 | + futures.waitForFinished(); | |
| 113 | + i = nextTrainableTransform; | |
| 86 | 114 | } |
| 87 | 115 | } |
| 88 | 116 | ... | ... |