Commit 07e6c737ed3827b9627b840b9541a28ec47c8c60

Authored by Josh Klontz
1 parent cc1536f1

first draft of evalDetection complete

app/br/br.cpp
... ... @@ -135,8 +135,8 @@ public:
135 135 check(parc == 2, "Incorrect parameter count for 'evalClustering'.");
136 136 br_eval_clustering(parv[0], parv[1]);
137 137 } else if (!strcmp(fun, "evalDetection")) {
138   - check(parc == 2, "Incorrect parameter count for 'evalDetection'.");
139   - br_eval_detection(parv[0], parv[1]);
  138 + check((parc >= 2) && (parc <= 3), "Incorrect parameter count for 'evalDetection'.");
  139 + br_eval_detection(parv[0], parv[1], parc == 3 ? parv[2] : "");
140 140 } else if (!strcmp(fun, "evalRegression")) {
141 141 check(parc == 2, "Incorrect parameter count for 'evalRegression'.");
142 142 br_eval_regression(parv[0], parv[1]);
... ... @@ -214,7 +214,7 @@ private:
214 214 "-convert (Format|Gallery|Output) <input_file> {output_file}\n"
215 215 "-evalClassification <predicted_gallery> <truth_gallery>\n"
216 216 "-evalClustering <clusters> <gallery>\n"
217   - "-evalDetection <predicted_gallery> <truth_gallery>\n"
  217 + "-evalDetection <predicted_gallery> <truth_gallery> [{csv}]\n"
218 218 "-evalRegression <predicted_gallery> <truth_gallery>\n"
219 219 "-plotMetadata <file> ... <file> <columns>\n"
220 220 "-getHeader <matrix>\n"
... ...
openbr/core/eval.cpp
... ... @@ -23,6 +23,8 @@ using namespace cv;
23 23 namespace br
24 24 {
25 25  
  26 +static const int Max_Points = 500; // Maximum number of points to render on plots
  27 +
26 28 struct Comparison
27 29 {
28 30 float score;
... ... @@ -100,7 +102,6 @@ float Evaluate(const Mat &amp;simmat, const Mat &amp;mask, const QString &amp;csv)
100 102 qFatal("Similarity matrix (%ix%i) differs in size from mask matrix (%ix%i).",
101 103 simmat.rows, simmat.cols, mask.rows, mask.cols);
102 104  
103   - const int Max_Points = 500;
104 105 float result = -1;
105 106  
106 107 // Make comparisons
... ... @@ -237,7 +238,7 @@ float Evaluate(const Mat &amp;simmat, const Mat &amp;mask, const QString &amp;csv)
237 238 if (i == Report_Retrieval) reportRetrievalRate = retrievalRate;
238 239 }
239 240  
240   - if (!csv.isEmpty()) QtUtils::writeFile(csv, lines);
  241 + QtUtils::writeFile(csv, lines);
241 242 qDebug("TAR @ FAR = 0.01: %.3f\nRetrieval Rate @ Rank = %d: %.3f", result, Report_Retrieval, reportRetrievalRate);
242 243 return result;
243 244 }
... ... @@ -325,7 +326,7 @@ struct Detection
325 326 float overlap(const Detection &other) const
326 327 {
327 328 const Detection intersection(boundingBox.intersected(other.boundingBox));
328   - return intersection.area() / (area() + other.area() - 2*intersection.area());
  329 + return intersection.area() / (area() + other.area() - intersection.area());
329 330 }
330 331 };
331 332  
... ... @@ -334,14 +335,53 @@ struct Detections
334 335 QList<Detection> predicted, truth;
335 336 };
336 337  
337   -struct DetectionOperatingPoint
  338 +struct ResolvedDetection
338 339 {
339 340 float confidence, overlap;
340   - DetectionOperatingPoint() : confidence(-1), overlap(-1) {}
341   - DetectionOperatingPoint(float confidence_, float overlap_) : confidence(confidence_), overlap(overlap_) {}
342   - inline bool operator<(const DetectionOperatingPoint &other) const { return confidence > other.confidence; }
  341 + ResolvedDetection() : confidence(-1), overlap(-1) {}
  342 + ResolvedDetection(float confidence_, float overlap_) : confidence(confidence_), overlap(overlap_) {}
  343 + inline bool operator<(const ResolvedDetection &other) const { return confidence > other.confidence; }
  344 +};
  345 +
  346 +struct DetectionOperatingPoint
  347 +{
  348 + float Recall, FalsePositives, Precision;
  349 + DetectionOperatingPoint() : Recall(-1), FalsePositives(-1), Precision(-1) {}
  350 + DetectionOperatingPoint(float TP, float FP, float totalPositives)
  351 + : Recall(TP/totalPositives), FalsePositives(FP), Precision(TP/(TP+FP)) {}
343 352 };
344 353  
  354 +static QStringList computeDetectionResults(const QList<ResolvedDetection> &detections, int totalPositives, bool discrete)
  355 +{
  356 + QList<DetectionOperatingPoint> points;
  357 + float TP = 0, FP = 0, prevFP = 0;
  358 + for (int i=0; i<detections.size(); i++) {
  359 + const ResolvedDetection &detection = detections[i];
  360 + if (discrete) {
  361 + if (detection.overlap >= 0.5) TP++;
  362 + else FP++;
  363 + } else {
  364 + TP += detection.overlap;
  365 + FP += 1 - detection.overlap;
  366 + }
  367 + if ((i == detections.size()-1) || (detection.confidence > detections[i+1].confidence)) {
  368 + if (FP > prevFP) {
  369 + points.append(DetectionOperatingPoint(TP, FP, totalPositives));
  370 + prevFP = FP;
  371 + }
  372 + }
  373 + }
  374 +
  375 + const int keep = qMin(points.size(), Max_Points);
  376 + QStringList lines; lines.reserve(keep);
  377 + for (int i=0; i<keep; i++) {
  378 + const DetectionOperatingPoint &point = points[double(i) / double(keep-1) * double(points.size()-1)];
  379 + lines.append(QString("%1ROC, %2, %3").arg(discrete ? "Discrete" : "Continuous", QString::number(point.FalsePositives), QString::number(point.Recall)));
  380 + lines.append(QString("%1PR, %2, %3").arg(discrete ? "Discrete" : "Continuous", QString::number(point.Precision), QString::number(point.Recall)));
  381 + }
  382 + return lines;
  383 +}
  384 +
345 385 float EvalDetection(const QString &predictedInput, const QString &truthInput, const QString &csv)
346 386 {
347 387 qDebug("Evaluating detection of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
... ... @@ -358,18 +398,18 @@ float EvalDetection(const QString &amp;predictedInput, const QString &amp;truthInput, co
358 398 if (detectKey.isNull()) qFatal("No suitable metadata key found.");
359 399 else qDebug("Using metadata key: %s", qPrintable(detectKey));
360 400  
361   - QHash<QString, Detections> allDetections; // Organized by file
  401 + QMap<QString, Detections> allDetections; // Organized by file, QMap used to preserve order
362 402 foreach (const Template &t, predicted)
363 403 allDetections[t.file.baseName()].predicted.append(Detection(t.file.get<QRectF>(detectKey), t.file.get<float>("Confidence", -1)));
364 404 foreach (const Template &t, truth)
365 405 allDetections[t.file.baseName()].truth.append(Detection(t.file.get<QRectF>(detectKey)));
366 406  
367   - QList<DetectionOperatingPoint> points;
  407 + QList<ResolvedDetection> resolvedDetections, falseNegativeDetections;
368 408 foreach (Detections detections, allDetections.values()) {
369 409 while (!detections.truth.isEmpty() && !detections.predicted.isEmpty()) {
370   - Detection truth = detections.truth.takeFirst();
  410 + const Detection truth = detections.truth.takeFirst();
371 411 int bestIndex = -1;
372   - float bestOverlap = -1;
  412 + float bestOverlap = -std::numeric_limits<float>::max();
373 413 for (int i=0; i<detections.predicted.size(); i++) {
374 414 const float overlap = truth.overlap(detections.predicted[i]);
375 415 if (overlap > bestOverlap) {
... ... @@ -377,25 +417,40 @@ float EvalDetection(const QString &amp;predictedInput, const QString &amp;truthInput, co
377 417 bestIndex = i;
378 418 }
379 419 }
380   - Detection predicted = detections.predicted.takeAt(bestIndex);
381   - points.append(DetectionOperatingPoint(predicted.confidence, bestOverlap));
  420 + const Detection predicted = detections.predicted.takeAt(bestIndex);
  421 + resolvedDetections.append(ResolvedDetection(predicted.confidence, bestOverlap));
382 422 }
383 423  
384 424 foreach (const Detection &detection, detections.predicted)
385   - points.append(DetectionOperatingPoint(detection.confidence, 0));
  425 + resolvedDetections.append(ResolvedDetection(detection.confidence, 0));
386 426 for (int i=0; i<detections.truth.size(); i++)
387   - points.append(DetectionOperatingPoint(-std::numeric_limits<float>::max(), 0));
  427 + falseNegativeDetections.append(ResolvedDetection(-std::numeric_limits<float>::max(), 0));
388 428 }
389 429  
390   - std::sort(points.begin(), points.end());
  430 + std::sort(resolvedDetections.begin(), resolvedDetections.end());
391 431  
392 432 QStringList lines;
393 433 lines.append("Plot, X, Y");
  434 + lines.append(computeDetectionResults(resolvedDetections, truth.size(), true));
  435 + lines.append(computeDetectionResults(resolvedDetections, truth.size(), false));
  436 +
  437 + float averageOverlap;
  438 + { // Overlap Density
  439 + QList<ResolvedDetection> allDetections; allDetections << resolvedDetections << falseNegativeDetections;
  440 + const int keep = qMin(allDetections.size(), Max_Points);
  441 + lines.reserve(lines.size() + keep);
  442 + float totalOverlap = 0;
  443 + for (int i=0; i<keep; i++) {
  444 + const float overlap = allDetections[double(i) / double(keep-1) * double(allDetections.size()-1)].overlap;
  445 + totalOverlap += overlap;
  446 + lines.append(QString("Overlap,%1,1").arg(QString::number(allDetections[double(i) / double(keep-1) * double(allDetections.size()-1)].overlap)));
  447 + }
  448 + averageOverlap = totalOverlap / keep;
  449 + }
394 450  
395   - // TODO: finish implementing
396   -
397   - (void) csv;
398   - return 0;
  451 + QtUtils::writeFile(csv, lines);
  452 + qDebug("Average Overlap = %.3f", averageOverlap);
  453 + return averageOverlap;
399 454 }
400 455  
401 456 void EvalRegression(const QString &predictedInput, const QString &truthInput)
... ...
openbr/core/qtutils.cpp
... ... @@ -105,6 +105,7 @@ void QtUtils::writeFile(const QString &amp;file, const QString &amp;data)
105 105  
106 106 void QtUtils::writeFile(const QString &file, const QByteArray &data, int compression)
107 107 {
  108 + if (file.isEmpty()) return;
108 109 const QString baseName = QFileInfo(file).baseName();
109 110 const QByteArray contents = (compression == 0) ? data : qCompress(data, compression);
110 111 if (baseName == "terminal") {
... ...
openbr/openbr.cpp
... ... @@ -82,9 +82,9 @@ void br_eval_clustering(const char *csv, const char *gallery)
82 82 EvalClustering(csv, gallery);
83 83 }
84 84  
85   -void br_eval_detection(const char *predicted_gallery, const char *truth_gallery)
  85 +float br_eval_detection(const char *predicted_gallery, const char *truth_gallery, const char *csv)
86 86 {
87   - EvalDetection(predicted_gallery, truth_gallery);
  87 + return EvalDetection(predicted_gallery, truth_gallery, csv);
88 88 }
89 89  
90 90 void br_eval_regression(const char *predicted_gallery, const char *truth_gallery)
... ...
openbr/openbr.h
... ... @@ -148,7 +148,6 @@ BR_EXPORT float br_eval(const char *simmat, const char *mask, const char *csv =
148 148 * \brief Evaluates and prints classification accuracy to terminal.
149 149 * \param predicted_input The predicted br::Gallery.
150 150 * \param truth_input The ground truth br::Gallery.
151   - * \see br_enroll
152 151 */
153 152 BR_EXPORT void br_eval_classification(const char *predicted_gallery, const char *truth_gallery);
154 153  
... ... @@ -164,15 +163,15 @@ BR_EXPORT void br_eval_clustering(const char *csv, const char *gallery);
164 163 * \brief Evaluates and prints detection accuracy to terminal.
165 164 * \param predicted_gallery The predicted br::Gallery.
166 165 * \param truth_galery The ground truth br::Gallery.
167   - * \see br_enroll
  166 + * \param csv Optional \c .csv file to contain performance metrics.
  167 + * \return Average detection bounding box overlap.
168 168 */
169   -BR_EXPORT void br_eval_detection(const char *predicted_gallery, const char *truth_gallery);
  169 +BR_EXPORT float br_eval_detection(const char *predicted_gallery, const char *truth_gallery, const char *csv = "");
170 170  
171 171 /*!
172 172 * \brief Evaluates regression accuracy to disk.
173 173 * \param predicted_input The predicted br::Gallery.
174 174 * \param truth_input The ground truth br::Gallery.
175   - * \see br_enroll
176 175 */
177 176 BR_EXPORT void br_eval_regression(const char *predicted_gallery, const char *truth_gallery);
178 177  
... ...