From 6ebad57b201989a25d35171b914269ac7876eeab Mon Sep 17 00:00:00 2001 From: bhklein Date: Mon, 15 Dec 2014 21:32:14 -0500 Subject: [PATCH] Sample ROC points at fixed FAR values for easy vertical averaging. Display confidence intervals in tables and plots. --- openbr/core/eval.cpp | 68 +++++++++++++++++++++++++++++++++++++------------------------------- openbr/core/plot.cpp | 53 ++++++++++++++++++++++++++++++++++------------------- 2 files changed, 71 insertions(+), 50 deletions(-) diff --git a/openbr/core/eval.cpp b/openbr/core/eval.cpp index e9feae7..10781ab 100755 --- a/openbr/core/eval.cpp +++ b/openbr/core/eval.cpp @@ -48,22 +48,26 @@ struct OperatingPoint : score(_score), FAR(_FAR), TAR(_TAR) {} }; -static float getTAR(const QList &operatingPoints, float FAR) +static OperatingPoint getOperatingPoint(const QList &operatingPoints, float FAR) { int index = 0; while (operatingPoints[index].FAR < FAR) { index++; if (index == operatingPoints.size()) - return 1; + return OperatingPoint(operatingPoints.last().score, FAR, operatingPoints.last().TAR); } - const float x1 = (index == 0 ? 0 : operatingPoints[index-1].FAR); - const float y1 = (index == 0 ? 0 : operatingPoints[index-1].TAR); - const float x2 = operatingPoints[index].FAR; - const float y2 = operatingPoints[index].TAR; - const float m = (y2 - y1) / (x2 - x1); - const float b = y1 - m*x1; - return m * FAR + b; + const float FAR1 = (index == 0 ? 0 : operatingPoints[index-1].FAR); + const float TAR1 = (index == 0 ? 0 : operatingPoints[index-1].TAR); + const float score1 = (index == 0 ? operatingPoints[index].score : operatingPoints[index-1].score); + const float FAR2 = operatingPoints[index].FAR; + const float TAR2 = operatingPoints[index].TAR; + const float score2 = operatingPoints[index].score; + const float mTAR = (TAR2 - TAR1) / (FAR2 - FAR1); + const float bTAR = TAR1 - mTAR*FAR1; + const float mScore = (score2 - score1) / (FAR2 - FAR1); + const float bScore = score1 - mScore*FAR1; + return OperatingPoint(mScore * FAR + bScore,FAR, mTAR * FAR + bTAR); } static float getCMC(const QVector &firstGenuineReturns, int rank) @@ -266,23 +270,26 @@ float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv, const QSt } // Write Detection Error Tradeoff (DET), PRE, REC - int points = qMin(operatingPoints.size(), Max_Points); - for (int i=0; i sampledGenuineScores; sampledGenuineScores.reserve(points); QList sampledImpostorScores; sampledImpostorScores.reserve(points); @@ -332,10 +338,10 @@ float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv, const QSt QtUtils::writeFile(csv, lines); if (maxSize > 0) qDebug("Template Size: %i bytes", (int)maxSize); - qDebug("TAR @ FAR = 0.01: %.3f",getTAR(operatingPoints, 0.01)); - qDebug("TAR @ FAR = 0.001: %.3f",getTAR(operatingPoints, 0.001)); - qDebug("TAR @ FAR = 0.0001: %.3f",getTAR(operatingPoints, 0.0001)); - qDebug("TAR @ FAR = 0.00001: %.3f",getTAR(operatingPoints, 0.00001)); + qDebug("TAR @ FAR = 0.01: %.3f",getOperatingPoint(operatingPoints, 0.01).TAR); + qDebug("TAR @ FAR = 0.001: %.3f",getOperatingPoint(operatingPoints, 0.001).TAR); + qDebug("TAR @ FAR = 0.0001: %.3f",getOperatingPoint(operatingPoints, 0.0001).TAR); + qDebug("TAR @ FAR = 0.00001: %.3f",getOperatingPoint(operatingPoints, 0.00001).TAR); qDebug("\nRetrieval Rate @ Rank = %d: %.3f", Report_Retrieval, getCMC(firstGenuineReturns, Report_Retrieval)); @@ -560,8 +566,8 @@ float InplaceEval(const QString &simmat, const QString &target, const QString &q float result; // Write FAR/TAR Bar Chart (BC) - lines.append(qPrintable(QString("BC,0.001,%1").arg(QString::number(getTAR(operatingPoints, 0.001), 'f', 3)))); - lines.append(qPrintable(QString("BC,0.01,%1").arg(QString::number(result = getTAR(operatingPoints, 0.01), 'f', 3)))); + lines.append(qPrintable(QString("BC,0.001,%1").arg(QString::number(getOperatingPoint(operatingPoints, 0.001).TAR, 'f', 3)))); + lines.append(qPrintable(QString("BC,0.01,%1").arg(QString::number(result = getOperatingPoint(operatingPoints, 0.01).TAR, 'f', 3)))); qDebug("TAR @ FAR = 0.01: %.3f", result); QtUtils::writeFile(csv, lines); diff --git a/openbr/core/plot.cpp b/openbr/core/plot.cpp index 2a22219..17af833 100644 --- a/openbr/core/plot.cpp +++ b/openbr/core/plot.cpp @@ -108,7 +108,6 @@ struct RPlot pivotItems = QVector< QSet >(pivotHeaders.size()); foreach (const QString &fileName, files) { QStringList pivots = getPivots(fileName, false); - // If the number of pivots don't match, abandon the directory/filename labeling scheme if (pivots.size() != pivotHeaders.size()) { pivots.clear(); @@ -132,7 +131,6 @@ struct RPlot minor = Pivot(i, size, pivotHeaders[i]); } } - const QString &smooth = destination.get("smooth", ""); major.smooth = !smooth.isEmpty() && (major.header == smooth) && (major.size > 1); minor.smooth = !smooth.isEmpty() && (minor.header == smooth) && (minor.size > 1); @@ -177,29 +175,42 @@ struct RPlot "TS$Y <- as.character(TS$Y)\n" "CMC$Y <- as.numeric(as.character(CMC$Y))\n" "\n" + "if (%1) {\n\tsummarySE <- function(data=NULL, measurevar, groupvars=NULL, na.rm=FALSE, conf.interval=.95, .drop=TRUE) {\n\t\t" + "require(plyr)\n\n\t\tlength2 <- function (x, na.rm=FALSE) {\n\t\t\tif (na.rm) sum(!is.na(x))\n\t\t\telse length(x)" + "\n\t\t}\n\n\t\tdatac <- ddply(data, groupvars, .drop=.drop, .fun = function(xx, col) {\n\t\t\t" + "c(N=length2(xx[[col]], na.rm=na.rm), mean=mean(xx[[col]], na.rm=na.rm), sd=sd(xx[[col]], na.rm=na.rm))\n\t\t\t}," + "\n\t\t\tmeasurevar\n\t\t)\n\n\t\tdatac <- rename(datac, c(\"mean\" = measurevar))\n\t\tdatac$se <- datac$sd / sqrt(datac$N)" + "\n\t\tciMult <- qt(conf.interval/2 + .5, datac$N-1)\n\t\tdatac$ci <- datac$se * ciMult\n\n\t\treturn(datac)\n\t}\n\t" + "DET <- summarySE(DET, measurevar=\"Y\", groupvars=c(\"%2\", \"X\"))\n\t" + "ERR <- summarySE(ERR, measurevar=\"X\", groupvars=c(\"Error\", \"%2\", \"Y\"))\n\t" + "FT <- summarySE(FT, measurevar=\"Y\", groupvars=c(\"%2\", \"X\"))\n\t" + "CT <- summarySE(CT, measurevar=\"Y\", groupvars=c(\"%2\", \"X\"))\n}\n\n" "# Code to format FAR values\n" "far_names <- list('0.001'=\"FAR = 0.1%\", '0.01'=\"FAR = 1%\")\n" "far_labeller <- function(variable,value) { return(far_names[as.character(value)]) }\n" "\n" "# Code to format TAR@FAR table\n" - "algs <- unique(FT$%1)\n" + "algs <- unique(FT$%2)\n" "algs <- algs[!duplicated(algs)]\n" - "mat <- matrix(FT$Y,nrow=6,ncol=length(algs),byrow=FALSE)\n" + "mat <- matrix(%3,nrow=6,ncol=length(algs),byrow=FALSE)\n" "colnames(mat) <- algs \n" "rownames(mat) <- c(\"FAR = 1e-06\", \"FAR = 1e-05\", \"FAR = 1e-04\", \"FAR = 1e-03\", \"FAR = 1e-02\", \"FAR = 1e-01\")\n" "FTtable <- as.table(mat)\n" "\n" "# Code to format CMC Table\n" - "mat <- matrix(CT$Y,nrow=6,ncol=length(algs),byrow=FALSE)\n" + "mat <- matrix(%4,nrow=6,ncol=length(algs),byrow=FALSE)\n" "colnames(mat) <- algs \n" "rownames(mat) <- c(\" Rank 1\", \"Rank 5\", \"Rank 10\", \"Rank 20\", \"Rank 50\", \"Rank 100\")\n" "CMCtable <- as.table(mat)\n" "\n" "# Code to format Template Size Table\n" - "mat <- matrix(TS$Y,nrow=1,ncol=length(algs),byrow=FALSE)\n" - "colnames(mat) <- algs\n" - "rownames(mat) <- c(\"Template Size (bytes):\")\n" - "TStable <- as.table(mat)\n").arg(major.header))); + "if (nrow(TS) != 0) {\n\t" + "mat <- matrix(TS$Y,nrow=1,ncol=length(algs),byrow=FALSE)\n\t" + "colnames(mat) <- algs\n\t" + "rownames(mat) <- c(\"Template Size (bytes):\")\n\t" + "TStable <- as.table(mat)\n}\n").arg(((major.smooth || minor.smooth) ? "TRUE" : "FALSE"), major.size > 1 ? major.header : (minor.header.isEmpty() ? major.header : minor.header), + (major.smooth || minor.smooth) ? "paste(as.character(round(FT$Y, 3)), round(FT$ci, 3), sep=\"\\u00b1\")" : "FT$Y", + (major.smooth || minor.smooth) ? "paste(as.character(round(CT$Y, 3)), round(CT$ci, 3), sep=\"\\u00b1\")" : "CT$Y"))); // Open output device file.write(qPrintable(QString("\n" @@ -231,8 +242,9 @@ struct RPlot "print(title(\"Table of True Accept Rates at various False Accept Rates\"))\n" "print(textplot(CMCtable))\n" "print(title(\"Table of retrieval rate at various ranks\"))\n" - "print(textplot(TStable, cex=1.15))\n" - "print(title(\"Template Size by Algorithm\"))\n"; + "if (nrow(TS) != 0) {\n\t" + "print(textplot(TStable, cex=1.15))\n\t" + "print(title(\"Template Size by Algorithm\"))\n}\n"; file.write(qPrintable(textplot.arg(PRODUCT_NAME, PRODUCT_VERSION))); } @@ -281,11 +293,12 @@ bool Plot(const QStringList &files, const File &destination, bool show) RPlot p(files, destination); - p.file.write(qPrintable(QString("qplot(X, 1-Y, data=DET%1, main=\"%2\"").arg((p.major.smooth || p.minor.smooth) ? ", geom=\"smooth\", method=loess, level=0.99" : ", geom=\"line\"", rocOpts.get("title",QString())) + + p.file.write(qPrintable(QString("qplot(X, 1-Y, data=DET, geom=\"line\", main=\"%1\"").arg(rocOpts.get("title",QString())) + (p.major.size > 1 ? QString(", colour=factor(%1)").arg(p.major.header) : QString()) + (p.minor.size > 1 ? QString(", linetype=factor(%1)").arg(p.minor.header) : QString()) + QString(", xlab=\"False Accept Rate\", ylab=\"True Accept Rate\") + theme_minimal()") + - (p.major.size > 1 ? getScale("colour", "Algorithm", p.major.size) : QString()) + + ((p.major.smooth || p.minor.smooth) ? " + geom_errorbar(data=DET[seq(1, NROW(DET), by = 29),], aes(x=X, ymin=(1-Y)-ci, ymax=(1-Y)+ci), width=0.1, alpha=I(1/2))" : QString()) + + (p.major.size > 1 ? getScale("colour", p.major.header, p.major.size) : QString()) + (p.minor.size > 1 ? QString(" + scale_linetype_discrete(\"%1\")").arg(p.minor.header) : QString()) + QString(" + scale_x_log10(labels=trans_format(\"log10\", math_format()))") + (rocOpts.contains("yLimits") ? QString(" + scale_y_continuous(labels=percent) + coord_cartesian(ylim=%1)").arg("c"+QtUtils::toString(rocOpts.get("yLimits",QPointF()))) : QString(" + scale_y_continuous(labels=percent)")) + @@ -293,12 +306,14 @@ bool Plot(const QStringList &files, const File &destination, bool show) QString(" + theme(legend.title = element_text(size = %1), plot.title = element_text(size = %1), axis.text = element_text(size = %1), axis.title.x = element_text(size = %1), axis.title.y = element_text(size = %1)," " legend.position=%2, legend.background = element_rect(fill = 'white'), panel.grid.major = element_line(colour = \"gray\"), panel.grid.minor = element_line(colour = \"gray\", linetype = \"dashed\"), legend.text = element_text(size = %1))\n\n").arg(QString::number(rocOpts.get("textSize",12)), rocOpts.contains("legendPosition") ? "c"+QtUtils::toString(rocOpts.get("legendPosition")) : "'bottom'"))); - p.file.write(qPrintable(QString("qplot(X, Y, data=DET%1").arg((p.major.smooth || p.minor.smooth) ? ", geom=\"smooth\", method=loess, level=0.99" : ", geom=\"line\"") + + p.file.write(qPrintable(QString("qplot(X, Y, data=DET, geom=\"line\"") + (p.major.size > 1 ? QString(", colour=factor(%1)").arg(p.major.header) : QString()) + (p.minor.size > 1 ? QString(", linetype=factor(%1)").arg(p.minor.header) : QString()) + QString(", xlab=\"False Accept Rate\", ylab=\"False Reject Rate\") + geom_abline(alpha=0.5, colour=\"grey\", linetype=\"dashed\") + theme_minimal()") + - (p.major.size > 1 ? getScale("colour", "Algorithm", p.major.size) : QString()) + + ((p.major.smooth || p.minor.smooth) ? " + geom_errorbar(data=DET[seq(1, NROW(DET), by = 29),], aes(x=X, ymin=Y-ci, ymax=Y+ci), width=0.1, alpha=I(1/2))" : QString()) + + (p.major.size > 1 ? getScale("colour", p.major.header, p.major.size) : QString()) + (p.minor.size > 1 ? QString(" + scale_linetype_discrete(\"%1\")").arg(p.minor.header) : QString()) + + QString(" + theme(legend.position=%1)").arg(rocOpts.contains("legendPosition") ? "c"+QtUtils::toString(rocOpts.get("legendPosition")) : "'bottom'") + QString(" + scale_x_log10(labels=trans_format(\"log10\", math_format())) + scale_y_log10(labels=trans_format(\"log10\", math_format())) + annotation_logticks()\n\n"))); p.file.write(qPrintable(QString("qplot(X, data=SD, geom=\"histogram\", fill=Y, position=\"identity\", alpha=I(1/2)") + @@ -311,7 +326,7 @@ bool Plot(const QStringList &files, const File &destination, bool show) QString(((p.major.smooth || p.minor.smooth) ? (!uncertainty ? " + stat_summary(geom=\"line\", fun.y=mean, size=%1)" : " + stat_summary(geom=\"line\", fun.y=min, aes(linetype=\"Min/Max\"), size=%1) + stat_summary(geom=\"line\", " "fun.y=max, aes(linetype=\"Min/Max\"), size=%1) + stat_summary(geom=\"line\", fun.y=mean, aes(linetype=\"Mean\"), size=%1) + scale_linetype_manual(\"Legend\", values=c(\"Mean\"=1, \"Min/Max\"=2))") : " + geom_line(size=%1)")).arg(QString::number(cmcOpts.get("thickness",1))) + (minimalist ? "" : " + scale_x_log10(labels=c(1,5,10,50,100), breaks=c(1,5,10,50,100)) + annotation_logticks(sides=\"b\")") + - (p.major.size > 1 ? getScale("colour", "Algorithm", p.major.size) : QString()) + + (p.major.size > 1 ? getScale("colour", p.major.header, p.major.size) : QString()) + (p.minor.size > 1 ? QString(" + scale_linetype_discrete(\"%1\")").arg(p.minor.header) : QString()) + (cmcOpts.contains("yLimits") ? QString(" + scale_y_continuous(labels=percent) + coord_cartesian(ylim=%1)").arg("c"+QtUtils::toString(cmcOpts.get("yLimits",QPointF()))) : QString(" + scale_y_continuous(labels=percent)")) + QString(" + theme_minimal() + theme(legend.title = element_text(size = %1), plot.title = element_text(size = %1), axis.text = element_text(size = %1), axis.title.x = element_text(size = %1), axis.title.y = element_text(size = %1)," @@ -320,14 +335,14 @@ bool Plot(const QStringList &files, const File &destination, bool show) p.file.write(qPrintable(QString("qplot(factor(%1)%2, data=BC, %3").arg(p.major.smooth ? (p.minor.header.isEmpty() ? "Algorithm" : p.minor.header) : p.major.header, (p.major.smooth || p.minor.smooth) ? ", Y" : "", (p.major.smooth || p.minor.smooth) ? "geom=\"boxplot\"" : "geom=\"bar\", position=\"dodge\", weight=Y") + (p.major.size > 1 ? QString(", fill=factor(%1)").arg(p.major.header) : QString()) + QString(", xlab=\"False Accept Rate\", ylab=\"True Accept Rate\") + theme_minimal()") + - (p.major.size > 1 ? getScale("fill", "Algorithm", p.major.size) : QString()) + + (p.major.size > 1 ? getScale("fill", p.major.header, p.major.size) : QString()) + (p.minor.size > 1 ? QString(" + facet_grid(%2 ~ X)").arg(p.minor.header) : QString(" + facet_grid(. ~ X, labeller=far_labeller)")) + QString(" + scale_y_continuous(labels=percent) + theme(legend.position=\"none\", axis.text.x=element_text(angle=-90, hjust=0))%1").arg((p.major.smooth || p.minor.smooth) ? "" : " + geom_text(data=BC, aes(label=Y, y=0.05))") + "\n\n")); - p.file.write(qPrintable(QString("qplot(X, Y, data=ERR%1, linetype=Error").arg((p.major.smooth || p.minor.smooth) ? ", geom=\"smooth\", method=loess, level=0.99" : ", geom=\"line\"") + + p.file.write(qPrintable(QString("qplot(X, Y, data=ERR, geom=\"line\", linetype=Error") + ((p.flip ? p.major.size : p.minor.size) > 1 ? QString(", colour=factor(%1)").arg(p.flip ? p.major.header : p.minor.header) : QString()) + QString(", xlab=\"Score\", ylab=\"Error Rate\") + theme_minimal()") + - ((p.flip ? p.major.size : p.minor.size) > 1 ? getScale("colour", p.flip ? "Algorithm" : "Algorithm", p.flip ? p.major.size : p.minor.size) : QString()) + + ((p.flip ? p.major.size : p.minor.size) > 1 ? getScale("colour", p.flip ? p.major.header : p.minor.header, p.flip ? p.major.size : p.minor.size) : QString()) + QString(" + scale_y_log10(labels=percent) + annotation_logticks(sides=\"l\")") + ((p.flip ? p.minor.size : p.major.size) > 1 ? QString(" + facet_wrap(~ %1, scales=\"free_x\")").arg(p.flip ? p.minor.header : p.major.header) : QString()) + QString(" + theme(aspect.ratio=1)\n\n"))); -- libgit2 0.21.4