Commit a00c2d72c9427d4477e409e844948dd6d318d2dc

Authored by bhklein
1 parent 28362462

added ROC to evalEER

app/br/br.cpp
@@ -172,8 +172,8 @@ public: @@ -172,8 +172,8 @@ public:
172 check(parc >= 2 && parc <= 3, "Incorrect parameter count for 'evalKNN'."); 172 check(parc >= 2 && parc <= 3, "Incorrect parameter count for 'evalKNN'.");
173 br_eval_knn(parv[0], parv[1], parc > 2 ? parv[2] : ""); 173 br_eval_knn(parv[0], parv[1], parc > 2 ? parv[2] : "");
174 } else if (!strcmp(fun, "evalEER")) { 174 } else if (!strcmp(fun, "evalEER")) {
175 - check(parc >=1 && parc <=3 , "Incorrect parameter count for 'evalEER'.");  
176 - br_eval_eer(parv[0], parc > 1 ? parv[1] : "", parc > 2 ? parv[2] : ""); 175 + check(parc >=1 && parc <=4 , "Incorrect parameter count for 'evalEER'.");
  176 + br_eval_eer(parv[0], parc > 1 ? parv[1] : "", parc > 2 ? parv[2] : "", parc > 3 ? parv[3] : "");
177 } else if (!strcmp(fun, "pairwiseCompare")) { 177 } else if (!strcmp(fun, "pairwiseCompare")) {
178 check((parc >= 2) && (parc <= 3), "Incorrect parameter count for 'pairwiseCompare'."); 178 check((parc >= 2) && (parc <= 3), "Incorrect parameter count for 'pairwiseCompare'.");
179 br_pairwise_compare(parv[0], parv[1], parc == 3 ? parv[2] : ""); 179 br_pairwise_compare(parv[0], parv[1], parc == 3 ? parv[2] : "");
openbr/core/eval.cpp
@@ -1242,69 +1242,97 @@ void EvalKNN(const QString &amp;knnGraph, const QString &amp;knnTruth, const QString &amp;cs @@ -1242,69 +1242,97 @@ void EvalKNN(const QString &amp;knnGraph, const QString &amp;knnTruth, const QString &amp;cs
1242 qDebug("FNIR @ FPIR = 0.01: %.3f", 1-getOperatingPointGivenFAR(operatingPoints, 0.01).TAR); 1242 qDebug("FNIR @ FPIR = 0.01: %.3f", 1-getOperatingPointGivenFAR(operatingPoints, 0.01).TAR);
1243 } 1243 }
1244 1244
1245 -void EvalEER(const QString &predictedXML, QString gt_property, QString distribution_property){ 1245 +void EvalEER(const QString &predictedXML, QString gt_property, QString distribution_property, const QString &pdf) {
1246 if (gt_property.isEmpty()) 1246 if (gt_property.isEmpty())
1247 - gt_property = "LivenessGT"; 1247 + gt_property = "LivenessGT";
1248 if (distribution_property.isEmpty()) 1248 if (distribution_property.isEmpty())
1249 - distribution_property = "LivenessDistribution";  
1250 - double classOneTemplateCount = 0; 1249 + distribution_property = "LivenessClassScores";
  1250 + int classOneTemplateCount = 0;
1251 const TemplateList templateList(TemplateList::fromGallery(predictedXML)); 1251 const TemplateList templateList(TemplateList::fromGallery(predictedXML));
1252 1252
1253 QHash<QString, int> gtLabels; 1253 QHash<QString, int> gtLabels;
1254 QHash<QString, QList<float> > scores; 1254 QHash<QString, QList<float> > scores;
1255 - for (double i=0; i<templateList.size(); i++) {  
1256 - if (!templateList[i].file.contains(distribution_property) || !templateList[i].file.contains(gt_property))  
1257 - continue;  
1258 - QString templateKey = templateList[i].file.path() + templateList[i].file.baseName();  
1259 - int gtLabel = templateList[i].file.get<int>(gt_property);  
1260 - if (gtLabel == 1)  
1261 - classOneTemplateCount++;  
1262 - QList<float> templateScores = templateList[i].file.getList<float>(distribution_property);  
1263 - gtLabels[templateKey] = gtLabel;  
1264 - scores[templateKey] = templateScores;  
1265 - }  
1266 -  
1267 - const int numPoints = 200;  
1268 - const float stepSize = 100.0/numPoints;  
1269 - const double numTemplates = scores.size();  
1270 - float thres = 0.0; //Between [0,100]  
1271 - float thresNorm = 0.0; //Between [0,1]  
1272 - double FA = 0, FR = 0;  
1273 - float minDiff = 100;  
1274 - float EER = 100;  
1275 - float EERThres = 0;  
1276 -  
1277 - for(int i = 0; i <= numPoints; i++){  
1278 - FA = 0, FR = 0;  
1279 - thresNorm = thres/100.0;  
1280 - foreach(const QString &key, scores.keys()){  
1281 - int gtLabel = gtLabels[key];  
1282 - //> thresNorm = class 0 (spoof) : < thresNorm = class 1 (genuine)  
1283 - if (scores[key][0] >= thresNorm && gtLabel == 0)  
1284 - continue;  
1285 - else if (scores[key][0] < thresNorm && gtLabel == 1)  
1286 - continue;  
1287 - else if (scores[key][0] >= thresNorm && gtLabel == 1)  
1288 - FR +=1;  
1289 - else if (scores[key][0] < thresNorm && gtLabel == 0)  
1290 - FA +=1;  
1291 - }  
1292 - float FAR = FA / fabs(numTemplates - classOneTemplateCount);  
1293 - float FRR = FR / float(classOneTemplateCount);  
1294 -  
1295 - float diff = std::abs(FAR-FRR);  
1296 - if (diff < minDiff){  
1297 - minDiff = diff;  
1298 - EER = (FAR+FRR)/2.0;  
1299 - EERThres = thresNorm;  
1300 - }  
1301 - thres += stepSize;  
1302 - } 1255 + for (int i=0; i<templateList.size(); i++) {
  1256 + if (!templateList[i].file.contains(distribution_property) || !templateList[i].file.contains(gt_property))
  1257 + continue;
  1258 + QString templateKey = templateList[i].file.path() + templateList[i].file.baseName();
  1259 + const int gtLabel = templateList[i].file.get<int>(gt_property);
  1260 + if (gtLabel == 1)
  1261 + classOneTemplateCount++;
  1262 + const QList<float> templateScores = templateList[i].file.getList<float>(distribution_property);
  1263 + gtLabels[templateKey] = gtLabel;
  1264 + scores[templateKey] = templateScores;
  1265 + }
1303 1266
1304 - qDebug() <<"Class 0 Templates:" << fabs(numTemplates - classOneTemplateCount) << "Class 1 Templates:"  
1305 - << classOneTemplateCount << "Total Templates:" << numTemplates;  
1306 - qDebug("EER: %.3f @ Threshold %.3f", EER*100, EERThres);  
1307 -} 1267 + const int numPoints = 200;
  1268 + const float stepSize = 100.0/numPoints;
  1269 + const int numTemplates = scores.size();
  1270 + float thres = 0.0; //Between [0,100]
  1271 + float thresNorm = 0.0; //Between [0,1]
  1272 + float minDiff = 100, EER = 100, EERThres = 0;
  1273 + QList<OperatingPoint> operatingPoints;
  1274 +
  1275 + for(int i = 0; i <= numPoints; i++) {
  1276 + int FA = 0, FR = 0;
  1277 + thresNorm = thres/100.0;
  1278 + foreach(const QString &key, scores.keys()) {
  1279 + int gtLabel = gtLabels[key];
  1280 + //> thresNorm = class 0 (spoof) : < thresNorm = class 1 (genuine)
  1281 + if (scores[key][0] >= thresNorm && gtLabel == 0)
  1282 + continue;
  1283 + else if (scores[key][0] < thresNorm && gtLabel == 1)
  1284 + continue;
  1285 + else if (scores[key][0] >= thresNorm && gtLabel == 1)
  1286 + FR +=1;
  1287 + else if (scores[key][0] < thresNorm && gtLabel == 0)
  1288 + FA +=1;
  1289 + }
  1290 + const float FAR = FA / float(numTemplates - classOneTemplateCount);
  1291 + const float FRR = FR / float(classOneTemplateCount);
  1292 + operatingPoints.append(OperatingPoint(thresNorm, FAR, 1-FRR));
  1293 +
  1294 + const float diff = std::abs(FAR-FRR);
  1295 + if (diff < minDiff) {
  1296 + minDiff = diff;
  1297 + EER = (FAR+FRR)/2.0;
  1298 + EERThres = thresNorm;
  1299 + }
  1300 + thres += stepSize;
  1301 + }
  1302 +
  1303 + printf("Class 0 Templates: %d\tClass 1 Templates: %d\tTotal Templates: %d\n",
  1304 + numTemplates-classOneTemplateCount, classOneTemplateCount, numTemplates);
  1305 + foreach (float FAR, QList<float>() << 0.1 << 0.01 << 0.001 << 0.0001) {
  1306 + const OperatingPoint op = getOperatingPointGivenFAR(operatingPoints, FAR);
  1307 + printf("TAR & Score @ FAR = %.0e: %.3f %.3f\n", FAR, op.TAR, op.score);
  1308 + }
  1309 + printf("EER: %.3f @ Threshold %.3f\n", EER*100, EERThres);
  1310 +
  1311 + // Optionally write ROC curve
  1312 + if (!pdf.isEmpty()) {
  1313 + QStringList farValues, tarValues;
  1314 + float expFAR = std::max(ceil(log10(numTemplates - classOneTemplateCount)), 1.0);
  1315 + float FARstep = expFAR / (float)(Max_Points - 1);
  1316 + for (int i=0; i<Max_Points; i++) {
  1317 + float FAR = pow(10, -expFAR + i*FARstep);
  1318 + OperatingPoint op = getOperatingPointGivenFAR(operatingPoints, FAR);
  1319 + farValues.append(QString::number(FAR));
  1320 + tarValues.append(QString::number(op.TAR));
  1321 + }
1308 1322
  1323 + QStringList rSource;
  1324 + rSource << "# Load libraries" << "library(ggplot2)" << "" << "# Set Data"
  1325 + << "FAR <- c(" + farValues.join(",") + ")"
  1326 + << "TAR <- c(" + tarValues.join(",") + ")"
  1327 + << "data <- data.frame(FAR, TAR)"
  1328 + << "" << "# Construct Plot" << "pdf(\"" + pdf + "\")"
  1329 + << "print(qplot(FAR, TAR, data=data, geom=\"line\") + scale_x_log10() + theme_minimal())"
  1330 + << "dev.off()";
  1331 +
  1332 + QString rFile = "EvalEER.R";
  1333 + QtUtils::writeFile(rFile, rSource);
  1334 + QtUtils::runRScript(rFile);
  1335 + }
  1336 +}
1309 1337
1310 } // namespace br 1338 } // namespace br
openbr/core/eval.h
@@ -34,7 +34,7 @@ namespace br @@ -34,7 +34,7 @@ namespace br
34 float EvalLandmarking(const QString &predictedGallery, const QString &truthGallery, const QString &csv = "", int normalizationIndexA = 0, int normalizationIndexB = 1, int sampleIndex = 0, int totalExamples = 5); // Return average error 34 float EvalLandmarking(const QString &predictedGallery, const QString &truthGallery, const QString &csv = "", int normalizationIndexA = 0, int normalizationIndexB = 1, int sampleIndex = 0, int totalExamples = 5); // Return average error
35 void EvalRegression(const QString &predictedGallery, const QString &truthGallery, QString predictedProperty = "", QString truthProperty = ""); 35 void EvalRegression(const QString &predictedGallery, const QString &truthGallery, QString predictedProperty = "", QString truthProperty = "");
36 void EvalKNN(const QString &knnGraph, const QString &knnTruth, const QString &csv = ""); 36 void EvalKNN(const QString &knnGraph, const QString &knnTruth, const QString &csv = "");
37 - void EvalEER(const QString &predictedXML, const QString gt_property = "", const QString distribution_property = ""); 37 + void EvalEER(const QString &predictedXML, const QString gt_property = "", const QString distribution_property = "", const QString &pdf = "");
38 struct Candidate 38 struct Candidate
39 { 39 {
40 size_t index; 40 size_t index;
openbr/openbr.cpp
@@ -146,9 +146,9 @@ void br_eval_knn(const char *knnGraph, const char *knnTruth, const char *csv) @@ -146,9 +146,9 @@ void br_eval_knn(const char *knnGraph, const char *knnTruth, const char *csv)
146 EvalKNN(knnGraph, knnTruth, csv); 146 EvalKNN(knnGraph, knnTruth, csv);
147 } 147 }
148 148
149 -void br_eval_eer(const char *predicted_xml, const char *gt_property, const char *distribution_property ) 149 +void br_eval_eer(const char *predicted_xml, const char *gt_property, const char *distribution_property, const char *pdf)
150 { 150 {
151 - EvalEER(predicted_xml, gt_property, distribution_property); 151 + EvalEER(predicted_xml, gt_property, distribution_property, pdf);
152 } 152 }
153 153
154 void br_finalize() 154 void br_finalize()
openbr/openbr.h
@@ -66,7 +66,7 @@ BR_EXPORT void br_eval_regression(const char *predicted_gallery, const char *tru @@ -66,7 +66,7 @@ BR_EXPORT void br_eval_regression(const char *predicted_gallery, const char *tru
66 66
67 BR_EXPORT void br_eval_knn(const char *knnGraph, const char *knnTruth, const char *csv = ""); 67 BR_EXPORT void br_eval_knn(const char *knnGraph, const char *knnTruth, const char *csv = "");
68 68
69 -BR_EXPORT void br_eval_eer(const char *predicted_xml, const char *gt_property = "", const char *distribution_property = ""); 69 +BR_EXPORT void br_eval_eer(const char *predicted_xml, const char *gt_property = "", const char *distribution_property = "", const char *pdf = "");
70 70
71 BR_EXPORT void br_finalize(); 71 BR_EXPORT void br_finalize();
72 72