Commit 1304c275563e69d41ae3070911b47db3ae639868

Authored by Charles Otto
1 parent ffcbc59f

In eval, construct a pairwise mask for a matrix if that seems appropriate

If we get a 1 column matrix in eval, check if the query and target sets are
of matching size, and construct a pairwise mask for that matrix.
openbr/core/core.cpp
... ... @@ -453,14 +453,23 @@ void br::Convert(const File &fileType, const File &inputFile, const File &output
453 453 const FileList targetFiles = TemplateList::fromGallery(target).files();
454 454 const FileList queryFiles = TemplateList::fromGallery(query).files();
455 455  
456   - if (targetFiles.size() != m.cols || queryFiles.size() != m.rows)
  456 + if ((targetFiles.size() != m.cols || queryFiles.size() != m.rows)
  457 + && (m.cols != 1 || targetFiles.size() != m.rows || queryFiles.size() != m.rows))
457 458 qFatal("Similarity matrix and file size mismatch.");
458 459  
459 460 QSharedPointer<Output> o(Factory<Output>::make(outputFile));
460 461 o->initialize(targetFiles, queryFiles);
461 462  
462   - for (int i=0; i<queryFiles.size(); i++)
463   - for (int j=0; j<targetFiles.size(); j++)
  463 + if (targetFiles.size() != m.cols)
  464 + {
  465 + MatrixOutput * mOut = dynamic_cast<MatrixOutput *>(o.data());
  466 + if (mOut)
  467 + mOut->data.create(queryFiles.size(), 1, CV_32FC1);
  468 + }
  469 +
  470 + o->setBlock(0,0);
  471 + for (int i=0; i < m.rows; i++)
  472 + for (int j=0; j < m.cols; j++)
464 473 o->setRelative(m.at<float>(i,j), i, j);
465 474 } else {
466 475 qFatal("Unrecognized file type %s.", qPrintable(fileType.flat()));
... ...
openbr/core/eval.cpp
... ... @@ -65,9 +65,26 @@ static float getTAR(const QList&lt;OperatingPoint&gt; &amp;operatingPoints, float FAR)
65 65 return m * FAR + b;
66 66 }
67 67  
  68 +// Decide whether to construct a normal mask matrix, or a pairwise mask by comparing the dimensions of
  69 +// scores with the size of the target and query lists
  70 +static cv::Mat constructMatchingMask(const cv::Mat & scores, const FileList & target, const FileList & query, int partition=0)
  71 +{
  72 + // If the dimensions of the score matrix match the sizes of the target and query lists, construct a normal mask matrix
  73 + if (target.size() == scores.cols && query.size() == scores.rows)
  74 + return BEE::makeMask(target, query, partition);
  75 + // If this looks like a pairwise comparison (1 column score matrix, equal length target and query sets), construct a
  76 + // mask for that
  77 + else if (scores.cols == 1 && target.size() == query.size()) {
  78 + return BEE::makePairwiseMask(target, query, partition);
  79 + }
  80 + // otherwise, we fail
  81 + else
  82 + qFatal("Unable to construct mask for %d by %d score matrix from %d element query set, and %d element target set ", scores.rows, scores.cols, query.length(), target.length());
  83 +}
  84 +
68 85 float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv, int partition)
69 86 {
70   - return Evaluate(scores, BEE::makeMask(target, query, partition), csv);
  87 + return Evaluate(scores, constructMatchingMask(scores, target, query, partition), csv);
71 88 }
72 89  
73 90 float Evaluate(const QString &simmat, const QString &mask, const QString &csv)
... ... @@ -93,8 +110,9 @@ float Evaluate(const QString &amp;simmat, const QString &amp;mask, const QString &amp;csv)
93 110 // Use the galleries specified in the similarity matrix
94 111 if (target.isEmpty()) qFatal("Unspecified target gallery.");
95 112 if (query.isEmpty()) qFatal("Unspecified query gallery.");
96   - truth = BEE::makeMask(TemplateList::fromGallery(target).files(),
97   - TemplateList::fromGallery(query).files());
  113 +
  114 + truth = constructMatchingMask(scores, TemplateList::fromGallery(target).files(),
  115 + TemplateList::fromGallery(query).files());
98 116 } else {
99 117 File maskFile(mask);
100 118 maskFile.set("rows", scores.rows);
... ...
openbr/plugins/output.cpp
... ... @@ -371,7 +371,7 @@ class evalOutput : public MatrixOutput
371 371 if (data.data) {
372 372 const QString csv = QString(file.name).replace(".eval", ".csv");
373 373 if ((Globals->crossValidate == 0) || (!crossValidate)) {
374   - Evaluate(data, BEE::makeMask(targetFiles, queryFiles), csv);
  374 + Evaluate(data,targetFiles, queryFiles, csv);
375 375 } else {
376 376 QFutureSynchronizer<float> futures;
377 377 for (int i=0; i<Globals->crossValidate; i++)
... ...