Commit be9fd90a2828e7cd85a37e1f667e8df0e5c38a9d

Authored by Jordan Cheney
1 parent b085d2aa

This PR makes 2 changes to evalEER and adds a new, corresponding, plotEER method…

…. The C API and command line tool
are updated to support the new function.

Changes are:

1. Change how operating points are computed to mirror eval (i.e. operating points are added only if both the true positive AND false positive count changes). This fixes a bug where successive operating points could have the same FAR, which caused nan results in the output.

2. Change the optional pdf argument to a CSV argument (again mirroring eval) so that results are written as CSVs and can be combined to show multiple results on the same plot
app/br/br.cpp
... ... @@ -192,6 +192,9 @@ public:
192 192 } else if (!strcmp(fun, "plotKNN")) {
193 193 check(parc >=2, "Incorrect parameter count for 'plotKNN'.");
194 194 br_plot_knn(parc-1, parv, parv[parc-1], true);
  195 + } else if (!strcmp(fun, "plotEER")) {
  196 + check(parc >= 2, "Incorrect parameter count for 'plotEER'.");
  197 + br_plot_eer(parc-1, parv, parv[parc-1], true);
195 198 } else if (!strcmp(fun, "project")) {
196 199 check(parc == 2, "Insufficient parameter count for 'project'.");
197 200 br_project(parv[0], parv[1]);
... ... @@ -298,6 +301,7 @@ private:
298 301 "-plotLandmarking <file> ... <file> {destination}\n"
299 302 "-plotMetadata <file> ... <file> <columns>\n"
300 303 "-plotKNN <file> ... <file> {destination}\n"
  304 + "-plotEER <file> ... <file> {destination}\n"
301 305 "-project <input_gallery> {output_gallery}\n"
302 306 "-deduplicate <input_gallery> <output_gallery> <threshold>\n"
303 307 "-likely <input_type> <output_type> <output_likely_source>\n"
... ...
openbr/core/eval.cpp
... ... @@ -377,8 +377,8 @@ float Evaluate(const Mat &amp;simmat, const Mat &amp;mask, const File &amp;csv, const QStrin
377 377 // Write TAR@FAR Table (TF)
378 378 foreach (float FAR, QList<float>() << 1e-6 << 1e-5 << 1e-4 << 1e-3 << 1e-2 << 1e-1)
379 379 lines.append(qPrintable(QString("TF,%1,%2").arg(
380   - QString::number(FAR, 'f'),
381   - QString::number(getOperatingPoint(operatingPoints, "FAR", FAR).TAR, 'f', 3))));
  380 + QString::number(FAR, 'f'),
  381 + QString::number(getOperatingPoint(operatingPoints, "FAR", FAR).TAR, 'f', 3))));
382 382  
383 383 // Write FAR@TAR Table (FT)
384 384 foreach (float TAR, QList<float>() << 0.4 << 0.5 << 0.65 << 0.75 << 0.85 << 0.95)
... ... @@ -1255,7 +1255,7 @@ void EvalKNN(const QString &amp;knnGraph, const QString &amp;knnTruth, const QString &amp;cs
1255 1255 qDebug("FNIR @ FPIR = 0.01: %.3f", 1-getOperatingPoint(operatingPoints, "FAR", 0.01).TAR);
1256 1256 }
1257 1257  
1258   -void EvalEER(const QString &predictedXML, QString gt_property, QString distribution_property, const QString &pdf) {
  1258 +void EvalEER(const QString &predictedXML, QString gt_property, QString distribution_property, const QString &csv) {
1259 1259 if (gt_property.isEmpty())
1260 1260 gt_property = "LivenessGT";
1261 1261 if (distribution_property.isEmpty())
... ... @@ -1263,59 +1263,79 @@ void EvalEER(const QString &amp;predictedXML, QString gt_property, QString distribut
1263 1263 int classOneTemplateCount = 0;
1264 1264 const TemplateList templateList(TemplateList::fromGallery(predictedXML));
1265 1265  
1266   - QHash<QString, int> gtLabels;
1267   - QHash<QString, float > scores;
  1266 + QList<QPair<float, int>> scores;
  1267 + QList<float> classZeroScores, classOneScores;
1268 1268 for (int i=0; i<templateList.size(); i++) {
1269 1269 if (!templateList[i].file.contains(distribution_property) || !templateList[i].file.contains(gt_property))
1270 1270 continue;
1271   - QString templateKey = templateList[i].file.path() + templateList[i].file.baseName();
  1271 +
1272 1272 const int gtLabel = templateList[i].file.get<int>(gt_property);
1273   - if (gtLabel == 1)
  1273 + const float templateScore = templateList[i].file.get<float>(distribution_property);
  1274 + scores.append(qMakePair(templateScore, gtLabel));
  1275 +
  1276 + if (gtLabel == 1) {
1274 1277 classOneTemplateCount++;
1275   - const float templateScores = templateList[i].file.get<float>(distribution_property);
1276   - gtLabels[templateKey] = gtLabel;
1277   - scores[templateKey] = templateScores;
  1278 + classOneScores.append(templateScore);
  1279 + } else {
  1280 + classZeroScores.append(templateScore);
  1281 + }
1278 1282 }
1279 1283  
1280   - const int numPoints = 200;
1281   - const float stepSize = 100.0/numPoints;
1282   - const int numTemplates = scores.size();
1283   - float thres = 0.0; //Between [0,100]
1284   - float thresNorm = 0.0; //Between [0,1]
1285   - float minDiff = 100, EER = 100, EERThres = 0;
  1284 + std::sort(scores.begin(), scores.end());
  1285 +
1286 1286 QList<OperatingPoint> operatingPoints;
1287 1287  
1288   - for(int i = 0; i <= numPoints; i++) {
1289   - int FA = 0, FR = 0;
1290   - thresNorm = thres/100.0;
1291   - foreach(const QString &key, scores.keys()) {
1292   - int gtLabel = gtLabels[key];
1293   - //> thresNorm = class 0 (spoof) : < thresNorm = class 1 (genuine)
1294   - if (scores[key] >= thresNorm && gtLabel == 0)
1295   - continue;
1296   - else if (scores[key] < thresNorm && gtLabel == 1)
1297   - continue;
1298   - else if (scores[key] >= thresNorm && gtLabel == 1)
1299   - FR +=1;
1300   - else if (scores[key] < thresNorm && gtLabel == 0)
1301   - FA +=1;
  1288 + const int classZeroTemplateCount = scores.size() - classOneTemplateCount;
  1289 + int falsePositives = 0, previousFalsePositives = 0;
  1290 + int truePositives = 0, previousTruePositives = 0;
  1291 + size_t index = 0;
  1292 + float minDiff = 100, EER = 100, EERThres = 0;
  1293 + float minClassOneScore = std::numeric_limits<float>::max();
  1294 + float minClassZeroScore = std::numeric_limits<float>::max();
  1295 +
  1296 + while (index < scores.size()) {
  1297 + float thresh = scores[index].first;
  1298 + // Compute genuine and imposter statistics at a threshold
  1299 + while ((index < scores.size()) &&
  1300 + (scores[index].first == thresh)) {
  1301 + if (scores[index].second) {
  1302 + truePositives++;
  1303 + if (scores[index].first != -std::numeric_limits<float>::max() && scores[index].first < minClassOneScore)
  1304 + minClassOneScore = scores[index].first;
  1305 + } else {
  1306 + falsePositives++;
  1307 + if (scores[index].first != -std::numeric_limits<float>::max() && scores[index].first < minClassZeroScore)
  1308 + minClassZeroScore = scores[index].first;
  1309 + }
  1310 + index++;
1302 1311 }
1303   - const float FAR = FA / float(numTemplates - classOneTemplateCount);
1304   - const float FRR = FR / float(classOneTemplateCount);
1305   - operatingPoints.append(OperatingPoint(thresNorm, FAR, 1-FRR));
1306   -
1307   - const float diff = std::abs(FAR-FRR);
1308   - if (diff < minDiff) {
1309   - minDiff = diff;
1310   - EER = (FAR+FRR)/2.0;
1311   - EERThres = thresNorm;
  1312 +
  1313 + if ((falsePositives > previousFalsePositives) &&
  1314 + (truePositives > previousTruePositives)) {
  1315 + const float FAR = float(falsePositives) / classZeroTemplateCount;
  1316 + const float TAR = float(truePositives) / classOneTemplateCount;
  1317 + const float FRR = 1 - TAR;
  1318 + operatingPoints.append(OperatingPoint(thresh, FAR, TAR));
  1319 +
  1320 + const float diff = std::abs(FAR-FRR);
  1321 + if (diff < minDiff) {
  1322 + minDiff = diff;
  1323 + EER = (FAR+FRR)/2.0;
  1324 + EERThres = thresh;
  1325 + }
  1326 +
  1327 + previousFalsePositives = falsePositives;
  1328 + previousTruePositives = truePositives;
1312 1329 }
1313   - thres += stepSize;
1314 1330 }
1315 1331  
  1332 + if (operatingPoints.size() == 0) operatingPoints.append(OperatingPoint(1, 1, 1));
  1333 + if (operatingPoints.size() == 1) operatingPoints.prepend(OperatingPoint(0, 0, 0));
  1334 + if (operatingPoints.size() > 2) operatingPoints.takeLast(); // Remove point (1,1)
  1335 +
1316 1336 printf("\n==========================================================\n");
1317 1337 printf("Class 0 Templates: %d\tClass 1 Templates: %d\tTotal Templates: %d\n",
1318   - numTemplates-classOneTemplateCount, classOneTemplateCount, numTemplates);
  1338 + classZeroTemplateCount, classOneTemplateCount, classZeroTemplateCount + classOneTemplateCount);
1319 1339 printf("----------------------------------------------------------\n");
1320 1340 foreach (float FAR, QList<float>() << 0.2 << 0.1 << 0.05 << 0.01 << 0.001 << 0.0001) {
1321 1341 const OperatingPoint op = getOperatingPoint(operatingPoints, "FAR", FAR);
... ... @@ -1333,29 +1353,76 @@ void EvalEER(const QString &amp;predictedXML, QString gt_property, QString distribut
1333 1353 printf("==========================================================\n\n");
1334 1354  
1335 1355 // Optionally write ROC curve
1336   - if (!pdf.isEmpty()) {
1337   - QStringList farValues, tarValues;
1338   - float expFAR = std::max(ceil(log10(numTemplates - classOneTemplateCount)), 1.0);
  1356 + if (!csv.isEmpty()) {
  1357 + QStringList lines;
  1358 + lines.append("Plot,X,Y");
  1359 + lines.append("Metadata,"+QString::number(classZeroTemplateCount+classOneTemplateCount)+",Total Templates");
  1360 + lines.append("Metadata,"+QString::number(classZeroTemplateCount)+",Class 0 Template Count");
  1361 + lines.append("Metadata,"+QString::number(classOneTemplateCount)+",Class 1 Template Count");
  1362 +
  1363 + // Write Detection Error Tradeoff (DET), PRE, REC
  1364 + float expFAR = std::max(ceil(log10(classZeroTemplateCount)), 1.0);
  1365 + float expFRR = std::max(ceil(log10(classOneTemplateCount)), 1.0);
  1366 +
1339 1367 float FARstep = expFAR / (float)(Max_Points - 1);
  1368 + float FRRstep = expFRR / (float)(Max_Points - 1);
  1369 +
1340 1370 for (int i=0; i<Max_Points; i++) {
1341 1371 float FAR = pow(10, -expFAR + i*FARstep);
1342   - OperatingPoint op = getOperatingPoint(operatingPoints, "FAR", FAR);
1343   - farValues.append(QString::number(FAR));
1344   - tarValues.append(QString::number(op.TAR));
  1372 + float FRR = pow(10, -expFRR + i*FRRstep);
  1373 +
  1374 + OperatingPoint operatingPointFAR = getOperatingPoint(operatingPoints, "FAR", FAR);
  1375 + OperatingPoint operatingPointTAR = getOperatingPoint(operatingPoints, "TAR", 1-FRR);
  1376 + lines.append(QString("DET,%1,%2").arg(QString::number(FAR),
  1377 + QString::number(1-operatingPointFAR.TAR)));
  1378 + lines.append(QString("FAR,%1,%2").arg(QString::number(operatingPointFAR.score),
  1379 + QString::number(FAR)));
  1380 + lines.append(QString("FRR,%1,%2").arg(QString::number(operatingPointTAR.score),
  1381 + QString::number(FRR)));
  1382 + }
  1383 +
  1384 + // Write TAR@FAR Table (TF)
  1385 + foreach (float FAR, QList<float>() << 0.2 << 0.1 << 0.05 << 0.01 << 0.001 << 0.0001)
  1386 + lines.append(qPrintable(QString("TF,%1,%2").arg(
  1387 + QString::number(FAR, 'f'),
  1388 + QString::number(getOperatingPoint(operatingPoints, "FAR", FAR).TAR, 'f', 3))));
  1389 +
  1390 + // Write FAR@TAR Table (FT)
  1391 + foreach (float TAR, QList<float>() << 0.8 << 0.85 << 0.9 << 0.95 << 0.98)
  1392 + lines.append(qPrintable(QString("FT,%1,%2").arg(
  1393 + QString::number(TAR, 'f', 2),
  1394 + QString::number(getOperatingPoint(operatingPoints, "TAR", TAR).FAR, 'f', 3))));
  1395 +
  1396 + // Write FAR@Score Table (SF) and TAR@Score table (ST)
  1397 + foreach(const float score, QList<float>() << 0.05 << 0.1 << 0.15 << 0.2 << 0.25 << 0.3 << 0.35 << 0.4 << 0.45 << 0.5
  1398 + << 0.55 << 0.6 << 0.65 << 0.7 << 0.75 << 0.8 << 0.85 << 0.9 << 0.95) {
  1399 + const OperatingPoint op = getOperatingPoint(operatingPoints, "Score", score);
  1400 + lines.append(qPrintable(QString("SF,%1,%2").arg(
  1401 + QString::number(score, 'f', 2),
  1402 + QString::number(op.FAR))));
  1403 + lines.append(qPrintable(QString("ST,%1,%2").arg(
  1404 + QString::number(score, 'f', 2),
  1405 + QString::number(op.TAR))));
  1406 + }
  1407 +
  1408 + // Write FAR/TAR Bar Chart (BC)
  1409 + lines.append(qPrintable(QString("BC,0.0001,%1").arg(QString::number(getOperatingPoint(operatingPoints, "FAR", 0.0001).TAR, 'f', 3))));
  1410 + lines.append(qPrintable(QString("BC,0.001,%1").arg(QString::number(getOperatingPoint(operatingPoints, "FAR", 0.001).TAR, 'f', 3))));
  1411 +
  1412 + // Write SD & KDE
  1413 + int points = qMin(qMin(Max_Points, classZeroScores.size()), classOneScores.size());
  1414 + if (points > 1) {
  1415 + for (int i=0; i<points; i++) {
  1416 + float classZeroScore = classZeroScores[double(i) / double(points-1) * double(classZeroScores.size()-1)];
  1417 + float classOneScore = classOneScores[double(i) / double(points-1) * double(classOneScores.size()-1)];
  1418 + if (classZeroScore == -std::numeric_limits<float>::max()) classZeroScore = minClassZeroScore;
  1419 + if (classOneScore == -std::numeric_limits<float>::max()) classOneScore = minClassOneScore;
  1420 + lines.append(QString("SD,%1,Genuine").arg(QString::number(classOneScore)));
  1421 + lines.append(QString("SD,%1,Impostor").arg(QString::number(classZeroScore)));
  1422 + }
1345 1423 }
1346 1424  
1347   - QStringList rSource;
1348   - rSource << "# Load libraries" << "library(ggplot2)" << "" << "# Set Data"
1349   - << "FAR <- c(" + farValues.join(",") + ")"
1350   - << "TAR <- c(" + tarValues.join(",") + ")"
1351   - << "data <- data.frame(FAR, TAR)"
1352   - << "" << "# Construct Plot" << "pdf(\"" + pdf + "\")"
1353   - << "print(qplot(FAR, TAR, data=data, geom=\"line\") + scale_x_log10() + theme_minimal())"
1354   - << "dev.off()";
1355   -
1356   - QString rFile = "EvalEER.R";
1357   - QtUtils::writeFile(rFile, rSource);
1358   - QtUtils::runRScript(rFile);
  1425 + QtUtils::writeFile(csv, lines);
1359 1426 }
1360 1427 }
1361 1428  
... ...
openbr/core/eval.h
... ... @@ -34,7 +34,7 @@ namespace br
34 34 float EvalLandmarking(const QString &predictedGallery, const QString &truthGallery, const QString &csv = "", int normalizationIndexA = 0, int normalizationIndexB = 1, int sampleIndex = 0, int totalExamples = 5); // Return average error
35 35 void EvalRegression(const QString &predictedGallery, const QString &truthGallery, QString predictedProperty = "", QString truthProperty = "");
36 36 void EvalKNN(const QString &knnGraph, const QString &knnTruth, const QString &csv = "");
37   - void EvalEER(const QString &predictedXML, const QString gt_property = "", const QString distribution_property = "", const QString &pdf = "");
  37 + void EvalEER(const QString &predictedXML, const QString gt_property = "", const QString distribution_property = "", const QString &csv = "");
38 38 struct Candidate
39 39 {
40 40 size_t index;
... ...
openbr/core/plot.cpp
... ... @@ -368,4 +368,52 @@ bool PlotKNN(const QStringList &amp;files, const File &amp;destination, bool show)
368 368 return p.finalize(show);
369 369 }
370 370  
  371 +// Does not work if dataset folder starts with a number
  372 +bool PlotEER(const QStringList &files, const File &destination, bool show)
  373 +{
  374 + qDebug("Plotting %d file(s) to %s", files.size(), qPrintable(destination));
  375 +
  376 + RPlot p(files, destination);
  377 + p.file.write("\nformatData()\n\n");
  378 + p.file.write(qPrintable(QString("algs <- %1\n").arg((p.major.size > 1 && p.minor.size > 1) && !(p.major.smooth || p.minor.smooth) ? QString("paste(TF$%1, TF$%2, sep=\"_\")").arg(p.major.header, p.minor.header)
  379 + : QString("TF$%1").arg(p.major.size > 1 ? p.major.header : (p.minor.header.isEmpty() ? p.major.header : p.minor.header)))));
  380 + p.file.write("algs <- algs[!duplicated(algs)]\n");
  381 + if (p.major.smooth || p.minor.smooth) {
  382 + QString groupvar = p.major.size > 1 ? p.major.header : (p.minor.header.isEmpty() ? p.major.header : p.minor.header);
  383 + foreach(const QString &data, QStringList() << "DET" << "TF" << "FT") {
  384 + p.file.write(qPrintable(QString("%1 <- summarySE(%1, measurevar=\"Y\", groupvars=c(\"%2\", \"X\"), conf.interval=confidence)"
  385 + "\n").arg(data, groupvar)));
  386 + }
  387 + p.file.write(qPrintable(QString("%1 <- summarySE(%1, measurevar=\"X\", groupvars=c(\"Error\", \"%2\", \"Y\"), conf.interval=confidence)"
  388 + "\n\n").arg("ERR", groupvar)));
  389 + }
  390 +
  391 + // Use a br::file for simple storage of plot options
  392 + QMap<QString,File> optMap;
  393 + optMap.insert("rocOptions", File(QString("[xTitle=False Accept Rate,yTitle=True Accept Rate,xLog=true,yLog=false,xLimits=(.0000001,.1)]")));
  394 + optMap.insert("detOptions", File(QString("[xTitle=False Accept Rate,yTitle=False Reject Rate,xLog=true,yLog=true,xLimits=(.0000001,.1),yLimits=(.0001,1)]")));
  395 + optMap.insert("farOptions", File(QString("[xTitle=Score,yTitle=False Accept Rate,xLog=false,yLog=true,xLabels=waiver(),yLimits=(.0000001,1)]")));
  396 + optMap.insert("frrOptions", File(QString("[xTitle=Score,yTitle=False Reject Rate,xLog=false,yLog=true,xLabels=waiver(),yLimits=(.0001,1)]")));
  397 +
  398 + foreach (const QString &key, optMap.keys()) {
  399 + const QStringList options = destination.get<QStringList>(key, QStringList());
  400 + foreach (const QString &option, options) {
  401 + QStringList words = QtUtils::parse(option, '=');
  402 + QtUtils::checkArgsSize(words[0], words, 1, 2);
  403 + optMap[key].set(words[0], words[1]);
  404 + }
  405 + }
  406 +
  407 + // Write plots
  408 + QString plot = "plot <- plotLine(lineData=%1, options=list(%2), flipY=%3)\nplot\n";
  409 + p.file.write(qPrintable(QString(plot).arg("DET", toRList(optMap["rocOptions"]), "TRUE")));
  410 + p.file.write(qPrintable(QString(plot).arg("DET", toRList(optMap["detOptions"]), "FALSE")));
  411 + p.file.write("plot <- plotSD(sdData=SD)\nplot\n");
  412 + p.file.write("plot <- plotBC(bcData=BC)\nplot\n");
  413 + p.file.write(qPrintable(QString(plot).arg("FAR", toRList(optMap["farOptions"]), "FALSE")));
  414 + p.file.write(qPrintable(QString(plot).arg("FRR", toRList(optMap["frrOptions"]), "FALSE")));
  415 +
  416 + return p.finalize(show);
  417 +}
  418 +
371 419 } // namespace br
... ...
openbr/core/plot.h
... ... @@ -29,6 +29,7 @@ namespace br
29 29 bool PlotLandmarking(const QStringList &files, const File &destination, bool show = false);
30 30 bool PlotMetadata(const QStringList &files, const QString &destination, bool show = false);
31 31 bool PlotKNN(const QStringList &files, const File &destination, bool show = false);
  32 + bool PlotEER(const QStringList &files, const File &destination, bool show = false);
32 33 }
33 34  
34 35 #endif // BR_PLOT_H
... ...
openbr/openbr.cpp
... ... @@ -227,6 +227,11 @@ bool br_plot_knn(int num_files, const char *files[], const char *destination, bo
227 227 return PlotKNN(QtUtils::toStringList(num_files, files), destination, show);
228 228 }
229 229  
  230 +bool br_plot_eer(int num_files, const char *files[], const char *destination, bool show)
  231 +{
  232 + return PlotEER(QtUtils::toStringList(num_files, files), destination, show);
  233 +}
  234 +
230 235 float br_progress()
231 236 {
232 237 return Globals->progress();
... ...
openbr/openbr.h
... ... @@ -97,6 +97,8 @@ BR_EXPORT bool br_plot_metadata(int num_files, const char *files[], const char *
97 97  
98 98 BR_EXPORT bool br_plot_knn(int num_files, const char *files[], const char *destination, bool show = false);
99 99  
  100 +BR_EXPORT bool br_plot_eer(int num_files, const char *files[], const char *destination, bool show = false);
  101 +
100 102 BR_EXPORT float br_progress();
101 103  
102 104 BR_EXPORT void br_read_pipe(const char *pipe, int *argc, char ***argv);
... ...