Commit b6ff0861579e9e4dd3e68261d2f1bb63263675b9

Authored by Josh Klontz
1 parent 49b51c19

progress on cross validation framework

sdk/core/plot.cpp
... ... @@ -124,21 +124,23 @@ float Evaluate(const QString &simmat, const QString &mask, const QString &csv)
124 124  
125 125 // Make comparisons
126 126 QList<Comparison> comparisons; comparisons.reserve(scores.rows*scores.cols);
127   - int genuineCount = 0, impostorCount = 0;
  127 + int genuineCount = 0, impostorCount = 0, numNaNs = 0;
128 128 for (int i=0; i<scores.rows; i++) {
129 129 for (int j=0; j<scores.cols; j++) {
130 130 const BEE::Mask_t mask_val = masks.at<BEE::Mask_t>(i,j);
131 131 const BEE::Simmat_t simmat_val = scores.at<BEE::Simmat_t>(i,j);
132 132 if ((mask_val == BEE::DontCare) ||
133 133 (simmat_val == -std::numeric_limits<float>::max())) continue;
  134 + if (simmat_val != simmat_val) { numNaNs++; continue; }
134 135 comparisons.append(Comparison(simmat_val, j, i, mask_val == BEE::Match));
135 136 if (comparisons.last().genuine) genuineCount++;
136 137 else impostorCount++;
137 138 }
138 139 }
139 140  
140   - if (genuineCount == 0) qFatal("No genuine scores.");
141   - if (impostorCount == 0) qFatal("No impostor scores.");
  141 + if (numNaNs > 0) qWarning("Encountered %d NaN scores!", numNaNs);
  142 + if (genuineCount == 0) qFatal("No genuine scores!");
  143 + if (impostorCount == 0) qFatal("No impostor scores!");
142 144  
143 145 std::sort(comparisons.begin(), comparisons.end());
144 146  
... ... @@ -324,7 +326,8 @@ struct RPlot
324 326 pivotItems = QVector< QSet<QString> >(pivotHeaders.size());
325 327 foreach (const QString &fileName, files) {
326 328 QStringList pivots = getPivots(fileName, false);
327   - if (pivots.size() != pivotHeaders.size()) qFatal("plot.cpp::initializeR pivot size mismatch.");
  329 + if (pivots.size() != pivotHeaders.size())
  330 + qFatal("Pivot size mismatch: [%s] [%s]", qPrintable(pivotHeaders.join(",")), qPrintable(pivots.join(",")));
328 331 file.write(qPrintable(QString("tmp <- read.csv(\"%1\")\n").arg(fileName).replace("\\", "\\\\")));
329 332 for (int i=0; i<pivots.size(); i++) {
330 333 pivotItems[i].insert(pivots[i]);
... ...
sdk/openbr_plugin.cpp
... ... @@ -451,6 +451,19 @@ TemplateList TemplateList::fromInput(const br::File &amp;input)
451 451 return templates;
452 452 }
453 453  
  454 +TemplateList TemplateList::relabel(const TemplateList &tl)
  455 +{
  456 + QHash<int,int> labels;
  457 + foreach (int label, tl.labels<int>())
  458 + if (!labels.contains(label))
  459 + labels.insert(label, labels.size());
  460 +
  461 + TemplateList result = tl;
  462 + for (int i=0; i<result.size(); i++)
  463 + result[i].file.setLabel(labels[result[i].file.label()]);
  464 + return result;
  465 +}
  466 +
454 467 /* Object - public methods */
455 468 QStringList Object::parameters() const
456 469 {
... ... @@ -1053,10 +1066,8 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t
1053 1066  
1054 1067 std::random_shuffle(indices.begin(), indices.end());
1055 1068 const int max = atLeast ? indices.size() : std::min(indices.size(), instances);
1056   - for (int j=0; j<max; j++) {
  1069 + for (int j=0; j<max; j++)
1057 1070 downsample.append(templates.value(indices[j]));
1058   - if (transform->relabel) downsample.last().file.insert("Label", i);
1059   - }
1060 1071 }
1061 1072  
1062 1073 if (transform->fraction < 1) {
... ... @@ -1168,7 +1179,6 @@ Transform::Transform(bool _independent, bool _trainable)
1168 1179 {
1169 1180 independent = _independent;
1170 1181 trainable = _trainable;
1171   - relabel = false;
1172 1182 classes = std::numeric_limits<int>::max();
1173 1183 instances = std::numeric_limits<int>::max();
1174 1184 fraction = 1;
... ... @@ -1222,7 +1232,6 @@ Transform *Transform::make(QString str, QObject *parent)
1222 1232 Transform *Transform::clone() const
1223 1233 {
1224 1234 Transform *clone = Factory<Transform>::make(file.flat());
1225   - clone->relabel = relabel;
1226 1235 clone->classes = classes;
1227 1236 clone->instances = instances;
1228 1237 clone->fraction = fraction;
... ...
sdk/openbr_plugin.h
... ... @@ -323,7 +323,7 @@ struct TemplateList : public QList&lt;Template&gt;
323 323 TemplateList() : uniform(false) {}
324 324 TemplateList(const QList<Template> &templates) : uniform(false) { append(templates); } /*!< \brief Initialize the template list from another template list. */
325 325 BR_EXPORT static TemplateList fromInput(const File &input); /*!< \brief Create a template list from a br::Input. */
326   -
  326 + BR_EXPORT static TemplateList relabel(const TemplateList &tl); /*!< \brief Ensure labels are in the range [0,numClasses-1]. */
327 327 /*!
328 328 * \brief Returns the total number of bytes in all the templates.
329 329 */
... ... @@ -932,11 +932,9 @@ class BR_EXPORT Transform : public Object
932 932 Q_OBJECT
933 933  
934 934 public:
935   - Q_PROPERTY(bool relabel READ get_relabel WRITE set_relabel RESET reset_relabel STORED false)
936 935 Q_PROPERTY(int classes READ get_classes WRITE set_classes RESET reset_classes STORED false)
937 936 Q_PROPERTY(int instances READ get_instances WRITE set_instances RESET reset_instances STORED false)
938 937 Q_PROPERTY(float fraction READ get_fraction WRITE set_fraction RESET reset_fraction STORED false)
939   - BR_PROPERTY(bool, relabel, false)
940 938 BR_PROPERTY(int, classes, std::numeric_limits<int>::max())
941 939 BR_PROPERTY(int, instances, std::numeric_limits<int>::max())
942 940 BR_PROPERTY(float, fraction, 1)
... ...
sdk/plugins/algorithms.cpp
... ... @@ -64,7 +64,7 @@ class AlgorithmsInitializer : public Initializer
64 64 Globals->abbreviations.insert("DenseSIFT", "(Grid(10,10)+SIFTDescriptor(12)+ByRow)");
65 65 Globals->abbreviations.insert("FaceRecognitionRegistration", "(ASEFEyes+Affine(88,88,0.25,0.35)+FTE(DFFS,instances=1))");
66 66 Globals->abbreviations.insert("FaceRecognitionExtraction", "(Mask+DenseSIFT/DenseLBP+PCA(0.95,instances=1)+Normalize(L2)+Cat)");
67   - Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+LDA(0.98,instances=-2,relabel=true)+Cat+PCA(768,instances=1))");
  67 + Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+LDA(0.98,instances=-2)+Cat+PCA(768,instances=1))");
68 68 Globals->abbreviations.insert("FaceRecognitionQuantization", "(Normalize(L1)+Quantize)");
69 69 Globals->abbreviations.insert("FaceClassificationRegistration", "(ASEFEyes+Affine(56,72,0.33,0.45)+FTE(DFFS))");
70 70 Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+PCA(0.95,instances=-1)+Cat)");
... ...
sdk/plugins/eigen3.cpp
... ... @@ -264,8 +264,11 @@ class LDATransform : public Transform
264 264 Eigen::VectorXf mean;
265 265 Eigen::MatrixXf projection;
266 266  
267   - void train(const TemplateList &trainingSet)
  267 + void train(const TemplateList &_trainingSet)
268 268 {
  269 + TemplateList trainingSet = _trainingSet;
  270 + trainingSet = TemplateList::relabel(trainingSet);
  271 +
269 272 int instances = trainingSet.size();
270 273  
271 274 // Perform PCA dimensionality reduction
... ... @@ -276,6 +279,7 @@ class LDATransform : public Transform
276 279  
277 280 TemplateList ldaTrainingSet;
278 281 static_cast<Transform*>(&pca)->project(trainingSet, ldaTrainingSet);
  282 + ldaTrainingSet = TemplateList::relabel(ldaTrainingSet);
279 283  
280 284 int dimsIn = ldaTrainingSet.first().m().rows * ldaTrainingSet.first().m().cols;
281 285  
... ...
sdk/plugins/quality.cpp
... ... @@ -91,7 +91,7 @@ struct KDE
91 91 bins.append(Common::KernelDensityEstimation(scores, min + (max-min)*i/(size-1), h));
92 92 }
93 93  
94   - float operator()(float score, bool gaussian = false) const
  94 + float operator()(float score, bool gaussian = true) const
95 95 {
96 96 if (gaussian) return 1/(stddev*sqrt(2*CV_PI))*exp(-0.5*pow((score-mean)/stddev, 2));
97 97 if (score <= min) return bins.first();
... ... @@ -120,12 +120,12 @@ struct MP
120 120 MP() {}
121 121 MP(const QList<float> &genuineScores, const QList<float> &impostorScores)
122 122 : genuine(genuineScores), impostor(impostorScores) {}
123   - float operator()(float score, bool gaussian = false, bool log = false) const
  123 + float operator()(float score, bool gaussian = true) const
124 124 {
125 125 const float g = genuine(score, gaussian);
126 126 const float s = g / (impostor(score, gaussian) + g);
127   - if (log) return (std::max(std::log10(s), -10.f) + 10)/10;
128   - else return s;
  127 + if (s != s) qDebug() << "!!" << g << impostor(score, gaussian) << score << genuine.mean << genuine.stddev << impostor.mean << impostor.stddev;
  128 + return s;
129 129 }
130 130 };
131 131  
... ... @@ -150,11 +150,9 @@ class MPDistance : public Distance
150 150 Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
151 151 Q_PROPERTY(QString binKey READ get_binKey WRITE set_binKey RESET reset_binKey STORED false)
152 152 Q_PROPERTY(bool gaussian READ get_gaussian WRITE set_gaussian RESET reset_gaussian STORED false)
153   - Q_PROPERTY(bool log READ get_log WRITE set_log RESET reset_log STORED false)
154 153 BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))
155 154 BR_PROPERTY(QString, binKey, "")
156   - BR_PROPERTY(bool, gaussian, false)
157   - BR_PROPERTY(bool, log, false)
  155 + BR_PROPERTY(bool, gaussian, true)
158 156  
159 157 QHash<QString, MP> mps;
160 158  
... ... @@ -170,6 +168,7 @@ class MPDistance : public Distance
170 168 for (int i=0; i<src.size(); i++)
171 169 for (int j=0; j<i; j++) {
172 170 const float score = memoryOutput.data()->data.at<float>(i, j);
  171 + if (score == -std::numeric_limits<float>::max()) continue;
173 172 const QString bin = src[i].file.getString(binKey, "");
174 173 if (labels[i] == labels[j]) genuineScores[bin].append(score);
175 174 else impostorScores[bin].append(score);
... ... @@ -181,7 +180,9 @@ class MPDistance : public Distance
181 180  
182 181 float compare(const Template &target, const Template &query) const
183 182 {
184   - return mps[query.file.getString(binKey, "")](distance->compare(target, query), gaussian, log);
  183 + float rawScore = distance->compare(target, query);
  184 + if (rawScore == -std::numeric_limits<float>::max()) return rawScore;
  185 + return mps[query.file.getString(binKey, "")](rawScore, gaussian);
185 186 }
186 187  
187 188 void store(QDataStream &stream) const
... ...
sdk/plugins/validate.cpp
... ... @@ -9,37 +9,33 @@ namespace br
9 9 * \brief Cross validate a trainable transform.
10 10 * \author Josh Klontz \cite jklontz
11 11 */
12   -class CrossValidateTransform : public Transform
  12 +class CrossValidateTransform : public MetaTransform
13 13 {
14 14 Q_OBJECT
15 15 Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)
16 16 Q_PROPERTY(QList<br::Transform*> transforms READ get_transforms WRITE set_transforms RESET reset_transforms)
17 17 BR_PROPERTY(QString, description, "Identity")
18   - BR_PROPERTY(QList<br::Transform*>, transforms, QList<br::Transform*>() << make(description))
  18 + BR_PROPERTY(QList<br::Transform*>, transforms, QList<br::Transform*>())
19 19  
20 20 void train(const TemplateList &data)
21 21 {
22   - if (!transforms.first()->trainable)
23   - return;
24   -
25 22 int numPartitions = 0;
26 23 QList<int> partitions; partitions.reserve(data.size());
27 24 foreach (const File &file, data.files()) {
28 25 partitions.append(file.getInt("Cross_Validation_Partition", 0));
29   - numPartitions = std::max(numPartitions, partitions.last());
  26 + numPartitions = std::max(numPartitions, partitions.last()+1);
30 27 }
31 28  
  29 + while (transforms.size() < numPartitions)
  30 + transforms.append(make(description));
  31 +
32 32 if (numPartitions < 2) {
33 33 transforms.first()->train(data);
34 34 return;
35 35 }
36 36  
37   - while (transforms.size() < numPartitions)
38   - transforms.append(make(description));
39   -
40 37 QList< QFuture<void> > futures;
41 38 for (int i=0; i<numPartitions; i++) {
42   - qDebug() << "!!" << transforms[i]->description();
43 39 TemplateList partitionedData = data;
44 40 for (int j=partitionedData.size()-1; j>=0; j--)
45 41 if (partitions[j] == i)
... ... @@ -47,6 +43,7 @@ class CrossValidateTransform : public Transform
47 43 if (Globals->parallelism) futures.append(QtConcurrent::run(transforms[i], &Transform::train, partitionedData));
48 44 else transforms[i]->train(partitionedData);
49 45 }
  46 + Globals->trackFutures(futures);
50 47 }
51 48  
52 49 void project(const Template &src, Template &dst) const
... ...