Commit d0aa1fd8eedc2fe9136b9af396045ba328495839

Authored by Charles Otto
1 parent f0eb0cfb

Model separate project calls as seaprate list items in train

For non-leaf transforms, add a separate train method that takes
QList<TemplateList> as arguments, where the separate list items correspond to
the separate calls to project that would be made by parent transforms during
enrollment. This is basically to deal with the effects of distribute on
untrainable meta transforms such as flatten, we have to maintain a consistent
grouping with project, otherwise the intermediate projection results used
during training won't be consistent with what they would be during enrollment.
openbr/core/core.cpp
@@ -48,7 +48,9 @@ struct AlgorithmCore @@ -48,7 +48,9 @@ struct AlgorithmCore
48 48
49 QTime time; time.start(); 49 QTime time; time.start();
50 qDebug("Training Enrollment"); 50 qDebug("Training Enrollment");
51 - transform->train(data); 51 + QList<TemplateList> uniform;
  52 + uniform.append(data);
  53 + transform->train(uniform);
52 54
53 if (!distance.isNull()) { 55 if (!distance.isNull()) {
54 qDebug("Projecting Enrollment"); 56 qDebug("Projecting Enrollment");
openbr/openbr_plugin.h
@@ -1058,7 +1058,19 @@ public: @@ -1058,7 +1058,19 @@ public:
1058 static QSharedPointer<Transform> fromAlgorithm(const QString &algorithm); /*!< \brief Retrieve an algorithm's transform. */ 1058 static QSharedPointer<Transform> fromAlgorithm(const QString &algorithm); /*!< \brief Retrieve an algorithm's transform. */
1059 1059
1060 virtual Transform *clone() const; /*!< \brief Copy the transform. */ 1060 virtual Transform *clone() const; /*!< \brief Copy the transform. */
  1061 + /*!< \brief Train the transform, separate list items represent the way calls to project would be broken up*/
  1062 + virtual void train(const QList<TemplateList> &data)
  1063 + {
  1064 + TemplateList combined;
  1065 + foreach(const TemplateList & set, data) {
  1066 + combined.append(set);
  1067 + }
  1068 + train(combined);
  1069 + }
  1070 +
  1071 + // Terminal train, call with a complete training set when no further structure is needed
1061 virtual void train(const TemplateList &data) = 0; /*!< \brief Train the transform. */ 1072 virtual void train(const TemplateList &data) = 0; /*!< \brief Train the transform. */
  1073 +
1062 virtual void project(const Template &src, Template &dst) const = 0; /*!< \brief Apply the transform. */ 1074 virtual void project(const Template &src, Template &dst) const = 0; /*!< \brief Apply the transform. */
1063 virtual void project(const TemplateList &src, TemplateList &dst) const; /*!< \brief Apply the transform. */ 1075 virtual void project(const TemplateList &src, TemplateList &dst) const; /*!< \brief Apply the transform. */
1064 1076
openbr/plugins/meta.cpp
@@ -57,7 +57,7 @@ static TemplateList Expanded(const TemplateList &amp;templates) @@ -57,7 +57,7 @@ static TemplateList Expanded(const TemplateList &amp;templates)
57 return expanded; 57 return expanded;
58 } 58 }
59 59
60 -static void _train(Transform *transform, const TemplateList *data) 60 +static void _train(Transform *transform, const QList<TemplateList> *data)
61 { 61 {
62 transform->train(*data); 62 transform->train(*data);
63 } 63 }
@@ -82,18 +82,17 @@ class PipeTransform : public CompositeTransform @@ -82,18 +82,17 @@ class PipeTransform : public CompositeTransform
82 *srcdst >> *transforms[i]; 82 *srcdst >> *transforms[i];
83 } 83 }
84 84
85 - void train(const TemplateList &data) 85 + void train(const TemplateList & data)
  86 + {
  87 + (void) data;
  88 + qFatal("Terminal train called on interior node");
  89 + }
  90 +
  91 + void train(const QList<TemplateList> &data)
86 { 92 {
87 if (!trainable) return; 93 if (!trainable) return;
88 94
89 - TemplateList copy(data);  
90 - QList<TemplateList> singleItemLists;  
91 - for (int i=0; i < copy.size(); i++)  
92 - {  
93 - TemplateList temp;  
94 - temp.append(copy[i]);  
95 - singleItemLists.append(temp);  
96 - } 95 + QList<TemplateList> dataLines(data);
97 96
98 int i = 0; 97 int i = 0;
99 while (i < transforms.size()) { 98 while (i < transforms.size()) {
@@ -102,23 +101,18 @@ class PipeTransform : public CompositeTransform @@ -102,23 +101,18 @@ class PipeTransform : public CompositeTransform
102 // Conditional statement covers likely case that first transform is untrainable 101 // Conditional statement covers likely case that first transform is untrainable
103 if (transforms[i]->trainable) { 102 if (transforms[i]->trainable) {
104 fprintf(stderr, " training..."); 103 fprintf(stderr, " training...");
105 - transforms[i]->train(copy); 104 + transforms[i]->train(dataLines);
106 } 105 }
107 106
108 // if the transform is time varying, we can't project it in parallel 107 // if the transform is time varying, we can't project it in parallel
109 if (transforms[i]->timeVarying()) { 108 if (transforms[i]->timeVarying()) {
110 fprintf(stderr, "\n%s projecting...", qPrintable(transforms[i]->objectName())); 109 fprintf(stderr, "\n%s projecting...", qPrintable(transforms[i]->objectName()));
111 - for (int j=0; j < singleItemLists.size();j++)  
112 - transforms[i]->projectUpdate(singleItemLists[j], singleItemLists[j]); 110 + for (int j=0; j < dataLines.size();j++)
  111 + transforms[i]->projectUpdate(dataLines[j], dataLines[j]);
113 112
114 // advance i since we already projected for this stage. 113 // advance i since we already projected for this stage.
115 i++; 114 i++;
116 115
117 - // set up copy again  
118 - copy.clear();  
119 - for (int j=0; j < singleItemLists.size(); j++)  
120 - copy.append(singleItemLists[j]);  
121 -  
122 // the next stage might be trainable, so continue to evaluate it. 116 // the next stage might be trainable, so continue to evaluate it.
123 continue; 117 continue;
124 } 118 }
@@ -136,14 +130,10 @@ class PipeTransform : public CompositeTransform @@ -136,14 +130,10 @@ class PipeTransform : public CompositeTransform
136 130
137 fprintf(stderr, " projecting..."); 131 fprintf(stderr, " projecting...");
138 QFutureSynchronizer<void> futures; 132 QFutureSynchronizer<void> futures;
139 - for (int j=0; j < singleItemLists.size(); j++)  
140 - futures.addFuture(QtConcurrent::run(this, &PipeTransform::_projectPartial, &singleItemLists[j], i, nextTrainableTransform)); 133 + for (int j=0; j < dataLines.size(); j++)
  134 + futures.addFuture(QtConcurrent::run(this, &PipeTransform::_projectPartial, &dataLines[j], i, nextTrainableTransform));
141 futures.waitForFinished(); 135 futures.waitForFinished();
142 136
143 - copy.clear();  
144 - for (int j=0; j < singleItemLists.size(); j++)  
145 - copy.append(singleItemLists[j]);  
146 -  
147 i = nextTrainableTransform; 137 i = nextTrainableTransform;
148 } 138 }
149 } 139 }
@@ -296,7 +286,13 @@ class ForkTransform : public CompositeTransform @@ -296,7 +286,13 @@ class ForkTransform : public CompositeTransform
296 { 286 {
297 Q_OBJECT 287 Q_OBJECT
298 288
299 - void train(const TemplateList &data) 289 + void train(const TemplateList & data)
  290 + {
  291 + (void) data;
  292 + qFatal("Terminal train called on interior node.");
  293 + }
  294 +
  295 + void train(const QList<TemplateList> &data)
300 { 296 {
301 if (!trainable) return; 297 if (!trainable) return;
302 QFutureSynchronizer<void> futures; 298 QFutureSynchronizer<void> futures;
@@ -612,9 +608,23 @@ public: @@ -612,9 +608,23 @@ public:
612 return output; 608 return output;
613 } 609 }
614 610
615 - void train(const TemplateList &data) 611 + void train(const TemplateList & data)
616 { 612 {
617 - transform->train(data); 613 + (void) data;
  614 + qFatal("terminal train called on non-leaf transform");
  615 + }
  616 +
  617 + void train(const QList<TemplateList> &data)
  618 + {
  619 + QList<TemplateList> separated;
  620 + foreach (const TemplateList & list, data) {
  621 + foreach(const Template & t, list) {
  622 + separated.append(TemplateList());
  623 + separated.last().append(t);
  624 + }
  625 + }
  626 +
  627 + transform->train(separated);
618 } 628 }
619 629
620 void project(const Template &src, Template &dst) const 630 void project(const Template &src, Template &dst) const
openbr/plugins/openbr_internal.h
@@ -140,7 +140,13 @@ public: @@ -140,7 +140,13 @@ public:
140 transformSource.release(aTransform); 140 transformSource.release(aTransform);
141 } 141 }
142 142
143 - void train(const TemplateList &data) 143 + void train(const TemplateList & data)
  144 + {
  145 + (void) data;
  146 + qFatal("terminal train called on non-leaf node");
  147 + }
  148 +
  149 + void train(const QList<TemplateList> &data)
144 { 150 {
145 baseTransform->train(data); 151 baseTransform->train(data);
146 } 152 }
openbr/plugins/validate.cpp
@@ -6,6 +6,11 @@ @@ -6,6 +6,11 @@
6 namespace br 6 namespace br
7 { 7 {
8 8
  9 +static void _train(Transform * transform, TemplateList data) // think data has to be a copy -cao
  10 +{
  11 + transform->train(data);
  12 +}
  13 +
9 /*! 14 /*!
10 * \ingroup transforms 15 * \ingroup transforms
11 * \brief Cross validate a trainable transform. 16 * \brief Cross validate a trainable transform.
@@ -46,7 +51,8 @@ class CrossValidateTransform : public MetaTransform @@ -46,7 +51,8 @@ class CrossValidateTransform : public MetaTransform
46 if (partitions[j] == i) 51 if (partitions[j] == i)
47 partitionedData.removeAt(j); 52 partitionedData.removeAt(j);
48 // Train on the remaining templates 53 // Train on the remaining templates
49 - futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); 54 + //futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData));
  55 + futures.addFuture(QtConcurrent::run(_train, transforms[i], partitionedData));
50 } 56 }
51 futures.waitForFinished(); 57 futures.waitForFinished();
52 } 58 }