Commit 07e6c737ed3827b9627b840b9541a28ec47c8c60
1 parent
cc1536f1
first draft of evalDetection complete
Showing
5 changed files
with
84 additions
and
29 deletions
app/br/br.cpp
| ... | ... | @@ -135,8 +135,8 @@ public: |
| 135 | 135 | check(parc == 2, "Incorrect parameter count for 'evalClustering'."); |
| 136 | 136 | br_eval_clustering(parv[0], parv[1]); |
| 137 | 137 | } else if (!strcmp(fun, "evalDetection")) { |
| 138 | - check(parc == 2, "Incorrect parameter count for 'evalDetection'."); | |
| 139 | - br_eval_detection(parv[0], parv[1]); | |
| 138 | + check((parc >= 2) && (parc <= 3), "Incorrect parameter count for 'evalDetection'."); | |
| 139 | + br_eval_detection(parv[0], parv[1], parc == 3 ? parv[2] : ""); | |
| 140 | 140 | } else if (!strcmp(fun, "evalRegression")) { |
| 141 | 141 | check(parc == 2, "Incorrect parameter count for 'evalRegression'."); |
| 142 | 142 | br_eval_regression(parv[0], parv[1]); |
| ... | ... | @@ -214,7 +214,7 @@ private: |
| 214 | 214 | "-convert (Format|Gallery|Output) <input_file> {output_file}\n" |
| 215 | 215 | "-evalClassification <predicted_gallery> <truth_gallery>\n" |
| 216 | 216 | "-evalClustering <clusters> <gallery>\n" |
| 217 | - "-evalDetection <predicted_gallery> <truth_gallery>\n" | |
| 217 | + "-evalDetection <predicted_gallery> <truth_gallery> [{csv}]\n" | |
| 218 | 218 | "-evalRegression <predicted_gallery> <truth_gallery>\n" |
| 219 | 219 | "-plotMetadata <file> ... <file> <columns>\n" |
| 220 | 220 | "-getHeader <matrix>\n" | ... | ... |
openbr/core/eval.cpp
| ... | ... | @@ -23,6 +23,8 @@ using namespace cv; |
| 23 | 23 | namespace br |
| 24 | 24 | { |
| 25 | 25 | |
| 26 | +static const int Max_Points = 500; // Maximum number of points to render on plots | |
| 27 | + | |
| 26 | 28 | struct Comparison |
| 27 | 29 | { |
| 28 | 30 | float score; |
| ... | ... | @@ -100,7 +102,6 @@ float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv) |
| 100 | 102 | qFatal("Similarity matrix (%ix%i) differs in size from mask matrix (%ix%i).", |
| 101 | 103 | simmat.rows, simmat.cols, mask.rows, mask.cols); |
| 102 | 104 | |
| 103 | - const int Max_Points = 500; | |
| 104 | 105 | float result = -1; |
| 105 | 106 | |
| 106 | 107 | // Make comparisons |
| ... | ... | @@ -237,7 +238,7 @@ float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv) |
| 237 | 238 | if (i == Report_Retrieval) reportRetrievalRate = retrievalRate; |
| 238 | 239 | } |
| 239 | 240 | |
| 240 | - if (!csv.isEmpty()) QtUtils::writeFile(csv, lines); | |
| 241 | + QtUtils::writeFile(csv, lines); | |
| 241 | 242 | qDebug("TAR @ FAR = 0.01: %.3f\nRetrieval Rate @ Rank = %d: %.3f", result, Report_Retrieval, reportRetrievalRate); |
| 242 | 243 | return result; |
| 243 | 244 | } |
| ... | ... | @@ -325,7 +326,7 @@ struct Detection |
| 325 | 326 | float overlap(const Detection &other) const |
| 326 | 327 | { |
| 327 | 328 | const Detection intersection(boundingBox.intersected(other.boundingBox)); |
| 328 | - return intersection.area() / (area() + other.area() - 2*intersection.area()); | |
| 329 | + return intersection.area() / (area() + other.area() - intersection.area()); | |
| 329 | 330 | } |
| 330 | 331 | }; |
| 331 | 332 | |
| ... | ... | @@ -334,14 +335,53 @@ struct Detections |
| 334 | 335 | QList<Detection> predicted, truth; |
| 335 | 336 | }; |
| 336 | 337 | |
| 337 | -struct DetectionOperatingPoint | |
| 338 | +struct ResolvedDetection | |
| 338 | 339 | { |
| 339 | 340 | float confidence, overlap; |
| 340 | - DetectionOperatingPoint() : confidence(-1), overlap(-1) {} | |
| 341 | - DetectionOperatingPoint(float confidence_, float overlap_) : confidence(confidence_), overlap(overlap_) {} | |
| 342 | - inline bool operator<(const DetectionOperatingPoint &other) const { return confidence > other.confidence; } | |
| 341 | + ResolvedDetection() : confidence(-1), overlap(-1) {} | |
| 342 | + ResolvedDetection(float confidence_, float overlap_) : confidence(confidence_), overlap(overlap_) {} | |
| 343 | + inline bool operator<(const ResolvedDetection &other) const { return confidence > other.confidence; } | |
| 344 | +}; | |
| 345 | + | |
| 346 | +struct DetectionOperatingPoint | |
| 347 | +{ | |
| 348 | + float Recall, FalsePositives, Precision; | |
| 349 | + DetectionOperatingPoint() : Recall(-1), FalsePositives(-1), Precision(-1) {} | |
| 350 | + DetectionOperatingPoint(float TP, float FP, float totalPositives) | |
| 351 | + : Recall(TP/totalPositives), FalsePositives(FP), Precision(TP/(TP+FP)) {} | |
| 343 | 352 | }; |
| 344 | 353 | |
| 354 | +static QStringList computeDetectionResults(const QList<ResolvedDetection> &detections, int totalPositives, bool discrete) | |
| 355 | +{ | |
| 356 | + QList<DetectionOperatingPoint> points; | |
| 357 | + float TP = 0, FP = 0, prevFP = 0; | |
| 358 | + for (int i=0; i<detections.size(); i++) { | |
| 359 | + const ResolvedDetection &detection = detections[i]; | |
| 360 | + if (discrete) { | |
| 361 | + if (detection.overlap >= 0.5) TP++; | |
| 362 | + else FP++; | |
| 363 | + } else { | |
| 364 | + TP += detection.overlap; | |
| 365 | + FP += 1 - detection.overlap; | |
| 366 | + } | |
| 367 | + if ((i == detections.size()-1) || (detection.confidence > detections[i+1].confidence)) { | |
| 368 | + if (FP > prevFP) { | |
| 369 | + points.append(DetectionOperatingPoint(TP, FP, totalPositives)); | |
| 370 | + prevFP = FP; | |
| 371 | + } | |
| 372 | + } | |
| 373 | + } | |
| 374 | + | |
| 375 | + const int keep = qMin(points.size(), Max_Points); | |
| 376 | + QStringList lines; lines.reserve(keep); | |
| 377 | + for (int i=0; i<keep; i++) { | |
| 378 | + const DetectionOperatingPoint &point = points[double(i) / double(keep-1) * double(points.size()-1)]; | |
| 379 | + lines.append(QString("%1ROC, %2, %3").arg(discrete ? "Discrete" : "Continuous", QString::number(point.FalsePositives), QString::number(point.Recall))); | |
| 380 | + lines.append(QString("%1PR, %2, %3").arg(discrete ? "Discrete" : "Continuous", QString::number(point.Precision), QString::number(point.Recall))); | |
| 381 | + } | |
| 382 | + return lines; | |
| 383 | +} | |
| 384 | + | |
| 345 | 385 | float EvalDetection(const QString &predictedInput, const QString &truthInput, const QString &csv) |
| 346 | 386 | { |
| 347 | 387 | qDebug("Evaluating detection of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); |
| ... | ... | @@ -358,18 +398,18 @@ float EvalDetection(const QString &predictedInput, const QString &truthInput, co |
| 358 | 398 | if (detectKey.isNull()) qFatal("No suitable metadata key found."); |
| 359 | 399 | else qDebug("Using metadata key: %s", qPrintable(detectKey)); |
| 360 | 400 | |
| 361 | - QHash<QString, Detections> allDetections; // Organized by file | |
| 401 | + QMap<QString, Detections> allDetections; // Organized by file, QMap used to preserve order | |
| 362 | 402 | foreach (const Template &t, predicted) |
| 363 | 403 | allDetections[t.file.baseName()].predicted.append(Detection(t.file.get<QRectF>(detectKey), t.file.get<float>("Confidence", -1))); |
| 364 | 404 | foreach (const Template &t, truth) |
| 365 | 405 | allDetections[t.file.baseName()].truth.append(Detection(t.file.get<QRectF>(detectKey))); |
| 366 | 406 | |
| 367 | - QList<DetectionOperatingPoint> points; | |
| 407 | + QList<ResolvedDetection> resolvedDetections, falseNegativeDetections; | |
| 368 | 408 | foreach (Detections detections, allDetections.values()) { |
| 369 | 409 | while (!detections.truth.isEmpty() && !detections.predicted.isEmpty()) { |
| 370 | - Detection truth = detections.truth.takeFirst(); | |
| 410 | + const Detection truth = detections.truth.takeFirst(); | |
| 371 | 411 | int bestIndex = -1; |
| 372 | - float bestOverlap = -1; | |
| 412 | + float bestOverlap = -std::numeric_limits<float>::max(); | |
| 373 | 413 | for (int i=0; i<detections.predicted.size(); i++) { |
| 374 | 414 | const float overlap = truth.overlap(detections.predicted[i]); |
| 375 | 415 | if (overlap > bestOverlap) { |
| ... | ... | @@ -377,25 +417,40 @@ float EvalDetection(const QString &predictedInput, const QString &truthInput, co |
| 377 | 417 | bestIndex = i; |
| 378 | 418 | } |
| 379 | 419 | } |
| 380 | - Detection predicted = detections.predicted.takeAt(bestIndex); | |
| 381 | - points.append(DetectionOperatingPoint(predicted.confidence, bestOverlap)); | |
| 420 | + const Detection predicted = detections.predicted.takeAt(bestIndex); | |
| 421 | + resolvedDetections.append(ResolvedDetection(predicted.confidence, bestOverlap)); | |
| 382 | 422 | } |
| 383 | 423 | |
| 384 | 424 | foreach (const Detection &detection, detections.predicted) |
| 385 | - points.append(DetectionOperatingPoint(detection.confidence, 0)); | |
| 425 | + resolvedDetections.append(ResolvedDetection(detection.confidence, 0)); | |
| 386 | 426 | for (int i=0; i<detections.truth.size(); i++) |
| 387 | - points.append(DetectionOperatingPoint(-std::numeric_limits<float>::max(), 0)); | |
| 427 | + falseNegativeDetections.append(ResolvedDetection(-std::numeric_limits<float>::max(), 0)); | |
| 388 | 428 | } |
| 389 | 429 | |
| 390 | - std::sort(points.begin(), points.end()); | |
| 430 | + std::sort(resolvedDetections.begin(), resolvedDetections.end()); | |
| 391 | 431 | |
| 392 | 432 | QStringList lines; |
| 393 | 433 | lines.append("Plot, X, Y"); |
| 434 | + lines.append(computeDetectionResults(resolvedDetections, truth.size(), true)); | |
| 435 | + lines.append(computeDetectionResults(resolvedDetections, truth.size(), false)); | |
| 436 | + | |
| 437 | + float averageOverlap; | |
| 438 | + { // Overlap Density | |
| 439 | + QList<ResolvedDetection> allDetections; allDetections << resolvedDetections << falseNegativeDetections; | |
| 440 | + const int keep = qMin(allDetections.size(), Max_Points); | |
| 441 | + lines.reserve(lines.size() + keep); | |
| 442 | + float totalOverlap = 0; | |
| 443 | + for (int i=0; i<keep; i++) { | |
| 444 | + const float overlap = allDetections[double(i) / double(keep-1) * double(allDetections.size()-1)].overlap; | |
| 445 | + totalOverlap += overlap; | |
| 446 | + lines.append(QString("Overlap,%1,1").arg(QString::number(allDetections[double(i) / double(keep-1) * double(allDetections.size()-1)].overlap))); | |
| 447 | + } | |
| 448 | + averageOverlap = totalOverlap / keep; | |
| 449 | + } | |
| 394 | 450 | |
| 395 | - // TODO: finish implementing | |
| 396 | - | |
| 397 | - (void) csv; | |
| 398 | - return 0; | |
| 451 | + QtUtils::writeFile(csv, lines); | |
| 452 | + qDebug("Average Overlap = %.3f", averageOverlap); | |
| 453 | + return averageOverlap; | |
| 399 | 454 | } |
| 400 | 455 | |
| 401 | 456 | void EvalRegression(const QString &predictedInput, const QString &truthInput) | ... | ... |
openbr/core/qtutils.cpp
| ... | ... | @@ -105,6 +105,7 @@ void QtUtils::writeFile(const QString &file, const QString &data) |
| 105 | 105 | |
| 106 | 106 | void QtUtils::writeFile(const QString &file, const QByteArray &data, int compression) |
| 107 | 107 | { |
| 108 | + if (file.isEmpty()) return; | |
| 108 | 109 | const QString baseName = QFileInfo(file).baseName(); |
| 109 | 110 | const QByteArray contents = (compression == 0) ? data : qCompress(data, compression); |
| 110 | 111 | if (baseName == "terminal") { | ... | ... |
openbr/openbr.cpp
| ... | ... | @@ -82,9 +82,9 @@ void br_eval_clustering(const char *csv, const char *gallery) |
| 82 | 82 | EvalClustering(csv, gallery); |
| 83 | 83 | } |
| 84 | 84 | |
| 85 | -void br_eval_detection(const char *predicted_gallery, const char *truth_gallery) | |
| 85 | +float br_eval_detection(const char *predicted_gallery, const char *truth_gallery, const char *csv) | |
| 86 | 86 | { |
| 87 | - EvalDetection(predicted_gallery, truth_gallery); | |
| 87 | + return EvalDetection(predicted_gallery, truth_gallery, csv); | |
| 88 | 88 | } |
| 89 | 89 | |
| 90 | 90 | void br_eval_regression(const char *predicted_gallery, const char *truth_gallery) | ... | ... |
openbr/openbr.h
| ... | ... | @@ -148,7 +148,6 @@ BR_EXPORT float br_eval(const char *simmat, const char *mask, const char *csv = |
| 148 | 148 | * \brief Evaluates and prints classification accuracy to terminal. |
| 149 | 149 | * \param predicted_input The predicted br::Gallery. |
| 150 | 150 | * \param truth_input The ground truth br::Gallery. |
| 151 | - * \see br_enroll | |
| 152 | 151 | */ |
| 153 | 152 | BR_EXPORT void br_eval_classification(const char *predicted_gallery, const char *truth_gallery); |
| 154 | 153 | |
| ... | ... | @@ -164,15 +163,15 @@ BR_EXPORT void br_eval_clustering(const char *csv, const char *gallery); |
| 164 | 163 | * \brief Evaluates and prints detection accuracy to terminal. |
| 165 | 164 | * \param predicted_gallery The predicted br::Gallery. |
| 166 | 165 | * \param truth_galery The ground truth br::Gallery. |
| 167 | - * \see br_enroll | |
| 166 | + * \param csv Optional \c .csv file to contain performance metrics. | |
| 167 | + * \return Average detection bounding box overlap. | |
| 168 | 168 | */ |
| 169 | -BR_EXPORT void br_eval_detection(const char *predicted_gallery, const char *truth_gallery); | |
| 169 | +BR_EXPORT float br_eval_detection(const char *predicted_gallery, const char *truth_gallery, const char *csv = ""); | |
| 170 | 170 | |
| 171 | 171 | /*! |
| 172 | 172 | * \brief Evaluates regression accuracy to disk. |
| 173 | 173 | * \param predicted_input The predicted br::Gallery. |
| 174 | 174 | * \param truth_input The ground truth br::Gallery. |
| 175 | - * \see br_enroll | |
| 176 | 175 | */ |
| 177 | 176 | BR_EXPORT void br_eval_regression(const char *predicted_gallery, const char *truth_gallery); |
| 178 | 177 | ... | ... |