Commit a00c2d72c9427d4477e409e844948dd6d318d2dc

Authored by bhklein
1 parent 28362462

added ROC to evalEER

app/br/br.cpp
... ... @@ -172,8 +172,8 @@ public:
172 172 check(parc >= 2 && parc <= 3, "Incorrect parameter count for 'evalKNN'.");
173 173 br_eval_knn(parv[0], parv[1], parc > 2 ? parv[2] : "");
174 174 } else if (!strcmp(fun, "evalEER")) {
175   - check(parc >=1 && parc <=3 , "Incorrect parameter count for 'evalEER'.");
176   - br_eval_eer(parv[0], parc > 1 ? parv[1] : "", parc > 2 ? parv[2] : "");
  175 + check(parc >=1 && parc <=4 , "Incorrect parameter count for 'evalEER'.");
  176 + br_eval_eer(parv[0], parc > 1 ? parv[1] : "", parc > 2 ? parv[2] : "", parc > 3 ? parv[3] : "");
177 177 } else if (!strcmp(fun, "pairwiseCompare")) {
178 178 check((parc >= 2) && (parc <= 3), "Incorrect parameter count for 'pairwiseCompare'.");
179 179 br_pairwise_compare(parv[0], parv[1], parc == 3 ? parv[2] : "");
... ...
openbr/core/eval.cpp
... ... @@ -1242,69 +1242,97 @@ void EvalKNN(const QString &amp;knnGraph, const QString &amp;knnTruth, const QString &amp;cs
1242 1242 qDebug("FNIR @ FPIR = 0.01: %.3f", 1-getOperatingPointGivenFAR(operatingPoints, 0.01).TAR);
1243 1243 }
1244 1244  
1245   -void EvalEER(const QString &predictedXML, QString gt_property, QString distribution_property){
  1245 +void EvalEER(const QString &predictedXML, QString gt_property, QString distribution_property, const QString &pdf) {
1246 1246 if (gt_property.isEmpty())
1247   - gt_property = "LivenessGT";
  1247 + gt_property = "LivenessGT";
1248 1248 if (distribution_property.isEmpty())
1249   - distribution_property = "LivenessDistribution";
1250   - double classOneTemplateCount = 0;
  1249 + distribution_property = "LivenessClassScores";
  1250 + int classOneTemplateCount = 0;
1251 1251 const TemplateList templateList(TemplateList::fromGallery(predictedXML));
1252 1252  
1253 1253 QHash<QString, int> gtLabels;
1254 1254 QHash<QString, QList<float> > scores;
1255   - for (double i=0; i<templateList.size(); i++) {
1256   - if (!templateList[i].file.contains(distribution_property) || !templateList[i].file.contains(gt_property))
1257   - continue;
1258   - QString templateKey = templateList[i].file.path() + templateList[i].file.baseName();
1259   - int gtLabel = templateList[i].file.get<int>(gt_property);
1260   - if (gtLabel == 1)
1261   - classOneTemplateCount++;
1262   - QList<float> templateScores = templateList[i].file.getList<float>(distribution_property);
1263   - gtLabels[templateKey] = gtLabel;
1264   - scores[templateKey] = templateScores;
1265   - }
1266   -
1267   - const int numPoints = 200;
1268   - const float stepSize = 100.0/numPoints;
1269   - const double numTemplates = scores.size();
1270   - float thres = 0.0; //Between [0,100]
1271   - float thresNorm = 0.0; //Between [0,1]
1272   - double FA = 0, FR = 0;
1273   - float minDiff = 100;
1274   - float EER = 100;
1275   - float EERThres = 0;
1276   -
1277   - for(int i = 0; i <= numPoints; i++){
1278   - FA = 0, FR = 0;
1279   - thresNorm = thres/100.0;
1280   - foreach(const QString &key, scores.keys()){
1281   - int gtLabel = gtLabels[key];
1282   - //> thresNorm = class 0 (spoof) : < thresNorm = class 1 (genuine)
1283   - if (scores[key][0] >= thresNorm && gtLabel == 0)
1284   - continue;
1285   - else if (scores[key][0] < thresNorm && gtLabel == 1)
1286   - continue;
1287   - else if (scores[key][0] >= thresNorm && gtLabel == 1)
1288   - FR +=1;
1289   - else if (scores[key][0] < thresNorm && gtLabel == 0)
1290   - FA +=1;
1291   - }
1292   - float FAR = FA / fabs(numTemplates - classOneTemplateCount);
1293   - float FRR = FR / float(classOneTemplateCount);
1294   -
1295   - float diff = std::abs(FAR-FRR);
1296   - if (diff < minDiff){
1297   - minDiff = diff;
1298   - EER = (FAR+FRR)/2.0;
1299   - EERThres = thresNorm;
1300   - }
1301   - thres += stepSize;
1302   - }
  1255 + for (int i=0; i<templateList.size(); i++) {
  1256 + if (!templateList[i].file.contains(distribution_property) || !templateList[i].file.contains(gt_property))
  1257 + continue;
  1258 + QString templateKey = templateList[i].file.path() + templateList[i].file.baseName();
  1259 + const int gtLabel = templateList[i].file.get<int>(gt_property);
  1260 + if (gtLabel == 1)
  1261 + classOneTemplateCount++;
  1262 + const QList<float> templateScores = templateList[i].file.getList<float>(distribution_property);
  1263 + gtLabels[templateKey] = gtLabel;
  1264 + scores[templateKey] = templateScores;
  1265 + }
1303 1266  
1304   - qDebug() <<"Class 0 Templates:" << fabs(numTemplates - classOneTemplateCount) << "Class 1 Templates:"
1305   - << classOneTemplateCount << "Total Templates:" << numTemplates;
1306   - qDebug("EER: %.3f @ Threshold %.3f", EER*100, EERThres);
1307   -}
  1267 + const int numPoints = 200;
  1268 + const float stepSize = 100.0/numPoints;
  1269 + const int numTemplates = scores.size();
  1270 + float thres = 0.0; //Between [0,100]
  1271 + float thresNorm = 0.0; //Between [0,1]
  1272 + float minDiff = 100, EER = 100, EERThres = 0;
  1273 + QList<OperatingPoint> operatingPoints;
  1274 +
  1275 + for(int i = 0; i <= numPoints; i++) {
  1276 + int FA = 0, FR = 0;
  1277 + thresNorm = thres/100.0;
  1278 + foreach(const QString &key, scores.keys()) {
  1279 + int gtLabel = gtLabels[key];
  1280 + //> thresNorm = class 0 (spoof) : < thresNorm = class 1 (genuine)
  1281 + if (scores[key][0] >= thresNorm && gtLabel == 0)
  1282 + continue;
  1283 + else if (scores[key][0] < thresNorm && gtLabel == 1)
  1284 + continue;
  1285 + else if (scores[key][0] >= thresNorm && gtLabel == 1)
  1286 + FR +=1;
  1287 + else if (scores[key][0] < thresNorm && gtLabel == 0)
  1288 + FA +=1;
  1289 + }
  1290 + const float FAR = FA / float(numTemplates - classOneTemplateCount);
  1291 + const float FRR = FR / float(classOneTemplateCount);
  1292 + operatingPoints.append(OperatingPoint(thresNorm, FAR, 1-FRR));
  1293 +
  1294 + const float diff = std::abs(FAR-FRR);
  1295 + if (diff < minDiff) {
  1296 + minDiff = diff;
  1297 + EER = (FAR+FRR)/2.0;
  1298 + EERThres = thresNorm;
  1299 + }
  1300 + thres += stepSize;
  1301 + }
  1302 +
  1303 + printf("Class 0 Templates: %d\tClass 1 Templates: %d\tTotal Templates: %d\n",
  1304 + numTemplates-classOneTemplateCount, classOneTemplateCount, numTemplates);
  1305 + foreach (float FAR, QList<float>() << 0.1 << 0.01 << 0.001 << 0.0001) {
  1306 + const OperatingPoint op = getOperatingPointGivenFAR(operatingPoints, FAR);
  1307 + printf("TAR & Score @ FAR = %.0e: %.3f %.3f\n", FAR, op.TAR, op.score);
  1308 + }
  1309 + printf("EER: %.3f @ Threshold %.3f\n", EER*100, EERThres);
  1310 +
  1311 + // Optionally write ROC curve
  1312 + if (!pdf.isEmpty()) {
  1313 + QStringList farValues, tarValues;
  1314 + float expFAR = std::max(ceil(log10(numTemplates - classOneTemplateCount)), 1.0);
  1315 + float FARstep = expFAR / (float)(Max_Points - 1);
  1316 + for (int i=0; i<Max_Points; i++) {
  1317 + float FAR = pow(10, -expFAR + i*FARstep);
  1318 + OperatingPoint op = getOperatingPointGivenFAR(operatingPoints, FAR);
  1319 + farValues.append(QString::number(FAR));
  1320 + tarValues.append(QString::number(op.TAR));
  1321 + }
1308 1322  
  1323 + QStringList rSource;
  1324 + rSource << "# Load libraries" << "library(ggplot2)" << "" << "# Set Data"
  1325 + << "FAR <- c(" + farValues.join(",") + ")"
  1326 + << "TAR <- c(" + tarValues.join(",") + ")"
  1327 + << "data <- data.frame(FAR, TAR)"
  1328 + << "" << "# Construct Plot" << "pdf(\"" + pdf + "\")"
  1329 + << "print(qplot(FAR, TAR, data=data, geom=\"line\") + scale_x_log10() + theme_minimal())"
  1330 + << "dev.off()";
  1331 +
  1332 + QString rFile = "EvalEER.R";
  1333 + QtUtils::writeFile(rFile, rSource);
  1334 + QtUtils::runRScript(rFile);
  1335 + }
  1336 +}
1309 1337  
1310 1338 } // namespace br
... ...
openbr/core/eval.h
... ... @@ -34,7 +34,7 @@ namespace br
34 34 float EvalLandmarking(const QString &predictedGallery, const QString &truthGallery, const QString &csv = "", int normalizationIndexA = 0, int normalizationIndexB = 1, int sampleIndex = 0, int totalExamples = 5); // Return average error
35 35 void EvalRegression(const QString &predictedGallery, const QString &truthGallery, QString predictedProperty = "", QString truthProperty = "");
36 36 void EvalKNN(const QString &knnGraph, const QString &knnTruth, const QString &csv = "");
37   - void EvalEER(const QString &predictedXML, const QString gt_property = "", const QString distribution_property = "");
  37 + void EvalEER(const QString &predictedXML, const QString gt_property = "", const QString distribution_property = "", const QString &pdf = "");
38 38 struct Candidate
39 39 {
40 40 size_t index;
... ...
openbr/openbr.cpp
... ... @@ -146,9 +146,9 @@ void br_eval_knn(const char *knnGraph, const char *knnTruth, const char *csv)
146 146 EvalKNN(knnGraph, knnTruth, csv);
147 147 }
148 148  
149   -void br_eval_eer(const char *predicted_xml, const char *gt_property, const char *distribution_property )
  149 +void br_eval_eer(const char *predicted_xml, const char *gt_property, const char *distribution_property, const char *pdf)
150 150 {
151   - EvalEER(predicted_xml, gt_property, distribution_property);
  151 + EvalEER(predicted_xml, gt_property, distribution_property, pdf);
152 152 }
153 153  
154 154 void br_finalize()
... ...
openbr/openbr.h
... ... @@ -66,7 +66,7 @@ BR_EXPORT void br_eval_regression(const char *predicted_gallery, const char *tru
66 66  
67 67 BR_EXPORT void br_eval_knn(const char *knnGraph, const char *knnTruth, const char *csv = "");
68 68  
69   -BR_EXPORT void br_eval_eer(const char *predicted_xml, const char *gt_property = "", const char *distribution_property = "");
  69 +BR_EXPORT void br_eval_eer(const char *predicted_xml, const char *gt_property = "", const char *distribution_property = "", const char *pdf = "");
70 70  
71 71 BR_EXPORT void br_finalize();
72 72  
... ...