Commit 07e6c737ed3827b9627b840b9541a28ec47c8c60

Authored by Josh Klontz
1 parent cc1536f1

first draft of evalDetection complete

app/br/br.cpp
@@ -135,8 +135,8 @@ public: @@ -135,8 +135,8 @@ public:
135 check(parc == 2, "Incorrect parameter count for 'evalClustering'."); 135 check(parc == 2, "Incorrect parameter count for 'evalClustering'.");
136 br_eval_clustering(parv[0], parv[1]); 136 br_eval_clustering(parv[0], parv[1]);
137 } else if (!strcmp(fun, "evalDetection")) { 137 } else if (!strcmp(fun, "evalDetection")) {
138 - check(parc == 2, "Incorrect parameter count for 'evalDetection'.");  
139 - br_eval_detection(parv[0], parv[1]); 138 + check((parc >= 2) && (parc <= 3), "Incorrect parameter count for 'evalDetection'.");
  139 + br_eval_detection(parv[0], parv[1], parc == 3 ? parv[2] : "");
140 } else if (!strcmp(fun, "evalRegression")) { 140 } else if (!strcmp(fun, "evalRegression")) {
141 check(parc == 2, "Incorrect parameter count for 'evalRegression'."); 141 check(parc == 2, "Incorrect parameter count for 'evalRegression'.");
142 br_eval_regression(parv[0], parv[1]); 142 br_eval_regression(parv[0], parv[1]);
@@ -214,7 +214,7 @@ private: @@ -214,7 +214,7 @@ private:
214 "-convert (Format|Gallery|Output) <input_file> {output_file}\n" 214 "-convert (Format|Gallery|Output) <input_file> {output_file}\n"
215 "-evalClassification <predicted_gallery> <truth_gallery>\n" 215 "-evalClassification <predicted_gallery> <truth_gallery>\n"
216 "-evalClustering <clusters> <gallery>\n" 216 "-evalClustering <clusters> <gallery>\n"
217 - "-evalDetection <predicted_gallery> <truth_gallery>\n" 217 + "-evalDetection <predicted_gallery> <truth_gallery> [{csv}]\n"
218 "-evalRegression <predicted_gallery> <truth_gallery>\n" 218 "-evalRegression <predicted_gallery> <truth_gallery>\n"
219 "-plotMetadata <file> ... <file> <columns>\n" 219 "-plotMetadata <file> ... <file> <columns>\n"
220 "-getHeader <matrix>\n" 220 "-getHeader <matrix>\n"
openbr/core/eval.cpp
@@ -23,6 +23,8 @@ using namespace cv; @@ -23,6 +23,8 @@ using namespace cv;
23 namespace br 23 namespace br
24 { 24 {
25 25
  26 +static const int Max_Points = 500; // Maximum number of points to render on plots
  27 +
26 struct Comparison 28 struct Comparison
27 { 29 {
28 float score; 30 float score;
@@ -100,7 +102,6 @@ float Evaluate(const Mat &amp;simmat, const Mat &amp;mask, const QString &amp;csv) @@ -100,7 +102,6 @@ float Evaluate(const Mat &amp;simmat, const Mat &amp;mask, const QString &amp;csv)
100 qFatal("Similarity matrix (%ix%i) differs in size from mask matrix (%ix%i).", 102 qFatal("Similarity matrix (%ix%i) differs in size from mask matrix (%ix%i).",
101 simmat.rows, simmat.cols, mask.rows, mask.cols); 103 simmat.rows, simmat.cols, mask.rows, mask.cols);
102 104
103 - const int Max_Points = 500;  
104 float result = -1; 105 float result = -1;
105 106
106 // Make comparisons 107 // Make comparisons
@@ -237,7 +238,7 @@ float Evaluate(const Mat &amp;simmat, const Mat &amp;mask, const QString &amp;csv) @@ -237,7 +238,7 @@ float Evaluate(const Mat &amp;simmat, const Mat &amp;mask, const QString &amp;csv)
237 if (i == Report_Retrieval) reportRetrievalRate = retrievalRate; 238 if (i == Report_Retrieval) reportRetrievalRate = retrievalRate;
238 } 239 }
239 240
240 - if (!csv.isEmpty()) QtUtils::writeFile(csv, lines); 241 + QtUtils::writeFile(csv, lines);
241 qDebug("TAR @ FAR = 0.01: %.3f\nRetrieval Rate @ Rank = %d: %.3f", result, Report_Retrieval, reportRetrievalRate); 242 qDebug("TAR @ FAR = 0.01: %.3f\nRetrieval Rate @ Rank = %d: %.3f", result, Report_Retrieval, reportRetrievalRate);
242 return result; 243 return result;
243 } 244 }
@@ -325,7 +326,7 @@ struct Detection @@ -325,7 +326,7 @@ struct Detection
325 float overlap(const Detection &other) const 326 float overlap(const Detection &other) const
326 { 327 {
327 const Detection intersection(boundingBox.intersected(other.boundingBox)); 328 const Detection intersection(boundingBox.intersected(other.boundingBox));
328 - return intersection.area() / (area() + other.area() - 2*intersection.area()); 329 + return intersection.area() / (area() + other.area() - intersection.area());
329 } 330 }
330 }; 331 };
331 332
@@ -334,14 +335,53 @@ struct Detections @@ -334,14 +335,53 @@ struct Detections
334 QList<Detection> predicted, truth; 335 QList<Detection> predicted, truth;
335 }; 336 };
336 337
337 -struct DetectionOperatingPoint 338 +struct ResolvedDetection
338 { 339 {
339 float confidence, overlap; 340 float confidence, overlap;
340 - DetectionOperatingPoint() : confidence(-1), overlap(-1) {}  
341 - DetectionOperatingPoint(float confidence_, float overlap_) : confidence(confidence_), overlap(overlap_) {}  
342 - inline bool operator<(const DetectionOperatingPoint &other) const { return confidence > other.confidence; } 341 + ResolvedDetection() : confidence(-1), overlap(-1) {}
  342 + ResolvedDetection(float confidence_, float overlap_) : confidence(confidence_), overlap(overlap_) {}
  343 + inline bool operator<(const ResolvedDetection &other) const { return confidence > other.confidence; }
  344 +};
  345 +
  346 +struct DetectionOperatingPoint
  347 +{
  348 + float Recall, FalsePositives, Precision;
  349 + DetectionOperatingPoint() : Recall(-1), FalsePositives(-1), Precision(-1) {}
  350 + DetectionOperatingPoint(float TP, float FP, float totalPositives)
  351 + : Recall(TP/totalPositives), FalsePositives(FP), Precision(TP/(TP+FP)) {}
343 }; 352 };
344 353
  354 +static QStringList computeDetectionResults(const QList<ResolvedDetection> &detections, int totalPositives, bool discrete)
  355 +{
  356 + QList<DetectionOperatingPoint> points;
  357 + float TP = 0, FP = 0, prevFP = 0;
  358 + for (int i=0; i<detections.size(); i++) {
  359 + const ResolvedDetection &detection = detections[i];
  360 + if (discrete) {
  361 + if (detection.overlap >= 0.5) TP++;
  362 + else FP++;
  363 + } else {
  364 + TP += detection.overlap;
  365 + FP += 1 - detection.overlap;
  366 + }
  367 + if ((i == detections.size()-1) || (detection.confidence > detections[i+1].confidence)) {
  368 + if (FP > prevFP) {
  369 + points.append(DetectionOperatingPoint(TP, FP, totalPositives));
  370 + prevFP = FP;
  371 + }
  372 + }
  373 + }
  374 +
  375 + const int keep = qMin(points.size(), Max_Points);
  376 + QStringList lines; lines.reserve(keep);
  377 + for (int i=0; i<keep; i++) {
  378 + const DetectionOperatingPoint &point = points[double(i) / double(keep-1) * double(points.size()-1)];
  379 + lines.append(QString("%1ROC, %2, %3").arg(discrete ? "Discrete" : "Continuous", QString::number(point.FalsePositives), QString::number(point.Recall)));
  380 + lines.append(QString("%1PR, %2, %3").arg(discrete ? "Discrete" : "Continuous", QString::number(point.Precision), QString::number(point.Recall)));
  381 + }
  382 + return lines;
  383 +}
  384 +
345 float EvalDetection(const QString &predictedInput, const QString &truthInput, const QString &csv) 385 float EvalDetection(const QString &predictedInput, const QString &truthInput, const QString &csv)
346 { 386 {
347 qDebug("Evaluating detection of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); 387 qDebug("Evaluating detection of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
@@ -358,18 +398,18 @@ float EvalDetection(const QString &amp;predictedInput, const QString &amp;truthInput, co @@ -358,18 +398,18 @@ float EvalDetection(const QString &amp;predictedInput, const QString &amp;truthInput, co
358 if (detectKey.isNull()) qFatal("No suitable metadata key found."); 398 if (detectKey.isNull()) qFatal("No suitable metadata key found.");
359 else qDebug("Using metadata key: %s", qPrintable(detectKey)); 399 else qDebug("Using metadata key: %s", qPrintable(detectKey));
360 400
361 - QHash<QString, Detections> allDetections; // Organized by file 401 + QMap<QString, Detections> allDetections; // Organized by file, QMap used to preserve order
362 foreach (const Template &t, predicted) 402 foreach (const Template &t, predicted)
363 allDetections[t.file.baseName()].predicted.append(Detection(t.file.get<QRectF>(detectKey), t.file.get<float>("Confidence", -1))); 403 allDetections[t.file.baseName()].predicted.append(Detection(t.file.get<QRectF>(detectKey), t.file.get<float>("Confidence", -1)));
364 foreach (const Template &t, truth) 404 foreach (const Template &t, truth)
365 allDetections[t.file.baseName()].truth.append(Detection(t.file.get<QRectF>(detectKey))); 405 allDetections[t.file.baseName()].truth.append(Detection(t.file.get<QRectF>(detectKey)));
366 406
367 - QList<DetectionOperatingPoint> points; 407 + QList<ResolvedDetection> resolvedDetections, falseNegativeDetections;
368 foreach (Detections detections, allDetections.values()) { 408 foreach (Detections detections, allDetections.values()) {
369 while (!detections.truth.isEmpty() && !detections.predicted.isEmpty()) { 409 while (!detections.truth.isEmpty() && !detections.predicted.isEmpty()) {
370 - Detection truth = detections.truth.takeFirst(); 410 + const Detection truth = detections.truth.takeFirst();
371 int bestIndex = -1; 411 int bestIndex = -1;
372 - float bestOverlap = -1; 412 + float bestOverlap = -std::numeric_limits<float>::max();
373 for (int i=0; i<detections.predicted.size(); i++) { 413 for (int i=0; i<detections.predicted.size(); i++) {
374 const float overlap = truth.overlap(detections.predicted[i]); 414 const float overlap = truth.overlap(detections.predicted[i]);
375 if (overlap > bestOverlap) { 415 if (overlap > bestOverlap) {
@@ -377,25 +417,40 @@ float EvalDetection(const QString &amp;predictedInput, const QString &amp;truthInput, co @@ -377,25 +417,40 @@ float EvalDetection(const QString &amp;predictedInput, const QString &amp;truthInput, co
377 bestIndex = i; 417 bestIndex = i;
378 } 418 }
379 } 419 }
380 - Detection predicted = detections.predicted.takeAt(bestIndex);  
381 - points.append(DetectionOperatingPoint(predicted.confidence, bestOverlap)); 420 + const Detection predicted = detections.predicted.takeAt(bestIndex);
  421 + resolvedDetections.append(ResolvedDetection(predicted.confidence, bestOverlap));
382 } 422 }
383 423
384 foreach (const Detection &detection, detections.predicted) 424 foreach (const Detection &detection, detections.predicted)
385 - points.append(DetectionOperatingPoint(detection.confidence, 0)); 425 + resolvedDetections.append(ResolvedDetection(detection.confidence, 0));
386 for (int i=0; i<detections.truth.size(); i++) 426 for (int i=0; i<detections.truth.size(); i++)
387 - points.append(DetectionOperatingPoint(-std::numeric_limits<float>::max(), 0)); 427 + falseNegativeDetections.append(ResolvedDetection(-std::numeric_limits<float>::max(), 0));
388 } 428 }
389 429
390 - std::sort(points.begin(), points.end()); 430 + std::sort(resolvedDetections.begin(), resolvedDetections.end());
391 431
392 QStringList lines; 432 QStringList lines;
393 lines.append("Plot, X, Y"); 433 lines.append("Plot, X, Y");
  434 + lines.append(computeDetectionResults(resolvedDetections, truth.size(), true));
  435 + lines.append(computeDetectionResults(resolvedDetections, truth.size(), false));
  436 +
  437 + float averageOverlap;
  438 + { // Overlap Density
  439 + QList<ResolvedDetection> allDetections; allDetections << resolvedDetections << falseNegativeDetections;
  440 + const int keep = qMin(allDetections.size(), Max_Points);
  441 + lines.reserve(lines.size() + keep);
  442 + float totalOverlap = 0;
  443 + for (int i=0; i<keep; i++) {
  444 + const float overlap = allDetections[double(i) / double(keep-1) * double(allDetections.size()-1)].overlap;
  445 + totalOverlap += overlap;
  446 + lines.append(QString("Overlap,%1,1").arg(QString::number(allDetections[double(i) / double(keep-1) * double(allDetections.size()-1)].overlap)));
  447 + }
  448 + averageOverlap = totalOverlap / keep;
  449 + }
394 450
395 - // TODO: finish implementing  
396 -  
397 - (void) csv;  
398 - return 0; 451 + QtUtils::writeFile(csv, lines);
  452 + qDebug("Average Overlap = %.3f", averageOverlap);
  453 + return averageOverlap;
399 } 454 }
400 455
401 void EvalRegression(const QString &predictedInput, const QString &truthInput) 456 void EvalRegression(const QString &predictedInput, const QString &truthInput)
openbr/core/qtutils.cpp
@@ -105,6 +105,7 @@ void QtUtils::writeFile(const QString &amp;file, const QString &amp;data) @@ -105,6 +105,7 @@ void QtUtils::writeFile(const QString &amp;file, const QString &amp;data)
105 105
106 void QtUtils::writeFile(const QString &file, const QByteArray &data, int compression) 106 void QtUtils::writeFile(const QString &file, const QByteArray &data, int compression)
107 { 107 {
  108 + if (file.isEmpty()) return;
108 const QString baseName = QFileInfo(file).baseName(); 109 const QString baseName = QFileInfo(file).baseName();
109 const QByteArray contents = (compression == 0) ? data : qCompress(data, compression); 110 const QByteArray contents = (compression == 0) ? data : qCompress(data, compression);
110 if (baseName == "terminal") { 111 if (baseName == "terminal") {
openbr/openbr.cpp
@@ -82,9 +82,9 @@ void br_eval_clustering(const char *csv, const char *gallery) @@ -82,9 +82,9 @@ void br_eval_clustering(const char *csv, const char *gallery)
82 EvalClustering(csv, gallery); 82 EvalClustering(csv, gallery);
83 } 83 }
84 84
85 -void br_eval_detection(const char *predicted_gallery, const char *truth_gallery) 85 +float br_eval_detection(const char *predicted_gallery, const char *truth_gallery, const char *csv)
86 { 86 {
87 - EvalDetection(predicted_gallery, truth_gallery); 87 + return EvalDetection(predicted_gallery, truth_gallery, csv);
88 } 88 }
89 89
90 void br_eval_regression(const char *predicted_gallery, const char *truth_gallery) 90 void br_eval_regression(const char *predicted_gallery, const char *truth_gallery)
openbr/openbr.h
@@ -148,7 +148,6 @@ BR_EXPORT float br_eval(const char *simmat, const char *mask, const char *csv = @@ -148,7 +148,6 @@ BR_EXPORT float br_eval(const char *simmat, const char *mask, const char *csv =
148 * \brief Evaluates and prints classification accuracy to terminal. 148 * \brief Evaluates and prints classification accuracy to terminal.
149 * \param predicted_input The predicted br::Gallery. 149 * \param predicted_input The predicted br::Gallery.
150 * \param truth_input The ground truth br::Gallery. 150 * \param truth_input The ground truth br::Gallery.
151 - * \see br_enroll  
152 */ 151 */
153 BR_EXPORT void br_eval_classification(const char *predicted_gallery, const char *truth_gallery); 152 BR_EXPORT void br_eval_classification(const char *predicted_gallery, const char *truth_gallery);
154 153
@@ -164,15 +163,15 @@ BR_EXPORT void br_eval_clustering(const char *csv, const char *gallery); @@ -164,15 +163,15 @@ BR_EXPORT void br_eval_clustering(const char *csv, const char *gallery);
164 * \brief Evaluates and prints detection accuracy to terminal. 163 * \brief Evaluates and prints detection accuracy to terminal.
165 * \param predicted_gallery The predicted br::Gallery. 164 * \param predicted_gallery The predicted br::Gallery.
166 * \param truth_galery The ground truth br::Gallery. 165 * \param truth_galery The ground truth br::Gallery.
167 - * \see br_enroll 166 + * \param csv Optional \c .csv file to contain performance metrics.
  167 + * \return Average detection bounding box overlap.
168 */ 168 */
169 -BR_EXPORT void br_eval_detection(const char *predicted_gallery, const char *truth_gallery); 169 +BR_EXPORT float br_eval_detection(const char *predicted_gallery, const char *truth_gallery, const char *csv = "");
170 170
171 /*! 171 /*!
172 * \brief Evaluates regression accuracy to disk. 172 * \brief Evaluates regression accuracy to disk.
173 * \param predicted_input The predicted br::Gallery. 173 * \param predicted_input The predicted br::Gallery.
174 * \param truth_input The ground truth br::Gallery. 174 * \param truth_input The ground truth br::Gallery.
175 - * \see br_enroll  
176 */ 175 */
177 BR_EXPORT void br_eval_regression(const char *predicted_gallery, const char *truth_gallery); 176 BR_EXPORT void br_eval_regression(const char *predicted_gallery, const char *truth_gallery);
178 177