Commit b52a58b69dcf13583f9482a32e41feb30eaa9a7a

Authored by Josh Klontz
1 parent 830083ad

optimized pipe training

openbr/openbr_plugin.cpp
@@ -1162,6 +1162,7 @@ public: @@ -1162,6 +1162,7 @@ public:
1162 transform->setParent(this); 1162 transform->setParent(this);
1163 transforms.append(transform); 1163 transforms.append(transform);
1164 file = transform->file; 1164 file = transform->file;
  1165 + trainable = transform->trainable;
1165 setObjectName(transforms.first()->objectName()); 1166 setObjectName(transforms.first()->objectName());
1166 } 1167 }
1167 1168
@@ -1178,9 +1179,8 @@ private: @@ -1178,9 +1179,8 @@ private:
1178 1179
1179 void train(const TemplateList &data) 1180 void train(const TemplateList &data)
1180 { 1181 {
1181 - // Don't bother constructing datasets if the transform is untrainable  
1182 - if (dynamic_cast<UntrainableTransform*>(transforms.first()))  
1183 - return; 1182 + // Don't bother if the transform is untrainable
  1183 + if (!trainable) return;
1184 1184
1185 QList<TemplateList> templatesList; 1185 QList<TemplateList> templatesList;
1186 foreach (const Template &t, data) { 1186 foreach (const Template &t, data) {
openbr/plugins/meta.cpp
@@ -75,14 +75,42 @@ class PipeTransform : public CompositeTransform @@ -75,14 +75,42 @@ class PipeTransform : public CompositeTransform
75 { 75 {
76 Q_OBJECT 76 Q_OBJECT
77 77
  78 + void _projectPartial(Template *srcdst, int startIndex, int stopIndex)
  79 + {
  80 + for (int i=startIndex; i<stopIndex; i++)
  81 + *srcdst >> *transforms[i];
  82 + }
  83 +
78 void train(const TemplateList &data) 84 void train(const TemplateList &data)
79 { 85 {
80 TemplateList copy(data); 86 TemplateList copy(data);
81 - for (int i=0; i<transforms.size(); i++) {  
82 - fprintf(stderr, "%s training... ", qPrintable(transforms[i]->objectName()));  
83 - transforms[i]->train(copy);  
84 - fprintf(stderr, "projecting...\n");  
85 - copy >> *transforms[i]; 87 + int i = 0;
  88 + while (i < transforms.size()) {
  89 + fprintf(stderr, "%s", qPrintable(transforms[i]->objectName()));
  90 +
  91 + // Conditional statement covers likely case that first transform is untrainable
  92 + if (transforms[i]->trainable) {
  93 + fprintf(stderr, " training...");
  94 + transforms[i]->train(copy);
  95 + }
  96 +
  97 + // We project through any subsequent untrainable transforms at once
  98 + // as a memory optimization in case any of these intermediate
  99 + // transforms allocate a lot of memory (like OpenTransform)
  100 + // then we don't want all the training templates to be processed
  101 + // by that transform at once if we can avoid it.
  102 + int nextTrainableTransform = i+1;
  103 + while ((nextTrainableTransform < transforms.size()) &&
  104 + !transforms[nextTrainableTransform]->trainable)
  105 + nextTrainableTransform++;
  106 +
  107 + fprintf(stderr, " projecting...\n");
  108 + QFutureSynchronizer<void> futures;
  109 + for (int j=0; j<copy.size(); j++)
  110 + if (Globals->parallelism) futures.addFuture(QtConcurrent::run(this, &PipeTransform::_projectPartial, &copy[j], i, nextTrainableTransform));
  111 + else _projectPartial( &copy[j], i, nextTrainableTransform);
  112 + futures.waitForFinished();
  113 + i = nextTrainableTransform;
86 } 114 }
87 } 115 }
88 116