From a00c2d72c9427d4477e409e844948dd6d318d2dc Mon Sep 17 00:00:00 2001 From: bhklein Date: Mon, 30 Apr 2018 15:47:21 -0400 Subject: [PATCH] added ROC to evalEER --- app/br/br.cpp | 4 ++-- openbr/core/eval.cpp | 140 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------------------------------------------- openbr/core/eval.h | 2 +- openbr/openbr.cpp | 4 ++-- openbr/openbr.h | 2 +- 5 files changed, 90 insertions(+), 62 deletions(-) diff --git a/app/br/br.cpp b/app/br/br.cpp index 89cd232..89ed33e 100644 --- a/app/br/br.cpp +++ b/app/br/br.cpp @@ -172,8 +172,8 @@ public: check(parc >= 2 && parc <= 3, "Incorrect parameter count for 'evalKNN'."); br_eval_knn(parv[0], parv[1], parc > 2 ? parv[2] : ""); } else if (!strcmp(fun, "evalEER")) { - check(parc >=1 && parc <=3 , "Incorrect parameter count for 'evalEER'."); - br_eval_eer(parv[0], parc > 1 ? parv[1] : "", parc > 2 ? parv[2] : ""); + check(parc >=1 && parc <=4 , "Incorrect parameter count for 'evalEER'."); + br_eval_eer(parv[0], parc > 1 ? parv[1] : "", parc > 2 ? parv[2] : "", parc > 3 ? parv[3] : ""); } else if (!strcmp(fun, "pairwiseCompare")) { check((parc >= 2) && (parc <= 3), "Incorrect parameter count for 'pairwiseCompare'."); br_pairwise_compare(parv[0], parv[1], parc == 3 ? parv[2] : ""); diff --git a/openbr/core/eval.cpp b/openbr/core/eval.cpp index a8b7199..95de62c 100755 --- a/openbr/core/eval.cpp +++ b/openbr/core/eval.cpp @@ -1242,69 +1242,97 @@ void EvalKNN(const QString &knnGraph, const QString &knnTruth, const QString &cs qDebug("FNIR @ FPIR = 0.01: %.3f", 1-getOperatingPointGivenFAR(operatingPoints, 0.01).TAR); } -void EvalEER(const QString &predictedXML, QString gt_property, QString distribution_property){ +void EvalEER(const QString &predictedXML, QString gt_property, QString distribution_property, const QString &pdf) { if (gt_property.isEmpty()) - gt_property = "LivenessGT"; + gt_property = "LivenessGT"; if (distribution_property.isEmpty()) - distribution_property = "LivenessDistribution"; - double classOneTemplateCount = 0; + distribution_property = "LivenessClassScores"; + int classOneTemplateCount = 0; const TemplateList templateList(TemplateList::fromGallery(predictedXML)); QHash gtLabels; QHash > scores; - for (double i=0; i(gt_property); - if (gtLabel == 1) - classOneTemplateCount++; - QList templateScores = templateList[i].file.getList(distribution_property); - gtLabels[templateKey] = gtLabel; - scores[templateKey] = templateScores; - } - - const int numPoints = 200; - const float stepSize = 100.0/numPoints; - const double numTemplates = scores.size(); - float thres = 0.0; //Between [0,100] - float thresNorm = 0.0; //Between [0,1] - double FA = 0, FR = 0; - float minDiff = 100; - float EER = 100; - float EERThres = 0; - - for(int i = 0; i <= numPoints; i++){ - FA = 0, FR = 0; - thresNorm = thres/100.0; - foreach(const QString &key, scores.keys()){ - int gtLabel = gtLabels[key]; - //> thresNorm = class 0 (spoof) : < thresNorm = class 1 (genuine) - if (scores[key][0] >= thresNorm && gtLabel == 0) - continue; - else if (scores[key][0] < thresNorm && gtLabel == 1) - continue; - else if (scores[key][0] >= thresNorm && gtLabel == 1) - FR +=1; - else if (scores[key][0] < thresNorm && gtLabel == 0) - FA +=1; - } - float FAR = FA / fabs(numTemplates - classOneTemplateCount); - float FRR = FR / float(classOneTemplateCount); - - float diff = std::abs(FAR-FRR); - if (diff < minDiff){ - minDiff = diff; - EER = (FAR+FRR)/2.0; - EERThres = thresNorm; - } - thres += stepSize; - } + for (int i=0; i(gt_property); + if (gtLabel == 1) + classOneTemplateCount++; + const QList templateScores = templateList[i].file.getList(distribution_property); + gtLabels[templateKey] = gtLabel; + scores[templateKey] = templateScores; + } - qDebug() <<"Class 0 Templates:" << fabs(numTemplates - classOneTemplateCount) << "Class 1 Templates:" - << classOneTemplateCount << "Total Templates:" << numTemplates; - qDebug("EER: %.3f @ Threshold %.3f", EER*100, EERThres); -} + const int numPoints = 200; + const float stepSize = 100.0/numPoints; + const int numTemplates = scores.size(); + float thres = 0.0; //Between [0,100] + float thresNorm = 0.0; //Between [0,1] + float minDiff = 100, EER = 100, EERThres = 0; + QList operatingPoints; + + for(int i = 0; i <= numPoints; i++) { + int FA = 0, FR = 0; + thresNorm = thres/100.0; + foreach(const QString &key, scores.keys()) { + int gtLabel = gtLabels[key]; + //> thresNorm = class 0 (spoof) : < thresNorm = class 1 (genuine) + if (scores[key][0] >= thresNorm && gtLabel == 0) + continue; + else if (scores[key][0] < thresNorm && gtLabel == 1) + continue; + else if (scores[key][0] >= thresNorm && gtLabel == 1) + FR +=1; + else if (scores[key][0] < thresNorm && gtLabel == 0) + FA +=1; + } + const float FAR = FA / float(numTemplates - classOneTemplateCount); + const float FRR = FR / float(classOneTemplateCount); + operatingPoints.append(OperatingPoint(thresNorm, FAR, 1-FRR)); + + const float diff = std::abs(FAR-FRR); + if (diff < minDiff) { + minDiff = diff; + EER = (FAR+FRR)/2.0; + EERThres = thresNorm; + } + thres += stepSize; + } + + printf("Class 0 Templates: %d\tClass 1 Templates: %d\tTotal Templates: %d\n", + numTemplates-classOneTemplateCount, classOneTemplateCount, numTemplates); + foreach (float FAR, QList() << 0.1 << 0.01 << 0.001 << 0.0001) { + const OperatingPoint op = getOperatingPointGivenFAR(operatingPoints, FAR); + printf("TAR & Score @ FAR = %.0e: %.3f %.3f\n", FAR, op.TAR, op.score); + } + printf("EER: %.3f @ Threshold %.3f\n", EER*100, EERThres); + + // Optionally write ROC curve + if (!pdf.isEmpty()) { + QStringList farValues, tarValues; + float expFAR = std::max(ceil(log10(numTemplates - classOneTemplateCount)), 1.0); + float FARstep = expFAR / (float)(Max_Points - 1); + for (int i=0; i