Commit d0aa1fd8eedc2fe9136b9af396045ba328495839

Authored by Charles Otto
1 parent f0eb0cfb

Model separate project calls as seaprate list items in train

For non-leaf transforms, add a separate train method that takes
QList<TemplateList> as arguments, where the separate list items correspond to
the separate calls to project that would be made by parent transforms during
enrollment. This is basically to deal with the effects of distribute on
untrainable meta transforms such as flatten, we have to maintain a consistent
grouping with project, otherwise the intermediate projection results used
during training won't be consistent with what they would be during enrollment.
openbr/core/core.cpp
... ... @@ -48,7 +48,9 @@ struct AlgorithmCore
48 48  
49 49 QTime time; time.start();
50 50 qDebug("Training Enrollment");
51   - transform->train(data);
  51 + QList<TemplateList> uniform;
  52 + uniform.append(data);
  53 + transform->train(uniform);
52 54  
53 55 if (!distance.isNull()) {
54 56 qDebug("Projecting Enrollment");
... ...
openbr/openbr_plugin.h
... ... @@ -1058,7 +1058,19 @@ public:
1058 1058 static QSharedPointer<Transform> fromAlgorithm(const QString &algorithm); /*!< \brief Retrieve an algorithm's transform. */
1059 1059  
1060 1060 virtual Transform *clone() const; /*!< \brief Copy the transform. */
  1061 + /*!< \brief Train the transform, separate list items represent the way calls to project would be broken up*/
  1062 + virtual void train(const QList<TemplateList> &data)
  1063 + {
  1064 + TemplateList combined;
  1065 + foreach(const TemplateList & set, data) {
  1066 + combined.append(set);
  1067 + }
  1068 + train(combined);
  1069 + }
  1070 +
  1071 + // Terminal train, call with a complete training set when no further structure is needed
1061 1072 virtual void train(const TemplateList &data) = 0; /*!< \brief Train the transform. */
  1073 +
1062 1074 virtual void project(const Template &src, Template &dst) const = 0; /*!< \brief Apply the transform. */
1063 1075 virtual void project(const TemplateList &src, TemplateList &dst) const; /*!< \brief Apply the transform. */
1064 1076  
... ...
openbr/plugins/meta.cpp
... ... @@ -57,7 +57,7 @@ static TemplateList Expanded(const TemplateList &amp;templates)
57 57 return expanded;
58 58 }
59 59  
60   -static void _train(Transform *transform, const TemplateList *data)
  60 +static void _train(Transform *transform, const QList<TemplateList> *data)
61 61 {
62 62 transform->train(*data);
63 63 }
... ... @@ -82,18 +82,17 @@ class PipeTransform : public CompositeTransform
82 82 *srcdst >> *transforms[i];
83 83 }
84 84  
85   - void train(const TemplateList &data)
  85 + void train(const TemplateList & data)
  86 + {
  87 + (void) data;
  88 + qFatal("Terminal train called on interior node");
  89 + }
  90 +
  91 + void train(const QList<TemplateList> &data)
86 92 {
87 93 if (!trainable) return;
88 94  
89   - TemplateList copy(data);
90   - QList<TemplateList> singleItemLists;
91   - for (int i=0; i < copy.size(); i++)
92   - {
93   - TemplateList temp;
94   - temp.append(copy[i]);
95   - singleItemLists.append(temp);
96   - }
  95 + QList<TemplateList> dataLines(data);
97 96  
98 97 int i = 0;
99 98 while (i < transforms.size()) {
... ... @@ -102,23 +101,18 @@ class PipeTransform : public CompositeTransform
102 101 // Conditional statement covers likely case that first transform is untrainable
103 102 if (transforms[i]->trainable) {
104 103 fprintf(stderr, " training...");
105   - transforms[i]->train(copy);
  104 + transforms[i]->train(dataLines);
106 105 }
107 106  
108 107 // if the transform is time varying, we can't project it in parallel
109 108 if (transforms[i]->timeVarying()) {
110 109 fprintf(stderr, "\n%s projecting...", qPrintable(transforms[i]->objectName()));
111   - for (int j=0; j < singleItemLists.size();j++)
112   - transforms[i]->projectUpdate(singleItemLists[j], singleItemLists[j]);
  110 + for (int j=0; j < dataLines.size();j++)
  111 + transforms[i]->projectUpdate(dataLines[j], dataLines[j]);
113 112  
114 113 // advance i since we already projected for this stage.
115 114 i++;
116 115  
117   - // set up copy again
118   - copy.clear();
119   - for (int j=0; j < singleItemLists.size(); j++)
120   - copy.append(singleItemLists[j]);
121   -
122 116 // the next stage might be trainable, so continue to evaluate it.
123 117 continue;
124 118 }
... ... @@ -136,14 +130,10 @@ class PipeTransform : public CompositeTransform
136 130  
137 131 fprintf(stderr, " projecting...");
138 132 QFutureSynchronizer<void> futures;
139   - for (int j=0; j < singleItemLists.size(); j++)
140   - futures.addFuture(QtConcurrent::run(this, &PipeTransform::_projectPartial, &singleItemLists[j], i, nextTrainableTransform));
  133 + for (int j=0; j < dataLines.size(); j++)
  134 + futures.addFuture(QtConcurrent::run(this, &PipeTransform::_projectPartial, &dataLines[j], i, nextTrainableTransform));
141 135 futures.waitForFinished();
142 136  
143   - copy.clear();
144   - for (int j=0; j < singleItemLists.size(); j++)
145   - copy.append(singleItemLists[j]);
146   -
147 137 i = nextTrainableTransform;
148 138 }
149 139 }
... ... @@ -296,7 +286,13 @@ class ForkTransform : public CompositeTransform
296 286 {
297 287 Q_OBJECT
298 288  
299   - void train(const TemplateList &data)
  289 + void train(const TemplateList & data)
  290 + {
  291 + (void) data;
  292 + qFatal("Terminal train called on interior node.");
  293 + }
  294 +
  295 + void train(const QList<TemplateList> &data)
300 296 {
301 297 if (!trainable) return;
302 298 QFutureSynchronizer<void> futures;
... ... @@ -612,9 +608,23 @@ public:
612 608 return output;
613 609 }
614 610  
615   - void train(const TemplateList &data)
  611 + void train(const TemplateList & data)
616 612 {
617   - transform->train(data);
  613 + (void) data;
  614 + qFatal("terminal train called on non-leaf transform");
  615 + }
  616 +
  617 + void train(const QList<TemplateList> &data)
  618 + {
  619 + QList<TemplateList> separated;
  620 + foreach (const TemplateList & list, data) {
  621 + foreach(const Template & t, list) {
  622 + separated.append(TemplateList());
  623 + separated.last().append(t);
  624 + }
  625 + }
  626 +
  627 + transform->train(separated);
618 628 }
619 629  
620 630 void project(const Template &src, Template &dst) const
... ...
openbr/plugins/openbr_internal.h
... ... @@ -140,7 +140,13 @@ public:
140 140 transformSource.release(aTransform);
141 141 }
142 142  
143   - void train(const TemplateList &data)
  143 + void train(const TemplateList & data)
  144 + {
  145 + (void) data;
  146 + qFatal("terminal train called on non-leaf node");
  147 + }
  148 +
  149 + void train(const QList<TemplateList> &data)
144 150 {
145 151 baseTransform->train(data);
146 152 }
... ...
openbr/plugins/validate.cpp
... ... @@ -6,6 +6,11 @@
6 6 namespace br
7 7 {
8 8  
  9 +static void _train(Transform * transform, TemplateList data) // think data has to be a copy -cao
  10 +{
  11 + transform->train(data);
  12 +}
  13 +
9 14 /*!
10 15 * \ingroup transforms
11 16 * \brief Cross validate a trainable transform.
... ... @@ -46,7 +51,8 @@ class CrossValidateTransform : public MetaTransform
46 51 if (partitions[j] == i)
47 52 partitionedData.removeAt(j);
48 53 // Train on the remaining templates
49   - futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData));
  54 + //futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData));
  55 + futures.addFuture(QtConcurrent::run(_train, transforms[i], partitionedData));
50 56 }
51 57 futures.waitForFinished();
52 58 }
... ...