From be9fd90a2828e7cd85a37e1f667e8df0e5c38a9d Mon Sep 17 00:00:00 2001 From: Jordan Cheney Date: Mon, 25 Jan 2021 16:18:05 +0000 Subject: [PATCH] This PR makes 2 changes to evalEER and adds a new, corresponding, plotEER method. The C API and command line tool are updated to support the new function. --- app/br/br.cpp | 4 ++++ openbr/core/eval.cpp | 185 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------------------------------------------------- openbr/core/eval.h | 2 +- openbr/core/plot.cpp | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ openbr/core/plot.h | 1 + openbr/openbr.cpp | 5 +++++ openbr/openbr.h | 2 ++ 7 files changed, 187 insertions(+), 60 deletions(-) diff --git a/app/br/br.cpp b/app/br/br.cpp index 56c56c6..da50621 100644 --- a/app/br/br.cpp +++ b/app/br/br.cpp @@ -192,6 +192,9 @@ public: } else if (!strcmp(fun, "plotKNN")) { check(parc >=2, "Incorrect parameter count for 'plotKNN'."); br_plot_knn(parc-1, parv, parv[parc-1], true); + } else if (!strcmp(fun, "plotEER")) { + check(parc >= 2, "Incorrect parameter count for 'plotEER'."); + br_plot_eer(parc-1, parv, parv[parc-1], true); } else if (!strcmp(fun, "project")) { check(parc == 2, "Insufficient parameter count for 'project'."); br_project(parv[0], parv[1]); @@ -298,6 +301,7 @@ private: "-plotLandmarking ... {destination}\n" "-plotMetadata ... \n" "-plotKNN ... {destination}\n" + "-plotEER ... {destination}\n" "-project {output_gallery}\n" "-deduplicate \n" "-likely \n" diff --git a/openbr/core/eval.cpp b/openbr/core/eval.cpp index e3aa127..acf0cd4 100755 --- a/openbr/core/eval.cpp +++ b/openbr/core/eval.cpp @@ -377,8 +377,8 @@ float Evaluate(const Mat &simmat, const Mat &mask, const File &csv, const QStrin // Write TAR@FAR Table (TF) foreach (float FAR, QList() << 1e-6 << 1e-5 << 1e-4 << 1e-3 << 1e-2 << 1e-1) lines.append(qPrintable(QString("TF,%1,%2").arg( - QString::number(FAR, 'f'), - QString::number(getOperatingPoint(operatingPoints, "FAR", FAR).TAR, 'f', 3)))); + QString::number(FAR, 'f'), + QString::number(getOperatingPoint(operatingPoints, "FAR", FAR).TAR, 'f', 3)))); // Write FAR@TAR Table (FT) foreach (float TAR, QList() << 0.4 << 0.5 << 0.65 << 0.75 << 0.85 << 0.95) @@ -1255,7 +1255,7 @@ void EvalKNN(const QString &knnGraph, const QString &knnTruth, const QString &cs qDebug("FNIR @ FPIR = 0.01: %.3f", 1-getOperatingPoint(operatingPoints, "FAR", 0.01).TAR); } -void EvalEER(const QString &predictedXML, QString gt_property, QString distribution_property, const QString &pdf) { +void EvalEER(const QString &predictedXML, QString gt_property, QString distribution_property, const QString &csv) { if (gt_property.isEmpty()) gt_property = "LivenessGT"; if (distribution_property.isEmpty()) @@ -1263,59 +1263,79 @@ void EvalEER(const QString &predictedXML, QString gt_property, QString distribut int classOneTemplateCount = 0; const TemplateList templateList(TemplateList::fromGallery(predictedXML)); - QHash gtLabels; - QHash scores; + QList> scores; + QList classZeroScores, classOneScores; for (int i=0; i(gt_property); - if (gtLabel == 1) + const float templateScore = templateList[i].file.get(distribution_property); + scores.append(qMakePair(templateScore, gtLabel)); + + if (gtLabel == 1) { classOneTemplateCount++; - const float templateScores = templateList[i].file.get(distribution_property); - gtLabels[templateKey] = gtLabel; - scores[templateKey] = templateScores; + classOneScores.append(templateScore); + } else { + classZeroScores.append(templateScore); + } } - const int numPoints = 200; - const float stepSize = 100.0/numPoints; - const int numTemplates = scores.size(); - float thres = 0.0; //Between [0,100] - float thresNorm = 0.0; //Between [0,1] - float minDiff = 100, EER = 100, EERThres = 0; + std::sort(scores.begin(), scores.end()); + QList operatingPoints; - for(int i = 0; i <= numPoints; i++) { - int FA = 0, FR = 0; - thresNorm = thres/100.0; - foreach(const QString &key, scores.keys()) { - int gtLabel = gtLabels[key]; - //> thresNorm = class 0 (spoof) : < thresNorm = class 1 (genuine) - if (scores[key] >= thresNorm && gtLabel == 0) - continue; - else if (scores[key] < thresNorm && gtLabel == 1) - continue; - else if (scores[key] >= thresNorm && gtLabel == 1) - FR +=1; - else if (scores[key] < thresNorm && gtLabel == 0) - FA +=1; + const int classZeroTemplateCount = scores.size() - classOneTemplateCount; + int falsePositives = 0, previousFalsePositives = 0; + int truePositives = 0, previousTruePositives = 0; + size_t index = 0; + float minDiff = 100, EER = 100, EERThres = 0; + float minClassOneScore = std::numeric_limits::max(); + float minClassZeroScore = std::numeric_limits::max(); + + while (index < scores.size()) { + float thresh = scores[index].first; + // Compute genuine and imposter statistics at a threshold + while ((index < scores.size()) && + (scores[index].first == thresh)) { + if (scores[index].second) { + truePositives++; + if (scores[index].first != -std::numeric_limits::max() && scores[index].first < minClassOneScore) + minClassOneScore = scores[index].first; + } else { + falsePositives++; + if (scores[index].first != -std::numeric_limits::max() && scores[index].first < minClassZeroScore) + minClassZeroScore = scores[index].first; + } + index++; } - const float FAR = FA / float(numTemplates - classOneTemplateCount); - const float FRR = FR / float(classOneTemplateCount); - operatingPoints.append(OperatingPoint(thresNorm, FAR, 1-FRR)); - - const float diff = std::abs(FAR-FRR); - if (diff < minDiff) { - minDiff = diff; - EER = (FAR+FRR)/2.0; - EERThres = thresNorm; + + if ((falsePositives > previousFalsePositives) && + (truePositives > previousTruePositives)) { + const float FAR = float(falsePositives) / classZeroTemplateCount; + const float TAR = float(truePositives) / classOneTemplateCount; + const float FRR = 1 - TAR; + operatingPoints.append(OperatingPoint(thresh, FAR, TAR)); + + const float diff = std::abs(FAR-FRR); + if (diff < minDiff) { + minDiff = diff; + EER = (FAR+FRR)/2.0; + EERThres = thresh; + } + + previousFalsePositives = falsePositives; + previousTruePositives = truePositives; } - thres += stepSize; } + if (operatingPoints.size() == 0) operatingPoints.append(OperatingPoint(1, 1, 1)); + if (operatingPoints.size() == 1) operatingPoints.prepend(OperatingPoint(0, 0, 0)); + if (operatingPoints.size() > 2) operatingPoints.takeLast(); // Remove point (1,1) + printf("\n==========================================================\n"); printf("Class 0 Templates: %d\tClass 1 Templates: %d\tTotal Templates: %d\n", - numTemplates-classOneTemplateCount, classOneTemplateCount, numTemplates); + classZeroTemplateCount, classOneTemplateCount, classZeroTemplateCount + classOneTemplateCount); printf("----------------------------------------------------------\n"); foreach (float FAR, QList() << 0.2 << 0.1 << 0.05 << 0.01 << 0.001 << 0.0001) { const OperatingPoint op = getOperatingPoint(operatingPoints, "FAR", FAR); @@ -1333,29 +1353,76 @@ void EvalEER(const QString &predictedXML, QString gt_property, QString distribut printf("==========================================================\n\n"); // Optionally write ROC curve - if (!pdf.isEmpty()) { - QStringList farValues, tarValues; - float expFAR = std::max(ceil(log10(numTemplates - classOneTemplateCount)), 1.0); + if (!csv.isEmpty()) { + QStringList lines; + lines.append("Plot,X,Y"); + lines.append("Metadata,"+QString::number(classZeroTemplateCount+classOneTemplateCount)+",Total Templates"); + lines.append("Metadata,"+QString::number(classZeroTemplateCount)+",Class 0 Template Count"); + lines.append("Metadata,"+QString::number(classOneTemplateCount)+",Class 1 Template Count"); + + // Write Detection Error Tradeoff (DET), PRE, REC + float expFAR = std::max(ceil(log10(classZeroTemplateCount)), 1.0); + float expFRR = std::max(ceil(log10(classOneTemplateCount)), 1.0); + float FARstep = expFAR / (float)(Max_Points - 1); + float FRRstep = expFRR / (float)(Max_Points - 1); + for (int i=0; i() << 0.2 << 0.1 << 0.05 << 0.01 << 0.001 << 0.0001) + lines.append(qPrintable(QString("TF,%1,%2").arg( + QString::number(FAR, 'f'), + QString::number(getOperatingPoint(operatingPoints, "FAR", FAR).TAR, 'f', 3)))); + + // Write FAR@TAR Table (FT) + foreach (float TAR, QList() << 0.8 << 0.85 << 0.9 << 0.95 << 0.98) + lines.append(qPrintable(QString("FT,%1,%2").arg( + QString::number(TAR, 'f', 2), + QString::number(getOperatingPoint(operatingPoints, "TAR", TAR).FAR, 'f', 3)))); + + // Write FAR@Score Table (SF) and TAR@Score table (ST) + foreach(const float score, QList() << 0.05 << 0.1 << 0.15 << 0.2 << 0.25 << 0.3 << 0.35 << 0.4 << 0.45 << 0.5 + << 0.55 << 0.6 << 0.65 << 0.7 << 0.75 << 0.8 << 0.85 << 0.9 << 0.95) { + const OperatingPoint op = getOperatingPoint(operatingPoints, "Score", score); + lines.append(qPrintable(QString("SF,%1,%2").arg( + QString::number(score, 'f', 2), + QString::number(op.FAR)))); + lines.append(qPrintable(QString("ST,%1,%2").arg( + QString::number(score, 'f', 2), + QString::number(op.TAR)))); + } + + // Write FAR/TAR Bar Chart (BC) + lines.append(qPrintable(QString("BC,0.0001,%1").arg(QString::number(getOperatingPoint(operatingPoints, "FAR", 0.0001).TAR, 'f', 3)))); + lines.append(qPrintable(QString("BC,0.001,%1").arg(QString::number(getOperatingPoint(operatingPoints, "FAR", 0.001).TAR, 'f', 3)))); + + // Write SD & KDE + int points = qMin(qMin(Max_Points, classZeroScores.size()), classOneScores.size()); + if (points > 1) { + for (int i=0; i::max()) classZeroScore = minClassZeroScore; + if (classOneScore == -std::numeric_limits::max()) classOneScore = minClassOneScore; + lines.append(QString("SD,%1,Genuine").arg(QString::number(classOneScore))); + lines.append(QString("SD,%1,Impostor").arg(QString::number(classZeroScore))); + } } - QStringList rSource; - rSource << "# Load libraries" << "library(ggplot2)" << "" << "# Set Data" - << "FAR <- c(" + farValues.join(",") + ")" - << "TAR <- c(" + tarValues.join(",") + ")" - << "data <- data.frame(FAR, TAR)" - << "" << "# Construct Plot" << "pdf(\"" + pdf + "\")" - << "print(qplot(FAR, TAR, data=data, geom=\"line\") + scale_x_log10() + theme_minimal())" - << "dev.off()"; - - QString rFile = "EvalEER.R"; - QtUtils::writeFile(rFile, rSource); - QtUtils::runRScript(rFile); + QtUtils::writeFile(csv, lines); } } diff --git a/openbr/core/eval.h b/openbr/core/eval.h index 52fc404..927a45d 100644 --- a/openbr/core/eval.h +++ b/openbr/core/eval.h @@ -34,7 +34,7 @@ namespace br float EvalLandmarking(const QString &predictedGallery, const QString &truthGallery, const QString &csv = "", int normalizationIndexA = 0, int normalizationIndexB = 1, int sampleIndex = 0, int totalExamples = 5); // Return average error void EvalRegression(const QString &predictedGallery, const QString &truthGallery, QString predictedProperty = "", QString truthProperty = ""); void EvalKNN(const QString &knnGraph, const QString &knnTruth, const QString &csv = ""); - void EvalEER(const QString &predictedXML, const QString gt_property = "", const QString distribution_property = "", const QString &pdf = ""); + void EvalEER(const QString &predictedXML, const QString gt_property = "", const QString distribution_property = "", const QString &csv = ""); struct Candidate { size_t index; diff --git a/openbr/core/plot.cpp b/openbr/core/plot.cpp index 40fdc44..6fa5bb7 100644 --- a/openbr/core/plot.cpp +++ b/openbr/core/plot.cpp @@ -368,4 +368,52 @@ bool PlotKNN(const QStringList &files, const File &destination, bool show) return p.finalize(show); } +// Does not work if dataset folder starts with a number +bool PlotEER(const QStringList &files, const File &destination, bool show) +{ + qDebug("Plotting %d file(s) to %s", files.size(), qPrintable(destination)); + + RPlot p(files, destination); + p.file.write("\nformatData()\n\n"); + p.file.write(qPrintable(QString("algs <- %1\n").arg((p.major.size > 1 && p.minor.size > 1) && !(p.major.smooth || p.minor.smooth) ? QString("paste(TF$%1, TF$%2, sep=\"_\")").arg(p.major.header, p.minor.header) + : QString("TF$%1").arg(p.major.size > 1 ? p.major.header : (p.minor.header.isEmpty() ? p.major.header : p.minor.header))))); + p.file.write("algs <- algs[!duplicated(algs)]\n"); + if (p.major.smooth || p.minor.smooth) { + QString groupvar = p.major.size > 1 ? p.major.header : (p.minor.header.isEmpty() ? p.major.header : p.minor.header); + foreach(const QString &data, QStringList() << "DET" << "TF" << "FT") { + p.file.write(qPrintable(QString("%1 <- summarySE(%1, measurevar=\"Y\", groupvars=c(\"%2\", \"X\"), conf.interval=confidence)" + "\n").arg(data, groupvar))); + } + p.file.write(qPrintable(QString("%1 <- summarySE(%1, measurevar=\"X\", groupvars=c(\"Error\", \"%2\", \"Y\"), conf.interval=confidence)" + "\n\n").arg("ERR", groupvar))); + } + + // Use a br::file for simple storage of plot options + QMap optMap; + optMap.insert("rocOptions", File(QString("[xTitle=False Accept Rate,yTitle=True Accept Rate,xLog=true,yLog=false,xLimits=(.0000001,.1)]"))); + optMap.insert("detOptions", File(QString("[xTitle=False Accept Rate,yTitle=False Reject Rate,xLog=true,yLog=true,xLimits=(.0000001,.1),yLimits=(.0001,1)]"))); + optMap.insert("farOptions", File(QString("[xTitle=Score,yTitle=False Accept Rate,xLog=false,yLog=true,xLabels=waiver(),yLimits=(.0000001,1)]"))); + optMap.insert("frrOptions", File(QString("[xTitle=Score,yTitle=False Reject Rate,xLog=false,yLog=true,xLabels=waiver(),yLimits=(.0001,1)]"))); + + foreach (const QString &key, optMap.keys()) { + const QStringList options = destination.get(key, QStringList()); + foreach (const QString &option, options) { + QStringList words = QtUtils::parse(option, '='); + QtUtils::checkArgsSize(words[0], words, 1, 2); + optMap[key].set(words[0], words[1]); + } + } + + // Write plots + QString plot = "plot <- plotLine(lineData=%1, options=list(%2), flipY=%3)\nplot\n"; + p.file.write(qPrintable(QString(plot).arg("DET", toRList(optMap["rocOptions"]), "TRUE"))); + p.file.write(qPrintable(QString(plot).arg("DET", toRList(optMap["detOptions"]), "FALSE"))); + p.file.write("plot <- plotSD(sdData=SD)\nplot\n"); + p.file.write("plot <- plotBC(bcData=BC)\nplot\n"); + p.file.write(qPrintable(QString(plot).arg("FAR", toRList(optMap["farOptions"]), "FALSE"))); + p.file.write(qPrintable(QString(plot).arg("FRR", toRList(optMap["frrOptions"]), "FALSE"))); + + return p.finalize(show); +} + } // namespace br diff --git a/openbr/core/plot.h b/openbr/core/plot.h index 26db428..def7021 100644 --- a/openbr/core/plot.h +++ b/openbr/core/plot.h @@ -29,6 +29,7 @@ namespace br bool PlotLandmarking(const QStringList &files, const File &destination, bool show = false); bool PlotMetadata(const QStringList &files, const QString &destination, bool show = false); bool PlotKNN(const QStringList &files, const File &destination, bool show = false); + bool PlotEER(const QStringList &files, const File &destination, bool show = false); } #endif // BR_PLOT_H diff --git a/openbr/openbr.cpp b/openbr/openbr.cpp index 0de41b6..7238103 100644 --- a/openbr/openbr.cpp +++ b/openbr/openbr.cpp @@ -227,6 +227,11 @@ bool br_plot_knn(int num_files, const char *files[], const char *destination, bo return PlotKNN(QtUtils::toStringList(num_files, files), destination, show); } +bool br_plot_eer(int num_files, const char *files[], const char *destination, bool show) +{ + return PlotEER(QtUtils::toStringList(num_files, files), destination, show); +} + float br_progress() { return Globals->progress(); diff --git a/openbr/openbr.h b/openbr/openbr.h index 4db4e46..b46a880 100644 --- a/openbr/openbr.h +++ b/openbr/openbr.h @@ -97,6 +97,8 @@ BR_EXPORT bool br_plot_metadata(int num_files, const char *files[], const char * BR_EXPORT bool br_plot_knn(int num_files, const char *files[], const char *destination, bool show = false); +BR_EXPORT bool br_plot_eer(int num_files, const char *files[], const char *destination, bool show = false); + BR_EXPORT float br_progress(); BR_EXPORT void br_read_pipe(const char *pipe, int *argc, char ***argv); -- libgit2 0.21.4