Commit 968f94f3ae6c8be351e1fdc6e2887d991ef7772f
1 parent
1699213d
tweaks to meta transforms to skip training when possible
Showing
2 changed files
with
9 additions
and
5 deletions
openbr/openbr_plugin.h
| ... | ... | @@ -1213,11 +1213,10 @@ public: |
| 1213 | 1213 | void init() |
| 1214 | 1214 | { |
| 1215 | 1215 | isTimeVarying = false; |
| 1216 | + trainable = false; | |
| 1216 | 1217 | foreach (const br::Transform *transform, transforms) { |
| 1217 | - if (transform->timeVarying()) { | |
| 1218 | - isTimeVarying = true; | |
| 1219 | - break; | |
| 1220 | - } | |
| 1218 | + isTimeVarying = isTimeVarying || transform->timeVarying(); | |
| 1219 | + trainable = trainable || transform->trainable; | |
| 1221 | 1220 | } |
| 1222 | 1221 | } |
| 1223 | 1222 | ... | ... |
openbr/plugins/meta.cpp
| ... | ... | @@ -73,7 +73,7 @@ static void _train(Transform *transform, const TemplateList *data) |
| 73 | 73 | */ |
| 74 | 74 | class PipeTransform : public CompositeTransform |
| 75 | 75 | { |
| 76 | - Q_OBJECT | |
| 76 | + Q_OBJECT | |
| 77 | 77 | |
| 78 | 78 | void _projectPartial(Template *srcdst, int startIndex, int stopIndex) |
| 79 | 79 | { |
| ... | ... | @@ -83,6 +83,8 @@ class PipeTransform : public CompositeTransform |
| 83 | 83 | |
| 84 | 84 | void train(const TemplateList &data) |
| 85 | 85 | { |
| 86 | + if (!trainable) return; | |
| 87 | + | |
| 86 | 88 | TemplateList copy(data); |
| 87 | 89 | int i = 0; |
| 88 | 90 | while (i < transforms.size()) { |
| ... | ... | @@ -285,6 +287,7 @@ class ForkTransform : public CompositeTransform |
| 285 | 287 | |
| 286 | 288 | void train(const TemplateList &data) |
| 287 | 289 | { |
| 290 | + if (!trainable) return; | |
| 288 | 291 | QFutureSynchronizer<void> futures; |
| 289 | 292 | for (int i=0; i<transforms.size(); i++) { |
| 290 | 293 | if (Globals->parallelism) futures.addFuture(QtConcurrent::run(_train, transforms[i], &data)); |
| ... | ... | @@ -415,6 +418,7 @@ public: |
| 415 | 418 | private: |
| 416 | 419 | void init() |
| 417 | 420 | { |
| 421 | + trainable = transform->trainable; | |
| 418 | 422 | if (!cache.isEmpty()) return; |
| 419 | 423 | |
| 420 | 424 | // Read from cache |
| ... | ... | @@ -476,6 +480,7 @@ private: |
| 476 | 480 | if (transform != NULL) return; |
| 477 | 481 | baseName = QRegExp("^[a-zA-Z0-9]+$").exactMatch(description) ? description : QtUtils::shortTextHash(description); |
| 478 | 482 | if (!tryLoad()) transform = make(description); |
| 483 | + else trainable = false; | |
| 479 | 484 | } |
| 480 | 485 | |
| 481 | 486 | void train(const TemplateList &data) | ... | ... |