Commit a00c2d72c9427d4477e409e844948dd6d318d2dc
1 parent
28362462
added ROC to evalEER
Showing
5 changed files
with
90 additions
and
62 deletions
app/br/br.cpp
| @@ -172,8 +172,8 @@ public: | @@ -172,8 +172,8 @@ public: | ||
| 172 | check(parc >= 2 && parc <= 3, "Incorrect parameter count for 'evalKNN'."); | 172 | check(parc >= 2 && parc <= 3, "Incorrect parameter count for 'evalKNN'."); |
| 173 | br_eval_knn(parv[0], parv[1], parc > 2 ? parv[2] : ""); | 173 | br_eval_knn(parv[0], parv[1], parc > 2 ? parv[2] : ""); |
| 174 | } else if (!strcmp(fun, "evalEER")) { | 174 | } else if (!strcmp(fun, "evalEER")) { |
| 175 | - check(parc >=1 && parc <=3 , "Incorrect parameter count for 'evalEER'."); | ||
| 176 | - br_eval_eer(parv[0], parc > 1 ? parv[1] : "", parc > 2 ? parv[2] : ""); | 175 | + check(parc >=1 && parc <=4 , "Incorrect parameter count for 'evalEER'."); |
| 176 | + br_eval_eer(parv[0], parc > 1 ? parv[1] : "", parc > 2 ? parv[2] : "", parc > 3 ? parv[3] : ""); | ||
| 177 | } else if (!strcmp(fun, "pairwiseCompare")) { | 177 | } else if (!strcmp(fun, "pairwiseCompare")) { |
| 178 | check((parc >= 2) && (parc <= 3), "Incorrect parameter count for 'pairwiseCompare'."); | 178 | check((parc >= 2) && (parc <= 3), "Incorrect parameter count for 'pairwiseCompare'."); |
| 179 | br_pairwise_compare(parv[0], parv[1], parc == 3 ? parv[2] : ""); | 179 | br_pairwise_compare(parv[0], parv[1], parc == 3 ? parv[2] : ""); |
openbr/core/eval.cpp
| @@ -1242,69 +1242,97 @@ void EvalKNN(const QString &knnGraph, const QString &knnTruth, const QString &cs | @@ -1242,69 +1242,97 @@ void EvalKNN(const QString &knnGraph, const QString &knnTruth, const QString &cs | ||
| 1242 | qDebug("FNIR @ FPIR = 0.01: %.3f", 1-getOperatingPointGivenFAR(operatingPoints, 0.01).TAR); | 1242 | qDebug("FNIR @ FPIR = 0.01: %.3f", 1-getOperatingPointGivenFAR(operatingPoints, 0.01).TAR); |
| 1243 | } | 1243 | } |
| 1244 | 1244 | ||
| 1245 | -void EvalEER(const QString &predictedXML, QString gt_property, QString distribution_property){ | 1245 | +void EvalEER(const QString &predictedXML, QString gt_property, QString distribution_property, const QString &pdf) { |
| 1246 | if (gt_property.isEmpty()) | 1246 | if (gt_property.isEmpty()) |
| 1247 | - gt_property = "LivenessGT"; | 1247 | + gt_property = "LivenessGT"; |
| 1248 | if (distribution_property.isEmpty()) | 1248 | if (distribution_property.isEmpty()) |
| 1249 | - distribution_property = "LivenessDistribution"; | ||
| 1250 | - double classOneTemplateCount = 0; | 1249 | + distribution_property = "LivenessClassScores"; |
| 1250 | + int classOneTemplateCount = 0; | ||
| 1251 | const TemplateList templateList(TemplateList::fromGallery(predictedXML)); | 1251 | const TemplateList templateList(TemplateList::fromGallery(predictedXML)); |
| 1252 | 1252 | ||
| 1253 | QHash<QString, int> gtLabels; | 1253 | QHash<QString, int> gtLabels; |
| 1254 | QHash<QString, QList<float> > scores; | 1254 | QHash<QString, QList<float> > scores; |
| 1255 | - for (double i=0; i<templateList.size(); i++) { | ||
| 1256 | - if (!templateList[i].file.contains(distribution_property) || !templateList[i].file.contains(gt_property)) | ||
| 1257 | - continue; | ||
| 1258 | - QString templateKey = templateList[i].file.path() + templateList[i].file.baseName(); | ||
| 1259 | - int gtLabel = templateList[i].file.get<int>(gt_property); | ||
| 1260 | - if (gtLabel == 1) | ||
| 1261 | - classOneTemplateCount++; | ||
| 1262 | - QList<float> templateScores = templateList[i].file.getList<float>(distribution_property); | ||
| 1263 | - gtLabels[templateKey] = gtLabel; | ||
| 1264 | - scores[templateKey] = templateScores; | ||
| 1265 | - } | ||
| 1266 | - | ||
| 1267 | - const int numPoints = 200; | ||
| 1268 | - const float stepSize = 100.0/numPoints; | ||
| 1269 | - const double numTemplates = scores.size(); | ||
| 1270 | - float thres = 0.0; //Between [0,100] | ||
| 1271 | - float thresNorm = 0.0; //Between [0,1] | ||
| 1272 | - double FA = 0, FR = 0; | ||
| 1273 | - float minDiff = 100; | ||
| 1274 | - float EER = 100; | ||
| 1275 | - float EERThres = 0; | ||
| 1276 | - | ||
| 1277 | - for(int i = 0; i <= numPoints; i++){ | ||
| 1278 | - FA = 0, FR = 0; | ||
| 1279 | - thresNorm = thres/100.0; | ||
| 1280 | - foreach(const QString &key, scores.keys()){ | ||
| 1281 | - int gtLabel = gtLabels[key]; | ||
| 1282 | - //> thresNorm = class 0 (spoof) : < thresNorm = class 1 (genuine) | ||
| 1283 | - if (scores[key][0] >= thresNorm && gtLabel == 0) | ||
| 1284 | - continue; | ||
| 1285 | - else if (scores[key][0] < thresNorm && gtLabel == 1) | ||
| 1286 | - continue; | ||
| 1287 | - else if (scores[key][0] >= thresNorm && gtLabel == 1) | ||
| 1288 | - FR +=1; | ||
| 1289 | - else if (scores[key][0] < thresNorm && gtLabel == 0) | ||
| 1290 | - FA +=1; | ||
| 1291 | - } | ||
| 1292 | - float FAR = FA / fabs(numTemplates - classOneTemplateCount); | ||
| 1293 | - float FRR = FR / float(classOneTemplateCount); | ||
| 1294 | - | ||
| 1295 | - float diff = std::abs(FAR-FRR); | ||
| 1296 | - if (diff < minDiff){ | ||
| 1297 | - minDiff = diff; | ||
| 1298 | - EER = (FAR+FRR)/2.0; | ||
| 1299 | - EERThres = thresNorm; | ||
| 1300 | - } | ||
| 1301 | - thres += stepSize; | ||
| 1302 | - } | 1255 | + for (int i=0; i<templateList.size(); i++) { |
| 1256 | + if (!templateList[i].file.contains(distribution_property) || !templateList[i].file.contains(gt_property)) | ||
| 1257 | + continue; | ||
| 1258 | + QString templateKey = templateList[i].file.path() + templateList[i].file.baseName(); | ||
| 1259 | + const int gtLabel = templateList[i].file.get<int>(gt_property); | ||
| 1260 | + if (gtLabel == 1) | ||
| 1261 | + classOneTemplateCount++; | ||
| 1262 | + const QList<float> templateScores = templateList[i].file.getList<float>(distribution_property); | ||
| 1263 | + gtLabels[templateKey] = gtLabel; | ||
| 1264 | + scores[templateKey] = templateScores; | ||
| 1265 | + } | ||
| 1303 | 1266 | ||
| 1304 | - qDebug() <<"Class 0 Templates:" << fabs(numTemplates - classOneTemplateCount) << "Class 1 Templates:" | ||
| 1305 | - << classOneTemplateCount << "Total Templates:" << numTemplates; | ||
| 1306 | - qDebug("EER: %.3f @ Threshold %.3f", EER*100, EERThres); | ||
| 1307 | -} | 1267 | + const int numPoints = 200; |
| 1268 | + const float stepSize = 100.0/numPoints; | ||
| 1269 | + const int numTemplates = scores.size(); | ||
| 1270 | + float thres = 0.0; //Between [0,100] | ||
| 1271 | + float thresNorm = 0.0; //Between [0,1] | ||
| 1272 | + float minDiff = 100, EER = 100, EERThres = 0; | ||
| 1273 | + QList<OperatingPoint> operatingPoints; | ||
| 1274 | + | ||
| 1275 | + for(int i = 0; i <= numPoints; i++) { | ||
| 1276 | + int FA = 0, FR = 0; | ||
| 1277 | + thresNorm = thres/100.0; | ||
| 1278 | + foreach(const QString &key, scores.keys()) { | ||
| 1279 | + int gtLabel = gtLabels[key]; | ||
| 1280 | + //> thresNorm = class 0 (spoof) : < thresNorm = class 1 (genuine) | ||
| 1281 | + if (scores[key][0] >= thresNorm && gtLabel == 0) | ||
| 1282 | + continue; | ||
| 1283 | + else if (scores[key][0] < thresNorm && gtLabel == 1) | ||
| 1284 | + continue; | ||
| 1285 | + else if (scores[key][0] >= thresNorm && gtLabel == 1) | ||
| 1286 | + FR +=1; | ||
| 1287 | + else if (scores[key][0] < thresNorm && gtLabel == 0) | ||
| 1288 | + FA +=1; | ||
| 1289 | + } | ||
| 1290 | + const float FAR = FA / float(numTemplates - classOneTemplateCount); | ||
| 1291 | + const float FRR = FR / float(classOneTemplateCount); | ||
| 1292 | + operatingPoints.append(OperatingPoint(thresNorm, FAR, 1-FRR)); | ||
| 1293 | + | ||
| 1294 | + const float diff = std::abs(FAR-FRR); | ||
| 1295 | + if (diff < minDiff) { | ||
| 1296 | + minDiff = diff; | ||
| 1297 | + EER = (FAR+FRR)/2.0; | ||
| 1298 | + EERThres = thresNorm; | ||
| 1299 | + } | ||
| 1300 | + thres += stepSize; | ||
| 1301 | + } | ||
| 1302 | + | ||
| 1303 | + printf("Class 0 Templates: %d\tClass 1 Templates: %d\tTotal Templates: %d\n", | ||
| 1304 | + numTemplates-classOneTemplateCount, classOneTemplateCount, numTemplates); | ||
| 1305 | + foreach (float FAR, QList<float>() << 0.1 << 0.01 << 0.001 << 0.0001) { | ||
| 1306 | + const OperatingPoint op = getOperatingPointGivenFAR(operatingPoints, FAR); | ||
| 1307 | + printf("TAR & Score @ FAR = %.0e: %.3f %.3f\n", FAR, op.TAR, op.score); | ||
| 1308 | + } | ||
| 1309 | + printf("EER: %.3f @ Threshold %.3f\n", EER*100, EERThres); | ||
| 1310 | + | ||
| 1311 | + // Optionally write ROC curve | ||
| 1312 | + if (!pdf.isEmpty()) { | ||
| 1313 | + QStringList farValues, tarValues; | ||
| 1314 | + float expFAR = std::max(ceil(log10(numTemplates - classOneTemplateCount)), 1.0); | ||
| 1315 | + float FARstep = expFAR / (float)(Max_Points - 1); | ||
| 1316 | + for (int i=0; i<Max_Points; i++) { | ||
| 1317 | + float FAR = pow(10, -expFAR + i*FARstep); | ||
| 1318 | + OperatingPoint op = getOperatingPointGivenFAR(operatingPoints, FAR); | ||
| 1319 | + farValues.append(QString::number(FAR)); | ||
| 1320 | + tarValues.append(QString::number(op.TAR)); | ||
| 1321 | + } | ||
| 1308 | 1322 | ||
| 1323 | + QStringList rSource; | ||
| 1324 | + rSource << "# Load libraries" << "library(ggplot2)" << "" << "# Set Data" | ||
| 1325 | + << "FAR <- c(" + farValues.join(",") + ")" | ||
| 1326 | + << "TAR <- c(" + tarValues.join(",") + ")" | ||
| 1327 | + << "data <- data.frame(FAR, TAR)" | ||
| 1328 | + << "" << "# Construct Plot" << "pdf(\"" + pdf + "\")" | ||
| 1329 | + << "print(qplot(FAR, TAR, data=data, geom=\"line\") + scale_x_log10() + theme_minimal())" | ||
| 1330 | + << "dev.off()"; | ||
| 1331 | + | ||
| 1332 | + QString rFile = "EvalEER.R"; | ||
| 1333 | + QtUtils::writeFile(rFile, rSource); | ||
| 1334 | + QtUtils::runRScript(rFile); | ||
| 1335 | + } | ||
| 1336 | +} | ||
| 1309 | 1337 | ||
| 1310 | } // namespace br | 1338 | } // namespace br |
openbr/core/eval.h
| @@ -34,7 +34,7 @@ namespace br | @@ -34,7 +34,7 @@ namespace br | ||
| 34 | float EvalLandmarking(const QString &predictedGallery, const QString &truthGallery, const QString &csv = "", int normalizationIndexA = 0, int normalizationIndexB = 1, int sampleIndex = 0, int totalExamples = 5); // Return average error | 34 | float EvalLandmarking(const QString &predictedGallery, const QString &truthGallery, const QString &csv = "", int normalizationIndexA = 0, int normalizationIndexB = 1, int sampleIndex = 0, int totalExamples = 5); // Return average error |
| 35 | void EvalRegression(const QString &predictedGallery, const QString &truthGallery, QString predictedProperty = "", QString truthProperty = ""); | 35 | void EvalRegression(const QString &predictedGallery, const QString &truthGallery, QString predictedProperty = "", QString truthProperty = ""); |
| 36 | void EvalKNN(const QString &knnGraph, const QString &knnTruth, const QString &csv = ""); | 36 | void EvalKNN(const QString &knnGraph, const QString &knnTruth, const QString &csv = ""); |
| 37 | - void EvalEER(const QString &predictedXML, const QString gt_property = "", const QString distribution_property = ""); | 37 | + void EvalEER(const QString &predictedXML, const QString gt_property = "", const QString distribution_property = "", const QString &pdf = ""); |
| 38 | struct Candidate | 38 | struct Candidate |
| 39 | { | 39 | { |
| 40 | size_t index; | 40 | size_t index; |
openbr/openbr.cpp
| @@ -146,9 +146,9 @@ void br_eval_knn(const char *knnGraph, const char *knnTruth, const char *csv) | @@ -146,9 +146,9 @@ void br_eval_knn(const char *knnGraph, const char *knnTruth, const char *csv) | ||
| 146 | EvalKNN(knnGraph, knnTruth, csv); | 146 | EvalKNN(knnGraph, knnTruth, csv); |
| 147 | } | 147 | } |
| 148 | 148 | ||
| 149 | -void br_eval_eer(const char *predicted_xml, const char *gt_property, const char *distribution_property ) | 149 | +void br_eval_eer(const char *predicted_xml, const char *gt_property, const char *distribution_property, const char *pdf) |
| 150 | { | 150 | { |
| 151 | - EvalEER(predicted_xml, gt_property, distribution_property); | 151 | + EvalEER(predicted_xml, gt_property, distribution_property, pdf); |
| 152 | } | 152 | } |
| 153 | 153 | ||
| 154 | void br_finalize() | 154 | void br_finalize() |
openbr/openbr.h
| @@ -66,7 +66,7 @@ BR_EXPORT void br_eval_regression(const char *predicted_gallery, const char *tru | @@ -66,7 +66,7 @@ BR_EXPORT void br_eval_regression(const char *predicted_gallery, const char *tru | ||
| 66 | 66 | ||
| 67 | BR_EXPORT void br_eval_knn(const char *knnGraph, const char *knnTruth, const char *csv = ""); | 67 | BR_EXPORT void br_eval_knn(const char *knnGraph, const char *knnTruth, const char *csv = ""); |
| 68 | 68 | ||
| 69 | -BR_EXPORT void br_eval_eer(const char *predicted_xml, const char *gt_property = "", const char *distribution_property = ""); | 69 | +BR_EXPORT void br_eval_eer(const char *predicted_xml, const char *gt_property = "", const char *distribution_property = "", const char *pdf = ""); |
| 70 | 70 | ||
| 71 | BR_EXPORT void br_finalize(); | 71 | BR_EXPORT void br_finalize(); |
| 72 | 72 |