diff --git a/.gitignore b/.gitignore index fc31dcc..6baacca 100644 --- a/.gitignore +++ b/.gitignore @@ -29,3 +29,6 @@ scripts/results ### OS X ### *.DS_Store + +### vim ### +*.swp diff --git a/app/br/br.cpp b/app/br/br.cpp index 133cd3c..7528393 100644 --- a/app/br/br.cpp +++ b/app/br/br.cpp @@ -140,13 +140,6 @@ public: } else if (!strcmp(fun, "evalRegression")) { check(parc == 2, "Incorrect parameter count for 'evalRegression'."); br_eval_regression(parv[0], parv[1]); - } else if (!strcmp(fun, "confusion")) { - check(parc == 2, "Incorrect parameter count for 'confusion'."); - int true_positives, false_positives, true_negatives, false_negatives; - br_confusion(parv[0], atof(parv[1]), - &true_positives, &false_positives, &true_negatives, &false_negatives); - printf("True Positives = %d\nFalse Positives = %d\nTrue Negatives = %d\nFalseNegatives = %d\n", - true_positives, false_positives, true_negatives, false_negatives); } else if (!strcmp(fun, "plotMetadata")) { check(parc >= 2, "Incorrect parameter count for 'plotMetadata'."); br_plot_metadata(parc-1, parv, parv[parc-1], true); @@ -223,7 +216,6 @@ private: "-evalClustering \n" "-evalDetection \n" "-evalRegression \n" - "-confusion \n" "-plotMetadata ... \n" "-getHeader \n" "-setHeader {} \n" diff --git a/data/KTH/README.md b/data/KTH/README.md new file mode 100644 index 0000000..a7493e7 --- /dev/null +++ b/data/KTH/README.md @@ -0,0 +1,3 @@ +## KTH Human Action Database +Grayscale human action videos. Six actions performed by 25 subjects in four scenarios, for a total of 600 160x120 videos. +* [Website](http://www.nada.kth.se/cvap/actions/) diff --git a/data/README.md b/data/README.md index 7f6bebc..6feb9f5 100644 --- a/data/README.md +++ b/data/README.md @@ -9,6 +9,7 @@ * [MEDS](MEDS/README.md) * [MNIST](MNIST/README.md) * [PCSO](PCSO/README.md) +* [KTH](KTH/README.md) For both practical and legal reasons we only include images for some of the datasets in this repository. Researchers should contact the respective owners of the other datasets in order to obtain a copy. diff --git a/openbr/core/bee.cpp b/openbr/core/bee.cpp index 18c498b..d9f320b 100644 --- a/openbr/core/bee.cpp +++ b/openbr/core/bee.cpp @@ -268,8 +268,8 @@ cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, { // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet // -cao - QList targetLabels = targets.get("Subject", "-1"); - QList queryLabels = queries.get("Subject", "-1"); + QList targetLabels = File::get(targets, "Subject", "-1"); + QList queryLabels = File::get(queries, "Subject", "-1"); QList targetPartitions = targets.crossValidationPartitions(); QList queryPartitions = queries.crossValidationPartitions(); diff --git a/openbr/core/classify.cpp b/openbr/core/classify.cpp deleted file mode 100644 index ea816be..0000000 --- a/openbr/core/classify.cpp +++ /dev/null @@ -1,134 +0,0 @@ -/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * - * Copyright 2012 The MITRE Corporation * - * * - * Licensed under the Apache License, Version 2.0 (the "License"); * - * you may not use this file except in compliance with the License. * - * You may obtain a copy of the License at * - * * - * http://www.apache.org/licenses/LICENSE-2.0 * - * * - * Unless required by applicable law or agreed to in writing, software * - * distributed under the License is distributed on an "AS IS" BASIS, * - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * - * See the License for the specific language governing permissions and * - * limitations under the License. * - * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ - -#include - -#include "classify.h" -#include "openbr/core/qtutils.h" - -// Helper struct for statistics accumulation -struct Counter -{ - float truePositive, falsePositive, falseNegative; - Counter() - { - truePositive = 0; - falsePositive = 0; - falseNegative = 0; - } -}; - -void br::EvalClassification(const QString &predictedInput, const QString &truthInput) -{ - qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); - - TemplateList predicted(TemplateList::fromGallery(predictedInput)); - TemplateList truth(TemplateList::fromGallery(truthInput)); - if (predicted.size() != truth.size()) qFatal("Input size mismatch."); - - QHash counters; - for (int i=0; i("Subject"); - QString trueSubject = truth[i].file.get("Subject"); - - QStringList predictedSubjects(predictedSubject); - QStringList trueSubjects(trueSubject); - - foreach (const QString &subject, trueSubjects.toVector() /* Hack to copy the list. */) { - if (predictedSubjects.contains(subject)) { - counters[subject].truePositive++; - trueSubjects.removeOne(subject); - predictedSubjects.removeOne(subject); - } else { - counters[subject].falseNegative++; - } - } - - for (int i=0; i output(Output::make("", FileList() << "Count" << "Precision" << "Recall" << "F-score", FileList(keys))); - - int tpc = 0; - int fnc = 0; - - for (int i=0; isetRelative(count, i, 0); - output->setRelative(precision, i, 1); - output->setRelative(recall, i, 2); - output->setRelative(fscore, i, 3); - } - - qDebug("Overall Accuracy = %f", (float)tpc / (float)(tpc + fnc)); -} - -void br::EvalDetection(const QString &predictedInput, const QString &truthInput) -{ - (void) predictedInput; - (void) truthInput; -} - -void br::EvalRegression(const QString &predictedInput, const QString &truthInput) -{ - qDebug("Evaluating regression of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); - - const TemplateList predicted(TemplateList::fromGallery(predictedInput)); - const TemplateList truth(TemplateList::fromGallery(truthInput)); - if (predicted.size() != truth.size()) qFatal("Input size mismatch."); - - float rmsError = 0; - QStringList truthValues, predictedValues; - for (int i=0; i("Subject")-truth[i].file.get("Subject"), 2.f); - truthValues.append(QString::number(truth[i].file.get("Subject"))); - predictedValues.append(QString::number(predicted[i].file.get("Subject"))); - } - - QStringList rSource; - rSource << "# Load libraries" << "library(ggplot2)" << "" << "# Set Data" - << "Actual <- c(" + truthValues.join(",") + ")" - << "Predicted <- c(" + predictedValues.join(",") + ")" - << "data <- data.frame(Actual, Predicted)" - << "" << "# Construct Plot" << "pdf(\"EvalRegression.pdf\")" - << "print(qplot(Actual, Predicted, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=1, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())" - << "print(qplot(Actual, Predicted-Actual, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=0, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())" - << "dev.off()"; - - - QString rFile = "EvalRegression.R"; - QtUtils::writeFile(rFile, rSource); - bool success = QtUtils::runRScript(rFile); - if (success) QtUtils::showFile("EvalRegression.pdf"); - - qDebug("RMS Error = %f", sqrt(rmsError/predicted.size())); -} diff --git a/openbr/core/cluster.cpp b/openbr/core/cluster.cpp index fdff3d8..6a6b954 100644 --- a/openbr/core/cluster.cpp +++ b/openbr/core/cluster.cpp @@ -280,7 +280,7 @@ void br::EvalClustering(const QString &csv, const QString &input) // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are // not named). - QList labels = TemplateList::fromGallery(input).files().get("Subject"); + QList labels = File::get(TemplateList::fromGallery(input), "Subject"); QHash labelToIndex; int nClusters = 0; diff --git a/openbr/core/eval.cpp b/openbr/core/eval.cpp new file mode 100644 index 0000000..2f13e66 --- /dev/null +++ b/openbr/core/eval.cpp @@ -0,0 +1,437 @@ +/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * + * Copyright 2012 The MITRE Corporation * + * * + * Licensed under the Apache License, Version 2.0 (the "License"); * + * you may not use this file except in compliance with the License. * + * You may obtain a copy of the License at * + * * + * http://www.apache.org/licenses/LICENSE-2.0 * + * * + * Unless required by applicable law or agreed to in writing, software * + * distributed under the License is distributed on an "AS IS" BASIS, * + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * + * See the License for the specific language governing permissions and * + * limitations under the License. * + * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ + +#include "bee.h" +#include "eval.h" +#include "openbr/core/qtutils.h" + +using namespace cv; + +namespace br +{ + +struct Comparison +{ + float score; + int target, query; + bool genuine; + + Comparison() {} + Comparison(float _score, int _target, int _query, bool _genuine) + : score(_score), target(_target), query(_query), genuine(_genuine) {} + inline bool operator<(const Comparison &other) const { return score > other.score; } +}; + +#undef FAR // Windows preprecessor definition conflicts with variable name +struct OperatingPoint +{ + float score, FAR, TAR; + OperatingPoint() {} + OperatingPoint(float _score, float _FAR, float _TAR) + : score(_score), FAR(_FAR), TAR(_TAR) {} +}; + +static float getTAR(const QList &operatingPoints, float FAR) +{ + int index = 0; + while (operatingPoints[index].FAR < FAR) { + index++; + if (index == operatingPoints.size()) + return 1; + } + + const float x1 = (index == 0 ? 0 : operatingPoints[index-1].FAR); + const float y1 = (index == 0 ? 0 : operatingPoints[index-1].TAR); + const float x2 = operatingPoints[index].FAR; + const float y2 = operatingPoints[index].TAR; + const float m = (y2 - y1) / (x2 - x1); + const float b = y1 - m*x1; + return m * FAR + b; +} + +float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv, int partition) +{ + return Evaluate(scores, BEE::makeMask(target, query, partition), csv); +} + +float Evaluate(const QString &simmat, const QString &mask, const QString &csv) +{ + qDebug("Evaluating %s%s%s", + qPrintable(simmat), + mask.isEmpty() ? "" : qPrintable(" with " + mask), + csv.isEmpty() ? "" : qPrintable(" to " + csv)); + + // Read similarity matrix + QString target, query; + const Mat scores = BEE::readSimmat(simmat, &target, &query); + + // Read mask matrix + Mat truth; + if (mask.isEmpty()) { + // Use the galleries specified in the similarity matrix + truth = BEE::makeMask(TemplateList::fromGallery(target).files(), + TemplateList::fromGallery(query).files()); + } else { + File maskFile(mask); + maskFile.set("rows", scores.rows); + maskFile.set("columns", scores.cols); + truth = BEE::readMask(maskFile); + } + + return Evaluate(scores, truth, csv); +} + +float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv) +{ + if (simmat.size() != mask.size()) + qFatal("Similarity matrix (%ix%i) differs in size from mask matrix (%ix%i).", + simmat.rows, simmat.cols, mask.rows, mask.cols); + + const int Max_Points = 500; + float result = -1; + + // Make comparisons + QList comparisons; comparisons.reserve(simmat.rows*simmat.cols); + int genuineCount = 0, impostorCount = 0, numNaNs = 0; + for (int i=0; i(i,j); + const BEE::Simmat_t simmat_val = simmat.at(i,j); + if (mask_val == BEE::DontCare) continue; + if (simmat_val != simmat_val) { numNaNs++; continue; } + comparisons.append(Comparison(simmat_val, j, i, mask_val == BEE::Match)); + if (comparisons.last().genuine) genuineCount++; + else impostorCount++; + } + } + + if (numNaNs > 0) qWarning("Encountered %d NaN scores!", numNaNs); + if (genuineCount == 0) qFatal("No genuine scores!"); + if (impostorCount == 0) qFatal("No impostor scores!"); + + // Sort comparisons by simmat_val (score) + std::sort(comparisons.begin(), comparisons.end()); + + QList operatingPoints; + QList genuines; genuines.reserve(sqrt((float)comparisons.size())); + QList impostors; impostors.reserve(comparisons.size()); + QVector firstGenuineReturns(simmat.rows, 0); + + int falsePositives = 0, previousFalsePositives = 0; + int truePositives = 0, previousTruePositives = 0; + int index = 0; + float minGenuineScore = std::numeric_limits::max(); + float minImpostorScore = std::numeric_limits::max(); + + while (index < comparisons.size()) { + float thresh = comparisons[index].score; + // Compute genuine and imposter statistics at a threshold + while ((index < comparisons.size()) && + (comparisons[index].score == thresh)) { + const Comparison &comparison = comparisons[index]; + if (comparison.genuine) { + truePositives++; + genuines.append(comparison.score); + if (firstGenuineReturns[comparison.query] < 1) + firstGenuineReturns[comparison.query] = abs(firstGenuineReturns[comparison.query]) + 1; + if ((comparison.score != -std::numeric_limits::max()) && + (comparison.score < minGenuineScore)) + minGenuineScore = comparison.score; + } else { + falsePositives++; + impostors.append(comparison.score); + if (firstGenuineReturns[comparison.query] < 1) + firstGenuineReturns[comparison.query]--; + if ((comparison.score != -std::numeric_limits::max()) && + (comparison.score < minImpostorScore)) + minImpostorScore = comparison.score; + } + index++; + } + + if ((falsePositives > previousFalsePositives) && + (truePositives > previousTruePositives)) { + // Restrict the extreme ends of the curve + if ((truePositives >= 10) && (falsePositives < impostorCount/2)) + operatingPoints.append(OperatingPoint(thresh, float(falsePositives)/impostorCount, float(truePositives)/genuineCount)); + previousFalsePositives = falsePositives; + previousTruePositives = truePositives; + } + } + + if (operatingPoints.size() == 0) operatingPoints.append(OperatingPoint(1, 1, 1)); + if (operatingPoints.size() == 1) operatingPoints.prepend(OperatingPoint(0, 0, 0)); + if (operatingPoints.size() > 2) operatingPoints.takeLast(); // Remove point (1,1) + + // Write Metadata table + QStringList lines; + lines.append("Plot,X,Y"); + lines.append("Metadata,"+QString::number(simmat.cols)+",Gallery"); + lines.append("Metadata,"+QString::number(simmat.rows)+",Probe"); + lines.append("Metadata,"+QString::number(genuineCount)+",Genuine"); + lines.append("Metadata,"+QString::number(impostorCount)+",Impostor"); + lines.append("Metadata,"+QString::number(simmat.cols*simmat.rows-(genuineCount+impostorCount))+",Ignored"); + + // Write Detection Error Tradeoff (DET), PRE, REC + int points = qMin(operatingPoints.size(), Max_Points); + for (int i=0; i sampledGenuineScores; sampledGenuineScores.reserve(points); + QList sampledImpostorScores; sampledImpostorScores.reserve(points); + + if (points > 1) { + for (int i=0; i::max()) genuineScore = minGenuineScore; + if (impostorScore == -std::numeric_limits::max()) impostorScore = minImpostorScore; + lines.append(QString("SD,%1,Genuine").arg(QString::number(genuineScore))); + lines.append(QString("SD,%1,Impostor").arg(QString::number(impostorScore))); + sampledGenuineScores.append(genuineScore); + sampledImpostorScores.append(impostorScore); + } + } + + // Write Cumulative Match Characteristic (CMC) curve + const int Max_Retrieval = 200; + const int Report_Retrieval = 5; + + float reportRetrievalRate = -1; + for (int i=1; i<=Max_Retrieval; i++) { + int realizedReturns = 0, possibleReturns = 0; + foreach (int firstGenuineReturn, firstGenuineReturns) { + if (firstGenuineReturn > 0) { + possibleReturns++; + if (firstGenuineReturn <= i) realizedReturns++; + } + } + const float retrievalRate = float(realizedReturns)/possibleReturns; + lines.append(qPrintable(QString("CMC,%1,%2").arg(QString::number(i), QString::number(retrievalRate)))); + if (i == Report_Retrieval) reportRetrievalRate = retrievalRate; + } + + if (!csv.isEmpty()) QtUtils::writeFile(csv, lines); + qDebug("TAR @ FAR = 0.01: %.3f\nRetrieval Rate @ Rank = %d: %.3f", result, Report_Retrieval, reportRetrievalRate); + return result; +} + +// Helper struct for statistics accumulation +struct Counter +{ + float truePositive, falsePositive, falseNegative; + Counter() + { + truePositive = 0; + falsePositive = 0; + falseNegative = 0; + } +}; + +void EvalClassification(const QString &predictedInput, const QString &truthInput) +{ + qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); + TemplateList predicted(TemplateList::fromGallery(predictedInput)); + TemplateList truth(TemplateList::fromGallery(truthInput)); + if (predicted.size() != truth.size()) qFatal("Input size mismatch."); + + QHash counters; + for (int i=0; i("Subject"); + QString trueSubject = truth[i].file.get("Subject"); + + QStringList predictedSubjects(predictedSubject); + QStringList trueSubjects(trueSubject); + + foreach (const QString &subject, trueSubjects.toVector() /* Hack to copy the list. */) { + if (predictedSubjects.contains(subject)) { + counters[subject].truePositive++; + trueSubjects.removeOne(subject); + predictedSubjects.removeOne(subject); + } else { + counters[subject].falseNegative++; + } + } + + for (int i=0; i output(Output::make("", FileList() << "Count" << "Precision" << "Recall" << "F-score", FileList(keys))); + + int tpc = 0; + int fnc = 0; + + for (int i=0; isetRelative(count, i, 0); + output->setRelative(precision, i, 1); + output->setRelative(recall, i, 2); + output->setRelative(fscore, i, 3); + } + + qDebug("Overall Accuracy = %f", (float)tpc / (float)(tpc + fnc)); +} + +struct Detection +{ + QRectF boundingBox; + float confidence; + + Detection() {} + Detection(const QRectF &boundingBox_, float confidence_ = -1) + : boundingBox(boundingBox_), confidence(confidence_) {} + + float area() const { return boundingBox.width() * boundingBox.height(); } + float overlap(const Detection &other) const + { + const Detection intersection(boundingBox.intersected(other.boundingBox)); + return intersection.area() / (area() + other.area() - 2*intersection.area()); + } +}; + +struct Detections +{ + QList predicted, truth; +}; + +struct DetectionOperatingPoint +{ + float confidence, overlap; + DetectionOperatingPoint() : confidence(-1), overlap(-1) {} + DetectionOperatingPoint(float confidence_, float overlap_) : confidence(confidence_), overlap(overlap_) {} + inline bool operator<(const DetectionOperatingPoint &other) const { return confidence > other.confidence; } +}; + +float EvalDetection(const QString &predictedInput, const QString &truthInput, const QString &csv) +{ + qDebug("Evaluating detection of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); + const TemplateList predicted(TemplateList::fromGallery(predictedInput)); + const TemplateList truth(TemplateList::fromGallery(truthInput)); + + // Figure out which metadata field contains a bounding box + QString detectKey; + foreach (const QString &key, truth.first().file.localKeys()) + if (!truth.first().file.get(key, QRectF()).isNull()) { + detectKey = key; + break; + } + if (detectKey.isNull()) qFatal("No suitable metadata key found."); + else qDebug("Using metadata key: %s", qPrintable(detectKey)); + + QHash allDetections; // Organized by file + foreach (const Template &t, predicted) + allDetections[t.file.baseName()].predicted.append(Detection(t.file.get(detectKey), t.file.get("Confidence", -1))); + foreach (const Template &t, truth) + allDetections[t.file.baseName()].truth.append(Detection(t.file.get(detectKey))); + + QList points; + foreach (Detections detections, allDetections.values()) { + while (!detections.truth.isEmpty() && !detections.predicted.isEmpty()) { + Detection truth = detections.truth.takeFirst(); + int bestIndex = -1; + float bestOverlap = -1; + for (int i=0; i bestOverlap) { + bestOverlap = overlap; + bestIndex = i; + } + } + Detection predicted = detections.predicted.takeAt(bestIndex); + points.append(DetectionOperatingPoint(predicted.confidence, bestOverlap)); + } + + foreach (const Detection &detection, detections.predicted) + points.append(DetectionOperatingPoint(detection.confidence, 0)); + for (int i=0; i::max(), 0)); + } + + std::sort(points.begin(), points.end()); + + QStringList lines; + lines.append("Plot, X, Y"); + + // TODO: finish implementing + + (void) csv; + return 0; +} + +void EvalRegression(const QString &predictedInput, const QString &truthInput) +{ + qDebug("Evaluating regression of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); + const TemplateList predicted(TemplateList::fromGallery(predictedInput)); + const TemplateList truth(TemplateList::fromGallery(truthInput)); + if (predicted.size() != truth.size()) qFatal("Input size mismatch."); + + float rmsError = 0; + QStringList truthValues, predictedValues; + for (int i=0; i("Subject")-truth[i].file.get("Subject"), 2.f); + truthValues.append(QString::number(truth[i].file.get("Subject"))); + predictedValues.append(QString::number(predicted[i].file.get("Subject"))); + } + + QStringList rSource; + rSource << "# Load libraries" << "library(ggplot2)" << "" << "# Set Data" + << "Actual <- c(" + truthValues.join(",") + ")" + << "Predicted <- c(" + predictedValues.join(",") + ")" + << "data <- data.frame(Actual, Predicted)" + << "" << "# Construct Plot" << "pdf(\"EvalRegression.pdf\")" + << "print(qplot(Actual, Predicted, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=1, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())" + << "print(qplot(Actual, Predicted-Actual, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=0, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())" + << "dev.off()"; + + + QString rFile = "EvalRegression.R"; + QtUtils::writeFile(rFile, rSource); + bool success = QtUtils::runRScript(rFile); + if (success) QtUtils::showFile("EvalRegression.pdf"); + + qDebug("RMS Error = %f", sqrt(rmsError/predicted.size())); +} + +} // namespace br diff --git a/openbr/core/classify.h b/openbr/core/eval.h index 67f5340..d90d1a5 100644 --- a/openbr/core/classify.h +++ b/openbr/core/eval.h @@ -14,18 +14,22 @@ * limitations under the License. * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ -#ifndef __CLASSIFY_H -#define __CLASSIFY_H +#ifndef __EVAL_H +#define __EVAL_H #include #include +#include "openbr/openbr_plugin.h" namespace br { + float Evaluate(const QString &simmat, const QString &mask = "", const QString &csv = ""); // Returns TAR @ FAR = 0.001 + float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv = "", int parition = 0); + float Evaluate(const cv::Mat &scores, const cv::Mat &masks, const QString &csv = ""); void EvalClassification(const QString &predictedInput, const QString &truthInput); - void EvalDetection(const QString &predictedInput, const QString &truthInput); + float EvalDetection(const QString &predictedInput, const QString &truthInput, const QString &csv = ""); // Return average overlap void EvalRegression(const QString &predictedInput, const QString &truthInput); } -#endif // __CLASSIFY_H +#endif // __EVAL_H diff --git a/openbr/core/plot.cpp b/openbr/core/plot.cpp index 0ef731c..65e8515 100644 --- a/openbr/core/plot.cpp +++ b/openbr/core/plot.cpp @@ -14,57 +14,15 @@ * limitations under the License. * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - #include "plot.h" #include "version.h" -#include "openbr/core/bee.h" -#include "openbr/core/common.h" -#include "openbr/core/opencvutils.h" #include "openbr/core/qtutils.h" -#undef FAR // Windows preprecessor definition - using namespace cv; namespace br { -void Confusion(const QString &file, float score, int &true_positives, int &false_positives, int &true_negatives, int &false_negatives) -{ - qDebug("Computing confusion matrix of %s at %f", qPrintable(file), score); - - QStringList lines = QtUtils::readLines(file); - true_positives = false_positives = true_negatives = false_negatives = 0; - foreach (const QString &line, lines) { - if (!line.startsWith("SD")) continue; - QStringList words = line.split(","); - bool ok; - float similarity = words[1].toFloat(&ok); assert(ok); - if (words[2] == "Genuine") { - if (similarity >= score) true_positives++; - else false_negatives++; - } else { - if (similarity >= score) false_positives++; - else true_negatives++; - } - } -} - static QStringList getPivots(const QString &file, bool headers) { QString str; @@ -73,224 +31,6 @@ static QStringList getPivots(const QString &file, bool headers) return str.split("_"); } -struct Comparison -{ - float score; - int target, query; - bool genuine; - - Comparison() {} - Comparison(float _score, int _target, int _query, bool _genuine) - : score(_score), target(_target), query(_query), genuine(_genuine) {} - inline bool operator<(const Comparison &other) const { return score > other.score; } -}; - -struct OperatingPoint -{ - float score, FAR, TAR; - OperatingPoint() {} - OperatingPoint(float _score, float _FAR, float _TAR) - : score(_score), FAR(_FAR), TAR(_TAR) {} -}; - -static float getTAR(const QList &operatingPoints, float FAR) -{ - int index = 0; - while (operatingPoints[index].FAR < FAR) { - index++; - if (index == operatingPoints.size()) - return 1; - } - - const float x1 = (index == 0 ? 0 : operatingPoints[index-1].FAR); - const float y1 = (index == 0 ? 0 : operatingPoints[index-1].TAR); - const float x2 = operatingPoints[index].FAR; - const float y2 = operatingPoints[index].TAR; - const float m = (y2 - y1) / (x2 - x1); - const float b = y1 - m*x1; - return m * FAR + b; -} - -float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv, int partition) -{ - return Evaluate(scores, BEE::makeMask(target, query, partition), csv); -} - -float Evaluate(const QString &simmat, const QString &mask, const QString &csv) -{ - qDebug("Evaluating %s%s%s", - qPrintable(simmat), - mask.isEmpty() ? "" : qPrintable(" with " + mask), - csv.isEmpty() ? "" : qPrintable(" to " + csv)); - - // Read similarity matrix - QString target, query; - const Mat scores = BEE::readSimmat(simmat, &target, &query); - - // Read mask matrix - Mat truth; - if (mask.isEmpty()) { - // Use the galleries specified in the similarity matrix - truth = BEE::makeMask(TemplateList::fromGallery(target).files(), - TemplateList::fromGallery(query).files()); - } else { - File maskFile(mask); - maskFile.set("rows", scores.rows); - maskFile.set("columns", scores.cols); - truth = BEE::readMask(maskFile); - } - - return Evaluate(scores, truth, csv); -} - -float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv) -{ - if (simmat.size() != mask.size()) - qFatal("Similarity matrix (%ix%i) differs in size from mask matrix (%ix%i).", - simmat.rows, simmat.cols, mask.rows, mask.cols); - - const int Max_Points = 500; - float result = -1; - - // Make comparisons - QList comparisons; comparisons.reserve(simmat.rows*simmat.cols); - int genuineCount = 0, impostorCount = 0, numNaNs = 0; - for (int i=0; i(i,j); - const BEE::Simmat_t simmat_val = simmat.at(i,j); - if (mask_val == BEE::DontCare) continue; - if (simmat_val != simmat_val) { numNaNs++; continue; } - comparisons.append(Comparison(simmat_val, j, i, mask_val == BEE::Match)); - if (comparisons.last().genuine) genuineCount++; - else impostorCount++; - } - } - - if (numNaNs > 0) qWarning("Encountered %d NaN scores!", numNaNs); - if (genuineCount == 0) qFatal("No genuine scores!"); - if (impostorCount == 0) qFatal("No impostor scores!"); - - // Sort comparisons by simmat_val (score) - std::sort(comparisons.begin(), comparisons.end()); - - QList operatingPoints; - QList genuines; genuines.reserve(sqrt((float)comparisons.size())); - QList impostors; impostors.reserve(comparisons.size()); - QVector firstGenuineReturns(simmat.rows, 0); - - int falsePositives = 0, previousFalsePositives = 0; - int truePositives = 0, previousTruePositives = 0; - int index = 0; - float minGenuineScore = std::numeric_limits::max(); - float minImpostorScore = std::numeric_limits::max(); - - while (index < comparisons.size()) { - float thresh = comparisons[index].score; - // Compute genuine and imposter statistics at a threshold - while ((index < comparisons.size()) && - (comparisons[index].score == thresh)) { - const Comparison &comparison = comparisons[index]; - if (comparison.genuine) { - truePositives++; - genuines.append(comparison.score); - if (firstGenuineReturns[comparison.query] < 1) - firstGenuineReturns[comparison.query] = abs(firstGenuineReturns[comparison.query]) + 1; - if ((comparison.score != -std::numeric_limits::max()) && - (comparison.score < minGenuineScore)) - minGenuineScore = comparison.score; - } else { - falsePositives++; - impostors.append(comparison.score); - if (firstGenuineReturns[comparison.query] < 1) - firstGenuineReturns[comparison.query]--; - if ((comparison.score != -std::numeric_limits::max()) && - (comparison.score < minImpostorScore)) - minImpostorScore = comparison.score; - } - index++; - } - - if ((falsePositives > previousFalsePositives) && - (truePositives > previousTruePositives)) { - // Restrict the extreme ends of the curve - if ((truePositives >= 10) && (falsePositives < impostorCount/2)) - operatingPoints.append(OperatingPoint(thresh, float(falsePositives)/impostorCount, float(truePositives)/genuineCount)); - previousFalsePositives = falsePositives; - previousTruePositives = truePositives; - } - } - - if (operatingPoints.size() == 0) operatingPoints.append(OperatingPoint(1, 1, 1)); - if (operatingPoints.size() == 1) operatingPoints.prepend(OperatingPoint(0, 0, 0)); - if (operatingPoints.size() > 2) operatingPoints.takeLast(); // Remove point (1,1) - - // Write Metadata table - QStringList lines; - lines.append("Plot,X,Y"); - lines.append("Metadata,"+QString::number(simmat.cols)+",Gallery"); - lines.append("Metadata,"+QString::number(simmat.rows)+",Probe"); - lines.append("Metadata,"+QString::number(genuineCount)+",Genuine"); - lines.append("Metadata,"+QString::number(impostorCount)+",Impostor"); - lines.append("Metadata,"+QString::number(simmat.cols*simmat.rows-(genuineCount+impostorCount))+",Ignored"); - - // Write Detection Error Tradeoff (DET), PRE, REC - int points = qMin(operatingPoints.size(), Max_Points); - for (int i=0; i sampledGenuineScores; sampledGenuineScores.reserve(points); - QList sampledImpostorScores; sampledImpostorScores.reserve(points); - - if (points > 1) { - for (int i=0; i::max()) genuineScore = minGenuineScore; - if (impostorScore == -std::numeric_limits::max()) impostorScore = minImpostorScore; - lines.append(QString("SD,%1,Genuine").arg(QString::number(genuineScore))); - lines.append(QString("SD,%1,Impostor").arg(QString::number(impostorScore))); - sampledGenuineScores.append(genuineScore); - sampledImpostorScores.append(impostorScore); - } - } - - // Write Cumulative Match Characteristic (CMC) curve - const int Max_Retrieval = 200; - const int Report_Retrieval = 5; - - float reportRetrievalRate = -1; - for (int i=1; i<=Max_Retrieval; i++) { - int realizedReturns = 0, possibleReturns = 0; - foreach (int firstGenuineReturn, firstGenuineReturns) { - if (firstGenuineReturn > 0) { - possibleReturns++; - if (firstGenuineReturn <= i) realizedReturns++; - } - } - const float retrievalRate = float(realizedReturns)/possibleReturns; - lines.append(qPrintable(QString("CMC,%1,%2").arg(QString::number(i), QString::number(retrievalRate)))); - if (i == Report_Retrieval) reportRetrievalRate = retrievalRate; - } - - if (!csv.isEmpty()) QtUtils::writeFile(csv, lines); - qDebug("TAR @ FAR = 0.01: %.3f\nRetrieval Rate @ Rank = %d: %.3f", result, Report_Retrieval, reportRetrievalRate); - return result; -} - static QString getScale(const QString &mode, const QString &title, int vals) { if (vals > 12) return " + scale_"+mode+"_discrete(\""+title+"\")"; @@ -474,7 +214,6 @@ struct RPlot }; // Does not work if dataset folder starts with a number - bool Plot(const QStringList &files, const br::File &destination, bool show) { qDebug("Plotting %d file(s) to %s", files.size(), qPrintable(destination)); diff --git a/openbr/core/plot.h b/openbr/core/plot.h index d672706..b015b24 100644 --- a/openbr/core/plot.h +++ b/openbr/core/plot.h @@ -24,14 +24,8 @@ namespace br { - -void Confusion(const QString &file, float score, int &true_positives, int &false_positives, int &true_negatives, int &false_negatives); -float Evaluate(const QString &simmat, const QString &mask = "", const QString &csv = ""); // Returns TAR @ FAR = 0.001 -float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv = "", int parition = 0); -float Evaluate(const cv::Mat &scores, const cv::Mat &masks, const QString &csv = ""); -bool Plot(const QStringList &files, const br::File &destination, bool show = false); -bool PlotMetadata(const QStringList &files, const QString &destination, bool show = false); - + bool Plot(const QStringList &files, const br::File &destination, bool show = false); + bool PlotMetadata(const QStringList &files, const QString &destination, bool show = false); } #endif // __PLOT_H diff --git a/openbr/openbr.cpp b/openbr/openbr.cpp index 8f909ca..7a0a759 100644 --- a/openbr/openbr.cpp +++ b/openbr/openbr.cpp @@ -17,8 +17,8 @@ #include #include "core/bee.h" -#include "core/classify.h" #include "core/cluster.h" +#include "core/eval.h" #include "core/fuse.h" #include "core/plot.h" #include "core/qtutils.h" @@ -51,11 +51,6 @@ void br_compare(const char *target_gallery, const char *query_gallery, const cha Compare(File(target_gallery), File(query_gallery), File(output)); } -void br_confusion(const char *file, float score, int *true_positives, int *false_positives, int *true_negatives, int *false_negatives) -{ - return Confusion(file, score, *true_positives, *false_positives, *true_negatives, *false_negatives); -} - void br_convert(const char *file_type, const char *input_file, const char *output_file) { Convert(File(file_type), File(input_file), File(output_file)); diff --git a/openbr/openbr.h b/openbr/openbr.h index eca10dd..00e1a01 100644 --- a/openbr/openbr.h +++ b/openbr/openbr.h @@ -115,20 +115,6 @@ BR_EXPORT void br_combine_masks(int num_input_masks, const char *input_masks[], BR_EXPORT void br_compare(const char *target_gallery, const char *query_gallery, const char *output = ""); /*! - * \brief Computes the confusion matrix for a dataset at a particular threshold. - * - * Wikipedia Explanation - * \param file .csv file created using \ref br_eval. - * \param score The similarity score to threshold at. - * \param[out] true_positives The true positive count. - * \param[out] false_positives The false positive count. - * \param[out] true_negatives The true negative count. - * \param[out] false_negatives The false negative count. - */ -BR_EXPORT void br_confusion(const char *file, float score, - int *true_positives, int *false_positives, int *true_negatives, int *false_negatives); - -/*! * \brief Wraps br::Convert() */ BR_EXPORT void br_convert(const char *file_type, const char *input_file, const char *output_file); diff --git a/openbr/openbr_plugin.cpp b/openbr/openbr_plugin.cpp index 5a4e89a..5f5fbb3 100644 --- a/openbr/openbr_plugin.cpp +++ b/openbr/openbr_plugin.cpp @@ -436,7 +436,7 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) // stores the index values in "Label" of the output template list TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propName) { - const QList originalLabels = tl.get(propName); + const QList originalLabels = File::get(tl, propName); QHash labelTable; foreach (const QString & label, originalLabels) if (!labelTable.contains(label)) @@ -464,7 +464,7 @@ QList TemplateList::indexProperty(const QString & propName, QHash originalLabels = values(propName); + const QList originalLabels = File::values(*this, propName); foreach (const QVariant & label, originalLabels) { QString labelString = label.toString(); if (!valueMap.contains(labelString)) { @@ -481,9 +481,9 @@ QList TemplateList::indexProperty(const QString & propName, QHash TemplateList::applyIndex(const QString & propName, const QHash & valueMap) const +QList TemplateList::applyIndex(const QString &propName, const QHash &valueMap) const { - const QList originalLabels = get(propName); + const QList originalLabels = File::get(*this, propName); QList result; for (int i=0; i(); } - /*!< \brief Returns a list of type T for the key, throwing an error if the key does not exist or if the value cannot be converted to the specified type. */ + /*!< \brief Returns a value for the key, returning \em defaultValue if the key does not exist or can't be converted. */ + template + T get(const QString &key, const T &defaultValue) const + { + if (!contains(key)) return defaultValue; + QVariant variant = value(key); + if (!variant.canConvert()) return defaultValue; + return variant.value(); + } + + /*!< \brief Specialization for boolean type. */ + bool getBool(const QString &key, bool defaultValue = false) const; + + /*!< \brief Specialization for list type. Returns a list of type T for the key, throwing an error if the key does not exist or if the value cannot be converted to the specified type. */ template QList getList(const QString &key) const { @@ -241,17 +254,31 @@ struct BR_EXPORT File return list; } - /*!< \brief Specialization for boolean type. */ - bool getBool(const QString &key, bool defaultValue = false) const; + /*!< \brief Returns the value for the specified key for every file in the list. */ + template + static QList values(const QList &fileList, const QString &key) + { + QList values; values.reserve(fileList.size()); + foreach (const U &f, fileList) values.append(((const File&)f).value(key)); + return values; + } - /*!< \brief Returns a value for the key, returning \em defaultValue if the key does not exist or can't be converted. */ - template - T get(const QString &key, const T &defaultValue) const + /*!< \brief Returns a value for the key for every file in the list, throwing an error if the key does not exist. */ + template + static QList get(const QList &fileList, const QString &key) { - if (!contains(key)) return defaultValue; - QVariant variant = value(key); - if (!variant.canConvert()) return defaultValue; - return variant.value(); + QList result; result.reserve(fileList.size()); + foreach (const U &f, fileList) result.append(((const File&)f).get(key)); + return result; + } + + /*!< \brief Returns a value for the key for every file in the list, returning \em defaultValue if the key does not exist or can't be converted. */ + template + static QList get(const QList &fileList, const QString &key, const T &defaultValue) + { + QList result; result.reserve(fileList.size()); + foreach (const U &f, fileList) result.append(static_cast(f).get(key, defaultValue)); + return result; } inline bool failed() const { return getBool("FTE") || getBool("FTO"); } /*!< \brief Returns \c true if the file failed to open or enroll, \c false otherwise. */ @@ -297,23 +324,6 @@ struct BR_EXPORT FileList : public QList QStringList flat() const; /*!< \brief Returns br::File::flat() for each file in the list. */ QStringList names() const; /*!< \brief Returns #br::File::name for each file in the list. */ void sort(const QString& key); /*!< \brief Sort the list based on metadata. */ - /*!< \brief Returns values associated with the input propName for each file in the list. */ - template - QList get(const QString & propName) const - { - QList values; values.reserve(size()); - foreach (const File &f, *this) - values.append(f.get(propName)); - return values; - } - template - QList get(const QString & propName, T defaultValue) const - { - QList values; values.reserve(size()); - foreach (const File &f, *this) - values.append(f.contains(propName) ? f.get(propName) : defaultValue); - return values; - } QList crossValidationPartitions() const; /*!< \brief Returns the cross-validation partition (default=0) for each file in the list. */ int failures() const; /*!< \brief Returns the number of files with br::File::failed(). */ @@ -344,6 +354,7 @@ struct Template : public QList inline const cv::Mat &m() const { static const cv::Mat NullMatrix; return isEmpty() ? qFatal("Empty template."), NullMatrix : last(); } /*!< \brief Idiom to treat the template as a matrix. */ inline cv::Mat &m() { return isEmpty() ? append(cv::Mat()), last() : last(); } /*!< \brief Idiom to treat the template as a matrix. */ + inline const File &operator()() const { return file; } inline cv::Mat &operator=(const cv::Mat &other) { return m() = other; } /*!< \brief Idiom to treat the template as a matrix. */ inline operator const cv::Mat&() const { return m(); } /*!< \brief Idiom to treat the template as a matrix. */ inline operator cv::Mat&() { return m(); } /*!< \brief Idiom to treat the template as a matrix. */ @@ -406,7 +417,6 @@ struct TemplateList : public QList