diff --git a/app/br/br.cpp b/app/br/br.cpp index 7528393..efd0483 100644 --- a/app/br/br.cpp +++ b/app/br/br.cpp @@ -135,8 +135,8 @@ public: check(parc == 2, "Incorrect parameter count for 'evalClustering'."); br_eval_clustering(parv[0], parv[1]); } else if (!strcmp(fun, "evalDetection")) { - check(parc == 2, "Incorrect parameter count for 'evalDetection'."); - br_eval_detection(parv[0], parv[1]); + check((parc >= 2) && (parc <= 3), "Incorrect parameter count for 'evalDetection'."); + br_eval_detection(parv[0], parv[1], parc == 3 ? parv[2] : ""); } else if (!strcmp(fun, "evalRegression")) { check(parc == 2, "Incorrect parameter count for 'evalRegression'."); br_eval_regression(parv[0], parv[1]); @@ -214,7 +214,7 @@ private: "-convert (Format|Gallery|Output) {output_file}\n" "-evalClassification \n" "-evalClustering \n" - "-evalDetection \n" + "-evalDetection [{csv}]\n" "-evalRegression \n" "-plotMetadata ... \n" "-getHeader \n" diff --git a/openbr/core/eval.cpp b/openbr/core/eval.cpp index 2f13e66..8f3e419 100644 --- a/openbr/core/eval.cpp +++ b/openbr/core/eval.cpp @@ -23,6 +23,8 @@ using namespace cv; namespace br { +static const int Max_Points = 500; // Maximum number of points to render on plots + struct Comparison { float score; @@ -100,7 +102,6 @@ float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv) qFatal("Similarity matrix (%ix%i) differs in size from mask matrix (%ix%i).", simmat.rows, simmat.cols, mask.rows, mask.cols); - const int Max_Points = 500; float result = -1; // Make comparisons @@ -237,7 +238,7 @@ float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv) if (i == Report_Retrieval) reportRetrievalRate = retrievalRate; } - if (!csv.isEmpty()) QtUtils::writeFile(csv, lines); + QtUtils::writeFile(csv, lines); qDebug("TAR @ FAR = 0.01: %.3f\nRetrieval Rate @ Rank = %d: %.3f", result, Report_Retrieval, reportRetrievalRate); return result; } @@ -325,7 +326,7 @@ struct Detection float overlap(const Detection &other) const { const Detection intersection(boundingBox.intersected(other.boundingBox)); - return intersection.area() / (area() + other.area() - 2*intersection.area()); + return intersection.area() / (area() + other.area() - intersection.area()); } }; @@ -334,14 +335,53 @@ struct Detections QList predicted, truth; }; -struct DetectionOperatingPoint +struct ResolvedDetection { float confidence, overlap; - DetectionOperatingPoint() : confidence(-1), overlap(-1) {} - DetectionOperatingPoint(float confidence_, float overlap_) : confidence(confidence_), overlap(overlap_) {} - inline bool operator<(const DetectionOperatingPoint &other) const { return confidence > other.confidence; } + ResolvedDetection() : confidence(-1), overlap(-1) {} + ResolvedDetection(float confidence_, float overlap_) : confidence(confidence_), overlap(overlap_) {} + inline bool operator<(const ResolvedDetection &other) const { return confidence > other.confidence; } +}; + +struct DetectionOperatingPoint +{ + float Recall, FalsePositives, Precision; + DetectionOperatingPoint() : Recall(-1), FalsePositives(-1), Precision(-1) {} + DetectionOperatingPoint(float TP, float FP, float totalPositives) + : Recall(TP/totalPositives), FalsePositives(FP), Precision(TP/(TP+FP)) {} }; +static QStringList computeDetectionResults(const QList &detections, int totalPositives, bool discrete) +{ + QList points; + float TP = 0, FP = 0, prevFP = 0; + for (int i=0; i= 0.5) TP++; + else FP++; + } else { + TP += detection.overlap; + FP += 1 - detection.overlap; + } + if ((i == detections.size()-1) || (detection.confidence > detections[i+1].confidence)) { + if (FP > prevFP) { + points.append(DetectionOperatingPoint(TP, FP, totalPositives)); + prevFP = FP; + } + } + } + + const int keep = qMin(points.size(), Max_Points); + QStringList lines; lines.reserve(keep); + for (int i=0; i allDetections; // Organized by file + QMap allDetections; // Organized by file, QMap used to preserve order foreach (const Template &t, predicted) allDetections[t.file.baseName()].predicted.append(Detection(t.file.get(detectKey), t.file.get("Confidence", -1))); foreach (const Template &t, truth) allDetections[t.file.baseName()].truth.append(Detection(t.file.get(detectKey))); - QList points; + QList resolvedDetections, falseNegativeDetections; foreach (Detections detections, allDetections.values()) { while (!detections.truth.isEmpty() && !detections.predicted.isEmpty()) { - Detection truth = detections.truth.takeFirst(); + const Detection truth = detections.truth.takeFirst(); int bestIndex = -1; - float bestOverlap = -1; + float bestOverlap = -std::numeric_limits::max(); for (int i=0; i bestOverlap) { @@ -377,25 +417,40 @@ float EvalDetection(const QString &predictedInput, const QString &truthInput, co bestIndex = i; } } - Detection predicted = detections.predicted.takeAt(bestIndex); - points.append(DetectionOperatingPoint(predicted.confidence, bestOverlap)); + const Detection predicted = detections.predicted.takeAt(bestIndex); + resolvedDetections.append(ResolvedDetection(predicted.confidence, bestOverlap)); } foreach (const Detection &detection, detections.predicted) - points.append(DetectionOperatingPoint(detection.confidence, 0)); + resolvedDetections.append(ResolvedDetection(detection.confidence, 0)); for (int i=0; i::max(), 0)); + falseNegativeDetections.append(ResolvedDetection(-std::numeric_limits::max(), 0)); } - std::sort(points.begin(), points.end()); + std::sort(resolvedDetections.begin(), resolvedDetections.end()); QStringList lines; lines.append("Plot, X, Y"); + lines.append(computeDetectionResults(resolvedDetections, truth.size(), true)); + lines.append(computeDetectionResults(resolvedDetections, truth.size(), false)); + + float averageOverlap; + { // Overlap Density + QList allDetections; allDetections << resolvedDetections << falseNegativeDetections; + const int keep = qMin(allDetections.size(), Max_Points); + lines.reserve(lines.size() + keep); + float totalOverlap = 0; + for (int i=0; i