Commit b52a58b69dcf13583f9482a32e41feb30eaa9a7a
1 parent
830083ad
optimized pipe training
Showing
2 changed files
with
36 additions
and
8 deletions
openbr/openbr_plugin.cpp
| @@ -1162,6 +1162,7 @@ public: | @@ -1162,6 +1162,7 @@ public: | ||
| 1162 | transform->setParent(this); | 1162 | transform->setParent(this); |
| 1163 | transforms.append(transform); | 1163 | transforms.append(transform); |
| 1164 | file = transform->file; | 1164 | file = transform->file; |
| 1165 | + trainable = transform->trainable; | ||
| 1165 | setObjectName(transforms.first()->objectName()); | 1166 | setObjectName(transforms.first()->objectName()); |
| 1166 | } | 1167 | } |
| 1167 | 1168 | ||
| @@ -1178,9 +1179,8 @@ private: | @@ -1178,9 +1179,8 @@ private: | ||
| 1178 | 1179 | ||
| 1179 | void train(const TemplateList &data) | 1180 | void train(const TemplateList &data) |
| 1180 | { | 1181 | { |
| 1181 | - // Don't bother constructing datasets if the transform is untrainable | ||
| 1182 | - if (dynamic_cast<UntrainableTransform*>(transforms.first())) | ||
| 1183 | - return; | 1182 | + // Don't bother if the transform is untrainable |
| 1183 | + if (!trainable) return; | ||
| 1184 | 1184 | ||
| 1185 | QList<TemplateList> templatesList; | 1185 | QList<TemplateList> templatesList; |
| 1186 | foreach (const Template &t, data) { | 1186 | foreach (const Template &t, data) { |
openbr/plugins/meta.cpp
| @@ -75,14 +75,42 @@ class PipeTransform : public CompositeTransform | @@ -75,14 +75,42 @@ class PipeTransform : public CompositeTransform | ||
| 75 | { | 75 | { |
| 76 | Q_OBJECT | 76 | Q_OBJECT |
| 77 | 77 | ||
| 78 | + void _projectPartial(Template *srcdst, int startIndex, int stopIndex) | ||
| 79 | + { | ||
| 80 | + for (int i=startIndex; i<stopIndex; i++) | ||
| 81 | + *srcdst >> *transforms[i]; | ||
| 82 | + } | ||
| 83 | + | ||
| 78 | void train(const TemplateList &data) | 84 | void train(const TemplateList &data) |
| 79 | { | 85 | { |
| 80 | TemplateList copy(data); | 86 | TemplateList copy(data); |
| 81 | - for (int i=0; i<transforms.size(); i++) { | ||
| 82 | - fprintf(stderr, "%s training... ", qPrintable(transforms[i]->objectName())); | ||
| 83 | - transforms[i]->train(copy); | ||
| 84 | - fprintf(stderr, "projecting...\n"); | ||
| 85 | - copy >> *transforms[i]; | 87 | + int i = 0; |
| 88 | + while (i < transforms.size()) { | ||
| 89 | + fprintf(stderr, "%s", qPrintable(transforms[i]->objectName())); | ||
| 90 | + | ||
| 91 | + // Conditional statement covers likely case that first transform is untrainable | ||
| 92 | + if (transforms[i]->trainable) { | ||
| 93 | + fprintf(stderr, " training..."); | ||
| 94 | + transforms[i]->train(copy); | ||
| 95 | + } | ||
| 96 | + | ||
| 97 | + // We project through any subsequent untrainable transforms at once | ||
| 98 | + // as a memory optimization in case any of these intermediate | ||
| 99 | + // transforms allocate a lot of memory (like OpenTransform) | ||
| 100 | + // then we don't want all the training templates to be processed | ||
| 101 | + // by that transform at once if we can avoid it. | ||
| 102 | + int nextTrainableTransform = i+1; | ||
| 103 | + while ((nextTrainableTransform < transforms.size()) && | ||
| 104 | + !transforms[nextTrainableTransform]->trainable) | ||
| 105 | + nextTrainableTransform++; | ||
| 106 | + | ||
| 107 | + fprintf(stderr, " projecting...\n"); | ||
| 108 | + QFutureSynchronizer<void> futures; | ||
| 109 | + for (int j=0; j<copy.size(); j++) | ||
| 110 | + if (Globals->parallelism) futures.addFuture(QtConcurrent::run(this, &PipeTransform::_projectPartial, ©[j], i, nextTrainableTransform)); | ||
| 111 | + else _projectPartial( ©[j], i, nextTrainableTransform); | ||
| 112 | + futures.waitForFinished(); | ||
| 113 | + i = nextTrainableTransform; | ||
| 86 | } | 114 | } |
| 87 | } | 115 | } |
| 88 | 116 |