Commit ec4af7799c2f589d03adea56955c8a2faac9f6a4
Merge branch 'master' of https://github.com/biometrics/openbr
Showing
21 changed files
with
564 additions
and
495 deletions
.gitignore
app/br/br.cpp
| @@ -140,13 +140,6 @@ public: | @@ -140,13 +140,6 @@ public: | ||
| 140 | } else if (!strcmp(fun, "evalRegression")) { | 140 | } else if (!strcmp(fun, "evalRegression")) { |
| 141 | check(parc == 2, "Incorrect parameter count for 'evalRegression'."); | 141 | check(parc == 2, "Incorrect parameter count for 'evalRegression'."); |
| 142 | br_eval_regression(parv[0], parv[1]); | 142 | br_eval_regression(parv[0], parv[1]); |
| 143 | - } else if (!strcmp(fun, "confusion")) { | ||
| 144 | - check(parc == 2, "Incorrect parameter count for 'confusion'."); | ||
| 145 | - int true_positives, false_positives, true_negatives, false_negatives; | ||
| 146 | - br_confusion(parv[0], atof(parv[1]), | ||
| 147 | - &true_positives, &false_positives, &true_negatives, &false_negatives); | ||
| 148 | - printf("True Positives = %d\nFalse Positives = %d\nTrue Negatives = %d\nFalseNegatives = %d\n", | ||
| 149 | - true_positives, false_positives, true_negatives, false_negatives); | ||
| 150 | } else if (!strcmp(fun, "plotMetadata")) { | 143 | } else if (!strcmp(fun, "plotMetadata")) { |
| 151 | check(parc >= 2, "Incorrect parameter count for 'plotMetadata'."); | 144 | check(parc >= 2, "Incorrect parameter count for 'plotMetadata'."); |
| 152 | br_plot_metadata(parc-1, parv, parv[parc-1], true); | 145 | br_plot_metadata(parc-1, parv, parv[parc-1], true); |
| @@ -223,7 +216,6 @@ private: | @@ -223,7 +216,6 @@ private: | ||
| 223 | "-evalClustering <clusters> <gallery>\n" | 216 | "-evalClustering <clusters> <gallery>\n" |
| 224 | "-evalDetection <predicted_gallery> <truth_gallery>\n" | 217 | "-evalDetection <predicted_gallery> <truth_gallery>\n" |
| 225 | "-evalRegression <predicted_gallery> <truth_gallery>\n" | 218 | "-evalRegression <predicted_gallery> <truth_gallery>\n" |
| 226 | - "-confusion <file> <score>\n" | ||
| 227 | "-plotMetadata <file> ... <file> <columns>\n" | 219 | "-plotMetadata <file> ... <file> <columns>\n" |
| 228 | "-getHeader <matrix>\n" | 220 | "-getHeader <matrix>\n" |
| 229 | "-setHeader {<matrix>} <target_gallery> <query_gallery>\n" | 221 | "-setHeader {<matrix>} <target_gallery> <query_gallery>\n" |
data/KTH/README.md
0 → 100644
data/README.md
| @@ -9,6 +9,7 @@ | @@ -9,6 +9,7 @@ | ||
| 9 | * [MEDS](MEDS/README.md) | 9 | * [MEDS](MEDS/README.md) |
| 10 | * [MNIST](MNIST/README.md) | 10 | * [MNIST](MNIST/README.md) |
| 11 | * [PCSO](PCSO/README.md) | 11 | * [PCSO](PCSO/README.md) |
| 12 | +* [KTH](KTH/README.md) | ||
| 12 | 13 | ||
| 13 | For both practical and legal reasons we only include images for some of the datasets in this repository. | 14 | For both practical and legal reasons we only include images for some of the datasets in this repository. |
| 14 | Researchers should contact the respective owners of the other datasets in order to obtain a copy. | 15 | Researchers should contact the respective owners of the other datasets in order to obtain a copy. |
openbr/core/bee.cpp
| @@ -268,8 +268,8 @@ cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, | @@ -268,8 +268,8 @@ cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, | ||
| 268 | { | 268 | { |
| 269 | // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet | 269 | // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet |
| 270 | // -cao | 270 | // -cao |
| 271 | - QList<QString> targetLabels = targets.get<QString>("Subject", "-1"); | ||
| 272 | - QList<QString> queryLabels = queries.get<QString>("Subject", "-1"); | 271 | + QList<QString> targetLabels = File::get<QString>(targets, "Subject", "-1"); |
| 272 | + QList<QString> queryLabels = File::get<QString>(queries, "Subject", "-1"); | ||
| 273 | QList<int> targetPartitions = targets.crossValidationPartitions(); | 273 | QList<int> targetPartitions = targets.crossValidationPartitions(); |
| 274 | QList<int> queryPartitions = queries.crossValidationPartitions(); | 274 | QList<int> queryPartitions = queries.crossValidationPartitions(); |
| 275 | 275 |
openbr/core/classify.cpp deleted
| 1 | -/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * | ||
| 2 | - * Copyright 2012 The MITRE Corporation * | ||
| 3 | - * * | ||
| 4 | - * Licensed under the Apache License, Version 2.0 (the "License"); * | ||
| 5 | - * you may not use this file except in compliance with the License. * | ||
| 6 | - * You may obtain a copy of the License at * | ||
| 7 | - * * | ||
| 8 | - * http://www.apache.org/licenses/LICENSE-2.0 * | ||
| 9 | - * * | ||
| 10 | - * Unless required by applicable law or agreed to in writing, software * | ||
| 11 | - * distributed under the License is distributed on an "AS IS" BASIS, * | ||
| 12 | - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * | ||
| 13 | - * See the License for the specific language governing permissions and * | ||
| 14 | - * limitations under the License. * | ||
| 15 | - * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ | ||
| 16 | - | ||
| 17 | -#include <openbr/openbr_plugin.h> | ||
| 18 | - | ||
| 19 | -#include "classify.h" | ||
| 20 | -#include "openbr/core/qtutils.h" | ||
| 21 | - | ||
| 22 | -// Helper struct for statistics accumulation | ||
| 23 | -struct Counter | ||
| 24 | -{ | ||
| 25 | - float truePositive, falsePositive, falseNegative; | ||
| 26 | - Counter() | ||
| 27 | - { | ||
| 28 | - truePositive = 0; | ||
| 29 | - falsePositive = 0; | ||
| 30 | - falseNegative = 0; | ||
| 31 | - } | ||
| 32 | -}; | ||
| 33 | - | ||
| 34 | -void br::EvalClassification(const QString &predictedInput, const QString &truthInput) | ||
| 35 | -{ | ||
| 36 | - qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); | ||
| 37 | - | ||
| 38 | - TemplateList predicted(TemplateList::fromGallery(predictedInput)); | ||
| 39 | - TemplateList truth(TemplateList::fromGallery(truthInput)); | ||
| 40 | - if (predicted.size() != truth.size()) qFatal("Input size mismatch."); | ||
| 41 | - | ||
| 42 | - QHash<QString, Counter> counters; | ||
| 43 | - for (int i=0; i<predicted.size(); i++) { | ||
| 44 | - if (predicted[i].file.name != truth[i].file.name) | ||
| 45 | - qFatal("Input order mismatch."); | ||
| 46 | - | ||
| 47 | - // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy. | ||
| 48 | - QString predictedSubject = predicted[i].file.get<QString>("Subject"); | ||
| 49 | - QString trueSubject = truth[i].file.get<QString>("Subject"); | ||
| 50 | - | ||
| 51 | - QStringList predictedSubjects(predictedSubject); | ||
| 52 | - QStringList trueSubjects(trueSubject); | ||
| 53 | - | ||
| 54 | - foreach (const QString &subject, trueSubjects.toVector() /* Hack to copy the list. */) { | ||
| 55 | - if (predictedSubjects.contains(subject)) { | ||
| 56 | - counters[subject].truePositive++; | ||
| 57 | - trueSubjects.removeOne(subject); | ||
| 58 | - predictedSubjects.removeOne(subject); | ||
| 59 | - } else { | ||
| 60 | - counters[subject].falseNegative++; | ||
| 61 | - } | ||
| 62 | - } | ||
| 63 | - | ||
| 64 | - for (int i=0; i<trueSubjects.size(); i++) | ||
| 65 | - foreach (const QString &subject, predictedSubjects) | ||
| 66 | - counters[subject].falsePositive += 1.f / predictedSubjects.size(); | ||
| 67 | - } | ||
| 68 | - | ||
| 69 | - const QStringList keys = counters.keys(); | ||
| 70 | - QSharedPointer<Output> output(Output::make("", FileList() << "Count" << "Precision" << "Recall" << "F-score", FileList(keys))); | ||
| 71 | - | ||
| 72 | - int tpc = 0; | ||
| 73 | - int fnc = 0; | ||
| 74 | - | ||
| 75 | - for (int i=0; i<counters.size(); i++) { | ||
| 76 | - const QString &subject = keys[i]; | ||
| 77 | - const Counter &counter = counters[subject]; | ||
| 78 | - tpc += counter.truePositive; | ||
| 79 | - fnc += counter.falseNegative; | ||
| 80 | - const int count = counter.truePositive + counter.falseNegative; | ||
| 81 | - const float precision = counter.truePositive / (float)(counter.truePositive + counter.falsePositive); | ||
| 82 | - const float recall = counter.truePositive / (float)(counter.truePositive + counter.falseNegative); | ||
| 83 | - const float fscore = 2 * precision * recall / (precision + recall); | ||
| 84 | - output->setRelative(count, i, 0); | ||
| 85 | - output->setRelative(precision, i, 1); | ||
| 86 | - output->setRelative(recall, i, 2); | ||
| 87 | - output->setRelative(fscore, i, 3); | ||
| 88 | - } | ||
| 89 | - | ||
| 90 | - qDebug("Overall Accuracy = %f", (float)tpc / (float)(tpc + fnc)); | ||
| 91 | -} | ||
| 92 | - | ||
| 93 | -void br::EvalDetection(const QString &predictedInput, const QString &truthInput) | ||
| 94 | -{ | ||
| 95 | - (void) predictedInput; | ||
| 96 | - (void) truthInput; | ||
| 97 | -} | ||
| 98 | - | ||
| 99 | -void br::EvalRegression(const QString &predictedInput, const QString &truthInput) | ||
| 100 | -{ | ||
| 101 | - qDebug("Evaluating regression of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); | ||
| 102 | - | ||
| 103 | - const TemplateList predicted(TemplateList::fromGallery(predictedInput)); | ||
| 104 | - const TemplateList truth(TemplateList::fromGallery(truthInput)); | ||
| 105 | - if (predicted.size() != truth.size()) qFatal("Input size mismatch."); | ||
| 106 | - | ||
| 107 | - float rmsError = 0; | ||
| 108 | - QStringList truthValues, predictedValues; | ||
| 109 | - for (int i=0; i<predicted.size(); i++) { | ||
| 110 | - if (predicted[i].file.name != truth[i].file.name) | ||
| 111 | - qFatal("Input order mismatch."); | ||
| 112 | - rmsError += pow(predicted[i].file.get<float>("Subject")-truth[i].file.get<float>("Subject"), 2.f); | ||
| 113 | - truthValues.append(QString::number(truth[i].file.get<float>("Subject"))); | ||
| 114 | - predictedValues.append(QString::number(predicted[i].file.get<float>("Subject"))); | ||
| 115 | - } | ||
| 116 | - | ||
| 117 | - QStringList rSource; | ||
| 118 | - rSource << "# Load libraries" << "library(ggplot2)" << "" << "# Set Data" | ||
| 119 | - << "Actual <- c(" + truthValues.join(",") + ")" | ||
| 120 | - << "Predicted <- c(" + predictedValues.join(",") + ")" | ||
| 121 | - << "data <- data.frame(Actual, Predicted)" | ||
| 122 | - << "" << "# Construct Plot" << "pdf(\"EvalRegression.pdf\")" | ||
| 123 | - << "print(qplot(Actual, Predicted, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=1, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())" | ||
| 124 | - << "print(qplot(Actual, Predicted-Actual, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=0, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())" | ||
| 125 | - << "dev.off()"; | ||
| 126 | - | ||
| 127 | - | ||
| 128 | - QString rFile = "EvalRegression.R"; | ||
| 129 | - QtUtils::writeFile(rFile, rSource); | ||
| 130 | - bool success = QtUtils::runRScript(rFile); | ||
| 131 | - if (success) QtUtils::showFile("EvalRegression.pdf"); | ||
| 132 | - | ||
| 133 | - qDebug("RMS Error = %f", sqrt(rmsError/predicted.size())); | ||
| 134 | -} |
openbr/core/cluster.cpp
| @@ -280,7 +280,7 @@ void br::EvalClustering(const QString &csv, const QString &input) | @@ -280,7 +280,7 @@ void br::EvalClustering(const QString &csv, const QString &input) | ||
| 280 | 280 | ||
| 281 | // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are | 281 | // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are |
| 282 | // not named). | 282 | // not named). |
| 283 | - QList<int> labels = TemplateList::fromGallery(input).files().get<int>("Subject"); | 283 | + QList<int> labels = File::get<int>(TemplateList::fromGallery(input), "Subject"); |
| 284 | 284 | ||
| 285 | QHash<int, int> labelToIndex; | 285 | QHash<int, int> labelToIndex; |
| 286 | int nClusters = 0; | 286 | int nClusters = 0; |
openbr/core/eval.cpp
0 → 100644
| 1 | +/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * | ||
| 2 | + * Copyright 2012 The MITRE Corporation * | ||
| 3 | + * * | ||
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); * | ||
| 5 | + * you may not use this file except in compliance with the License. * | ||
| 6 | + * You may obtain a copy of the License at * | ||
| 7 | + * * | ||
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 * | ||
| 9 | + * * | ||
| 10 | + * Unless required by applicable law or agreed to in writing, software * | ||
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, * | ||
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * | ||
| 13 | + * See the License for the specific language governing permissions and * | ||
| 14 | + * limitations under the License. * | ||
| 15 | + * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ | ||
| 16 | + | ||
| 17 | +#include "bee.h" | ||
| 18 | +#include "eval.h" | ||
| 19 | +#include "openbr/core/qtutils.h" | ||
| 20 | + | ||
| 21 | +using namespace cv; | ||
| 22 | + | ||
| 23 | +namespace br | ||
| 24 | +{ | ||
| 25 | + | ||
| 26 | +struct Comparison | ||
| 27 | +{ | ||
| 28 | + float score; | ||
| 29 | + int target, query; | ||
| 30 | + bool genuine; | ||
| 31 | + | ||
| 32 | + Comparison() {} | ||
| 33 | + Comparison(float _score, int _target, int _query, bool _genuine) | ||
| 34 | + : score(_score), target(_target), query(_query), genuine(_genuine) {} | ||
| 35 | + inline bool operator<(const Comparison &other) const { return score > other.score; } | ||
| 36 | +}; | ||
| 37 | + | ||
| 38 | +#undef FAR // Windows preprecessor definition conflicts with variable name | ||
| 39 | +struct OperatingPoint | ||
| 40 | +{ | ||
| 41 | + float score, FAR, TAR; | ||
| 42 | + OperatingPoint() {} | ||
| 43 | + OperatingPoint(float _score, float _FAR, float _TAR) | ||
| 44 | + : score(_score), FAR(_FAR), TAR(_TAR) {} | ||
| 45 | +}; | ||
| 46 | + | ||
| 47 | +static float getTAR(const QList<OperatingPoint> &operatingPoints, float FAR) | ||
| 48 | +{ | ||
| 49 | + int index = 0; | ||
| 50 | + while (operatingPoints[index].FAR < FAR) { | ||
| 51 | + index++; | ||
| 52 | + if (index == operatingPoints.size()) | ||
| 53 | + return 1; | ||
| 54 | + } | ||
| 55 | + | ||
| 56 | + const float x1 = (index == 0 ? 0 : operatingPoints[index-1].FAR); | ||
| 57 | + const float y1 = (index == 0 ? 0 : operatingPoints[index-1].TAR); | ||
| 58 | + const float x2 = operatingPoints[index].FAR; | ||
| 59 | + const float y2 = operatingPoints[index].TAR; | ||
| 60 | + const float m = (y2 - y1) / (x2 - x1); | ||
| 61 | + const float b = y1 - m*x1; | ||
| 62 | + return m * FAR + b; | ||
| 63 | +} | ||
| 64 | + | ||
| 65 | +float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv, int partition) | ||
| 66 | +{ | ||
| 67 | + return Evaluate(scores, BEE::makeMask(target, query, partition), csv); | ||
| 68 | +} | ||
| 69 | + | ||
| 70 | +float Evaluate(const QString &simmat, const QString &mask, const QString &csv) | ||
| 71 | +{ | ||
| 72 | + qDebug("Evaluating %s%s%s", | ||
| 73 | + qPrintable(simmat), | ||
| 74 | + mask.isEmpty() ? "" : qPrintable(" with " + mask), | ||
| 75 | + csv.isEmpty() ? "" : qPrintable(" to " + csv)); | ||
| 76 | + | ||
| 77 | + // Read similarity matrix | ||
| 78 | + QString target, query; | ||
| 79 | + const Mat scores = BEE::readSimmat(simmat, &target, &query); | ||
| 80 | + | ||
| 81 | + // Read mask matrix | ||
| 82 | + Mat truth; | ||
| 83 | + if (mask.isEmpty()) { | ||
| 84 | + // Use the galleries specified in the similarity matrix | ||
| 85 | + truth = BEE::makeMask(TemplateList::fromGallery(target).files(), | ||
| 86 | + TemplateList::fromGallery(query).files()); | ||
| 87 | + } else { | ||
| 88 | + File maskFile(mask); | ||
| 89 | + maskFile.set("rows", scores.rows); | ||
| 90 | + maskFile.set("columns", scores.cols); | ||
| 91 | + truth = BEE::readMask(maskFile); | ||
| 92 | + } | ||
| 93 | + | ||
| 94 | + return Evaluate(scores, truth, csv); | ||
| 95 | +} | ||
| 96 | + | ||
| 97 | +float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv) | ||
| 98 | +{ | ||
| 99 | + if (simmat.size() != mask.size()) | ||
| 100 | + qFatal("Similarity matrix (%ix%i) differs in size from mask matrix (%ix%i).", | ||
| 101 | + simmat.rows, simmat.cols, mask.rows, mask.cols); | ||
| 102 | + | ||
| 103 | + const int Max_Points = 500; | ||
| 104 | + float result = -1; | ||
| 105 | + | ||
| 106 | + // Make comparisons | ||
| 107 | + QList<Comparison> comparisons; comparisons.reserve(simmat.rows*simmat.cols); | ||
| 108 | + int genuineCount = 0, impostorCount = 0, numNaNs = 0; | ||
| 109 | + for (int i=0; i<simmat.rows; i++) { | ||
| 110 | + for (int j=0; j<simmat.cols; j++) { | ||
| 111 | + const BEE::Mask_t mask_val = mask.at<BEE::Mask_t>(i,j); | ||
| 112 | + const BEE::Simmat_t simmat_val = simmat.at<BEE::Simmat_t>(i,j); | ||
| 113 | + if (mask_val == BEE::DontCare) continue; | ||
| 114 | + if (simmat_val != simmat_val) { numNaNs++; continue; } | ||
| 115 | + comparisons.append(Comparison(simmat_val, j, i, mask_val == BEE::Match)); | ||
| 116 | + if (comparisons.last().genuine) genuineCount++; | ||
| 117 | + else impostorCount++; | ||
| 118 | + } | ||
| 119 | + } | ||
| 120 | + | ||
| 121 | + if (numNaNs > 0) qWarning("Encountered %d NaN scores!", numNaNs); | ||
| 122 | + if (genuineCount == 0) qFatal("No genuine scores!"); | ||
| 123 | + if (impostorCount == 0) qFatal("No impostor scores!"); | ||
| 124 | + | ||
| 125 | + // Sort comparisons by simmat_val (score) | ||
| 126 | + std::sort(comparisons.begin(), comparisons.end()); | ||
| 127 | + | ||
| 128 | + QList<OperatingPoint> operatingPoints; | ||
| 129 | + QList<float> genuines; genuines.reserve(sqrt((float)comparisons.size())); | ||
| 130 | + QList<float> impostors; impostors.reserve(comparisons.size()); | ||
| 131 | + QVector<int> firstGenuineReturns(simmat.rows, 0); | ||
| 132 | + | ||
| 133 | + int falsePositives = 0, previousFalsePositives = 0; | ||
| 134 | + int truePositives = 0, previousTruePositives = 0; | ||
| 135 | + int index = 0; | ||
| 136 | + float minGenuineScore = std::numeric_limits<float>::max(); | ||
| 137 | + float minImpostorScore = std::numeric_limits<float>::max(); | ||
| 138 | + | ||
| 139 | + while (index < comparisons.size()) { | ||
| 140 | + float thresh = comparisons[index].score; | ||
| 141 | + // Compute genuine and imposter statistics at a threshold | ||
| 142 | + while ((index < comparisons.size()) && | ||
| 143 | + (comparisons[index].score == thresh)) { | ||
| 144 | + const Comparison &comparison = comparisons[index]; | ||
| 145 | + if (comparison.genuine) { | ||
| 146 | + truePositives++; | ||
| 147 | + genuines.append(comparison.score); | ||
| 148 | + if (firstGenuineReturns[comparison.query] < 1) | ||
| 149 | + firstGenuineReturns[comparison.query] = abs(firstGenuineReturns[comparison.query]) + 1; | ||
| 150 | + if ((comparison.score != -std::numeric_limits<float>::max()) && | ||
| 151 | + (comparison.score < minGenuineScore)) | ||
| 152 | + minGenuineScore = comparison.score; | ||
| 153 | + } else { | ||
| 154 | + falsePositives++; | ||
| 155 | + impostors.append(comparison.score); | ||
| 156 | + if (firstGenuineReturns[comparison.query] < 1) | ||
| 157 | + firstGenuineReturns[comparison.query]--; | ||
| 158 | + if ((comparison.score != -std::numeric_limits<float>::max()) && | ||
| 159 | + (comparison.score < minImpostorScore)) | ||
| 160 | + minImpostorScore = comparison.score; | ||
| 161 | + } | ||
| 162 | + index++; | ||
| 163 | + } | ||
| 164 | + | ||
| 165 | + if ((falsePositives > previousFalsePositives) && | ||
| 166 | + (truePositives > previousTruePositives)) { | ||
| 167 | + // Restrict the extreme ends of the curve | ||
| 168 | + if ((truePositives >= 10) && (falsePositives < impostorCount/2)) | ||
| 169 | + operatingPoints.append(OperatingPoint(thresh, float(falsePositives)/impostorCount, float(truePositives)/genuineCount)); | ||
| 170 | + previousFalsePositives = falsePositives; | ||
| 171 | + previousTruePositives = truePositives; | ||
| 172 | + } | ||
| 173 | + } | ||
| 174 | + | ||
| 175 | + if (operatingPoints.size() == 0) operatingPoints.append(OperatingPoint(1, 1, 1)); | ||
| 176 | + if (operatingPoints.size() == 1) operatingPoints.prepend(OperatingPoint(0, 0, 0)); | ||
| 177 | + if (operatingPoints.size() > 2) operatingPoints.takeLast(); // Remove point (1,1) | ||
| 178 | + | ||
| 179 | + // Write Metadata table | ||
| 180 | + QStringList lines; | ||
| 181 | + lines.append("Plot,X,Y"); | ||
| 182 | + lines.append("Metadata,"+QString::number(simmat.cols)+",Gallery"); | ||
| 183 | + lines.append("Metadata,"+QString::number(simmat.rows)+",Probe"); | ||
| 184 | + lines.append("Metadata,"+QString::number(genuineCount)+",Genuine"); | ||
| 185 | + lines.append("Metadata,"+QString::number(impostorCount)+",Impostor"); | ||
| 186 | + lines.append("Metadata,"+QString::number(simmat.cols*simmat.rows-(genuineCount+impostorCount))+",Ignored"); | ||
| 187 | + | ||
| 188 | + // Write Detection Error Tradeoff (DET), PRE, REC | ||
| 189 | + int points = qMin(operatingPoints.size(), Max_Points); | ||
| 190 | + for (int i=0; i<points; i++) { | ||
| 191 | + const OperatingPoint &operatingPoint = operatingPoints[double(i) / double(points-1) * double(operatingPoints.size()-1)]; | ||
| 192 | + lines.append(QString("DET,%1,%2").arg(QString::number(operatingPoint.FAR), | ||
| 193 | + QString::number(1-operatingPoint.TAR))); | ||
| 194 | + lines.append(QString("FAR,%1,%2").arg(QString::number(operatingPoint.score), | ||
| 195 | + QString::number(operatingPoint.FAR))); | ||
| 196 | + lines.append(QString("FRR,%1,%2").arg(QString::number(operatingPoint.score), | ||
| 197 | + QString::number(1-operatingPoint.TAR))); | ||
| 198 | + } | ||
| 199 | + | ||
| 200 | + // Write FAR/TAR Bar Chart (BC) | ||
| 201 | + lines.append(qPrintable(QString("BC,0.001,%1").arg(QString::number(getTAR(operatingPoints, 0.001), 'f', 3)))); | ||
| 202 | + lines.append(qPrintable(QString("BC,0.01,%1").arg(QString::number(result = getTAR(operatingPoints, 0.01), 'f', 3)))); | ||
| 203 | + | ||
| 204 | + // Write SD & KDE | ||
| 205 | + points = qMin(qMin(Max_Points, genuines.size()), impostors.size()); | ||
| 206 | + QList<double> sampledGenuineScores; sampledGenuineScores.reserve(points); | ||
| 207 | + QList<double> sampledImpostorScores; sampledImpostorScores.reserve(points); | ||
| 208 | + | ||
| 209 | + if (points > 1) { | ||
| 210 | + for (int i=0; i<points; i++) { | ||
| 211 | + float genuineScore = genuines[double(i) / double(points-1) * double(genuines.size()-1)]; | ||
| 212 | + float impostorScore = impostors[double(i) / double(points-1) * double(impostors.size()-1)]; | ||
| 213 | + if (genuineScore == -std::numeric_limits<float>::max()) genuineScore = minGenuineScore; | ||
| 214 | + if (impostorScore == -std::numeric_limits<float>::max()) impostorScore = minImpostorScore; | ||
| 215 | + lines.append(QString("SD,%1,Genuine").arg(QString::number(genuineScore))); | ||
| 216 | + lines.append(QString("SD,%1,Impostor").arg(QString::number(impostorScore))); | ||
| 217 | + sampledGenuineScores.append(genuineScore); | ||
| 218 | + sampledImpostorScores.append(impostorScore); | ||
| 219 | + } | ||
| 220 | + } | ||
| 221 | + | ||
| 222 | + // Write Cumulative Match Characteristic (CMC) curve | ||
| 223 | + const int Max_Retrieval = 200; | ||
| 224 | + const int Report_Retrieval = 5; | ||
| 225 | + | ||
| 226 | + float reportRetrievalRate = -1; | ||
| 227 | + for (int i=1; i<=Max_Retrieval; i++) { | ||
| 228 | + int realizedReturns = 0, possibleReturns = 0; | ||
| 229 | + foreach (int firstGenuineReturn, firstGenuineReturns) { | ||
| 230 | + if (firstGenuineReturn > 0) { | ||
| 231 | + possibleReturns++; | ||
| 232 | + if (firstGenuineReturn <= i) realizedReturns++; | ||
| 233 | + } | ||
| 234 | + } | ||
| 235 | + const float retrievalRate = float(realizedReturns)/possibleReturns; | ||
| 236 | + lines.append(qPrintable(QString("CMC,%1,%2").arg(QString::number(i), QString::number(retrievalRate)))); | ||
| 237 | + if (i == Report_Retrieval) reportRetrievalRate = retrievalRate; | ||
| 238 | + } | ||
| 239 | + | ||
| 240 | + if (!csv.isEmpty()) QtUtils::writeFile(csv, lines); | ||
| 241 | + qDebug("TAR @ FAR = 0.01: %.3f\nRetrieval Rate @ Rank = %d: %.3f", result, Report_Retrieval, reportRetrievalRate); | ||
| 242 | + return result; | ||
| 243 | +} | ||
| 244 | + | ||
| 245 | +// Helper struct for statistics accumulation | ||
| 246 | +struct Counter | ||
| 247 | +{ | ||
| 248 | + float truePositive, falsePositive, falseNegative; | ||
| 249 | + Counter() | ||
| 250 | + { | ||
| 251 | + truePositive = 0; | ||
| 252 | + falsePositive = 0; | ||
| 253 | + falseNegative = 0; | ||
| 254 | + } | ||
| 255 | +}; | ||
| 256 | + | ||
| 257 | +void EvalClassification(const QString &predictedInput, const QString &truthInput) | ||
| 258 | +{ | ||
| 259 | + qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); | ||
| 260 | + TemplateList predicted(TemplateList::fromGallery(predictedInput)); | ||
| 261 | + TemplateList truth(TemplateList::fromGallery(truthInput)); | ||
| 262 | + if (predicted.size() != truth.size()) qFatal("Input size mismatch."); | ||
| 263 | + | ||
| 264 | + QHash<QString, Counter> counters; | ||
| 265 | + for (int i=0; i<predicted.size(); i++) { | ||
| 266 | + if (predicted[i].file.name != truth[i].file.name) | ||
| 267 | + qFatal("Input order mismatch."); | ||
| 268 | + | ||
| 269 | + // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy. | ||
| 270 | + QString predictedSubject = predicted[i].file.get<QString>("Subject"); | ||
| 271 | + QString trueSubject = truth[i].file.get<QString>("Subject"); | ||
| 272 | + | ||
| 273 | + QStringList predictedSubjects(predictedSubject); | ||
| 274 | + QStringList trueSubjects(trueSubject); | ||
| 275 | + | ||
| 276 | + foreach (const QString &subject, trueSubjects.toVector() /* Hack to copy the list. */) { | ||
| 277 | + if (predictedSubjects.contains(subject)) { | ||
| 278 | + counters[subject].truePositive++; | ||
| 279 | + trueSubjects.removeOne(subject); | ||
| 280 | + predictedSubjects.removeOne(subject); | ||
| 281 | + } else { | ||
| 282 | + counters[subject].falseNegative++; | ||
| 283 | + } | ||
| 284 | + } | ||
| 285 | + | ||
| 286 | + for (int i=0; i<trueSubjects.size(); i++) | ||
| 287 | + foreach (const QString &subject, predictedSubjects) | ||
| 288 | + counters[subject].falsePositive += 1.f / predictedSubjects.size(); | ||
| 289 | + } | ||
| 290 | + | ||
| 291 | + const QStringList keys = counters.keys(); | ||
| 292 | + QSharedPointer<Output> output(Output::make("", FileList() << "Count" << "Precision" << "Recall" << "F-score", FileList(keys))); | ||
| 293 | + | ||
| 294 | + int tpc = 0; | ||
| 295 | + int fnc = 0; | ||
| 296 | + | ||
| 297 | + for (int i=0; i<counters.size(); i++) { | ||
| 298 | + const QString &subject = keys[i]; | ||
| 299 | + const Counter &counter = counters[subject]; | ||
| 300 | + tpc += counter.truePositive; | ||
| 301 | + fnc += counter.falseNegative; | ||
| 302 | + const int count = counter.truePositive + counter.falseNegative; | ||
| 303 | + const float precision = counter.truePositive / (float)(counter.truePositive + counter.falsePositive); | ||
| 304 | + const float recall = counter.truePositive / (float)(counter.truePositive + counter.falseNegative); | ||
| 305 | + const float fscore = 2 * precision * recall / (precision + recall); | ||
| 306 | + output->setRelative(count, i, 0); | ||
| 307 | + output->setRelative(precision, i, 1); | ||
| 308 | + output->setRelative(recall, i, 2); | ||
| 309 | + output->setRelative(fscore, i, 3); | ||
| 310 | + } | ||
| 311 | + | ||
| 312 | + qDebug("Overall Accuracy = %f", (float)tpc / (float)(tpc + fnc)); | ||
| 313 | +} | ||
| 314 | + | ||
| 315 | +struct Detection | ||
| 316 | +{ | ||
| 317 | + QRectF boundingBox; | ||
| 318 | + float confidence; | ||
| 319 | + | ||
| 320 | + Detection() {} | ||
| 321 | + Detection(const QRectF &boundingBox_, float confidence_ = -1) | ||
| 322 | + : boundingBox(boundingBox_), confidence(confidence_) {} | ||
| 323 | + | ||
| 324 | + float area() const { return boundingBox.width() * boundingBox.height(); } | ||
| 325 | + float overlap(const Detection &other) const | ||
| 326 | + { | ||
| 327 | + const Detection intersection(boundingBox.intersected(other.boundingBox)); | ||
| 328 | + return intersection.area() / (area() + other.area() - 2*intersection.area()); | ||
| 329 | + } | ||
| 330 | +}; | ||
| 331 | + | ||
| 332 | +struct Detections | ||
| 333 | +{ | ||
| 334 | + QList<Detection> predicted, truth; | ||
| 335 | +}; | ||
| 336 | + | ||
| 337 | +struct DetectionOperatingPoint | ||
| 338 | +{ | ||
| 339 | + float confidence, overlap; | ||
| 340 | + DetectionOperatingPoint() : confidence(-1), overlap(-1) {} | ||
| 341 | + DetectionOperatingPoint(float confidence_, float overlap_) : confidence(confidence_), overlap(overlap_) {} | ||
| 342 | + inline bool operator<(const DetectionOperatingPoint &other) const { return confidence > other.confidence; } | ||
| 343 | +}; | ||
| 344 | + | ||
| 345 | +float EvalDetection(const QString &predictedInput, const QString &truthInput, const QString &csv) | ||
| 346 | +{ | ||
| 347 | + qDebug("Evaluating detection of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); | ||
| 348 | + const TemplateList predicted(TemplateList::fromGallery(predictedInput)); | ||
| 349 | + const TemplateList truth(TemplateList::fromGallery(truthInput)); | ||
| 350 | + | ||
| 351 | + // Figure out which metadata field contains a bounding box | ||
| 352 | + QString detectKey; | ||
| 353 | + foreach (const QString &key, truth.first().file.localKeys()) | ||
| 354 | + if (!truth.first().file.get<QRectF>(key, QRectF()).isNull()) { | ||
| 355 | + detectKey = key; | ||
| 356 | + break; | ||
| 357 | + } | ||
| 358 | + if (detectKey.isNull()) qFatal("No suitable metadata key found."); | ||
| 359 | + else qDebug("Using metadata key: %s", qPrintable(detectKey)); | ||
| 360 | + | ||
| 361 | + QHash<QString, Detections> allDetections; // Organized by file | ||
| 362 | + foreach (const Template &t, predicted) | ||
| 363 | + allDetections[t.file.baseName()].predicted.append(Detection(t.file.get<QRectF>(detectKey), t.file.get<float>("Confidence", -1))); | ||
| 364 | + foreach (const Template &t, truth) | ||
| 365 | + allDetections[t.file.baseName()].truth.append(Detection(t.file.get<QRectF>(detectKey))); | ||
| 366 | + | ||
| 367 | + QList<DetectionOperatingPoint> points; | ||
| 368 | + foreach (Detections detections, allDetections.values()) { | ||
| 369 | + while (!detections.truth.isEmpty() && !detections.predicted.isEmpty()) { | ||
| 370 | + Detection truth = detections.truth.takeFirst(); | ||
| 371 | + int bestIndex = -1; | ||
| 372 | + float bestOverlap = -1; | ||
| 373 | + for (int i=0; i<detections.predicted.size(); i++) { | ||
| 374 | + const float overlap = truth.overlap(detections.predicted[i]); | ||
| 375 | + if (overlap > bestOverlap) { | ||
| 376 | + bestOverlap = overlap; | ||
| 377 | + bestIndex = i; | ||
| 378 | + } | ||
| 379 | + } | ||
| 380 | + Detection predicted = detections.predicted.takeAt(bestIndex); | ||
| 381 | + points.append(DetectionOperatingPoint(predicted.confidence, bestOverlap)); | ||
| 382 | + } | ||
| 383 | + | ||
| 384 | + foreach (const Detection &detection, detections.predicted) | ||
| 385 | + points.append(DetectionOperatingPoint(detection.confidence, 0)); | ||
| 386 | + for (int i=0; i<detections.truth.size(); i++) | ||
| 387 | + points.append(DetectionOperatingPoint(-std::numeric_limits<float>::max(), 0)); | ||
| 388 | + } | ||
| 389 | + | ||
| 390 | + std::sort(points.begin(), points.end()); | ||
| 391 | + | ||
| 392 | + QStringList lines; | ||
| 393 | + lines.append("Plot, X, Y"); | ||
| 394 | + | ||
| 395 | + // TODO: finish implementing | ||
| 396 | + | ||
| 397 | + (void) csv; | ||
| 398 | + return 0; | ||
| 399 | +} | ||
| 400 | + | ||
| 401 | +void EvalRegression(const QString &predictedInput, const QString &truthInput) | ||
| 402 | +{ | ||
| 403 | + qDebug("Evaluating regression of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); | ||
| 404 | + const TemplateList predicted(TemplateList::fromGallery(predictedInput)); | ||
| 405 | + const TemplateList truth(TemplateList::fromGallery(truthInput)); | ||
| 406 | + if (predicted.size() != truth.size()) qFatal("Input size mismatch."); | ||
| 407 | + | ||
| 408 | + float rmsError = 0; | ||
| 409 | + QStringList truthValues, predictedValues; | ||
| 410 | + for (int i=0; i<predicted.size(); i++) { | ||
| 411 | + if (predicted[i].file.name != truth[i].file.name) | ||
| 412 | + qFatal("Input order mismatch."); | ||
| 413 | + rmsError += pow(predicted[i].file.get<float>("Subject")-truth[i].file.get<float>("Subject"), 2.f); | ||
| 414 | + truthValues.append(QString::number(truth[i].file.get<float>("Subject"))); | ||
| 415 | + predictedValues.append(QString::number(predicted[i].file.get<float>("Subject"))); | ||
| 416 | + } | ||
| 417 | + | ||
| 418 | + QStringList rSource; | ||
| 419 | + rSource << "# Load libraries" << "library(ggplot2)" << "" << "# Set Data" | ||
| 420 | + << "Actual <- c(" + truthValues.join(",") + ")" | ||
| 421 | + << "Predicted <- c(" + predictedValues.join(",") + ")" | ||
| 422 | + << "data <- data.frame(Actual, Predicted)" | ||
| 423 | + << "" << "# Construct Plot" << "pdf(\"EvalRegression.pdf\")" | ||
| 424 | + << "print(qplot(Actual, Predicted, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=1, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())" | ||
| 425 | + << "print(qplot(Actual, Predicted-Actual, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=0, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())" | ||
| 426 | + << "dev.off()"; | ||
| 427 | + | ||
| 428 | + | ||
| 429 | + QString rFile = "EvalRegression.R"; | ||
| 430 | + QtUtils::writeFile(rFile, rSource); | ||
| 431 | + bool success = QtUtils::runRScript(rFile); | ||
| 432 | + if (success) QtUtils::showFile("EvalRegression.pdf"); | ||
| 433 | + | ||
| 434 | + qDebug("RMS Error = %f", sqrt(rmsError/predicted.size())); | ||
| 435 | +} | ||
| 436 | + | ||
| 437 | +} // namespace br |
openbr/core/classify.h renamed to openbr/core/eval.h
| @@ -14,18 +14,22 @@ | @@ -14,18 +14,22 @@ | ||
| 14 | * limitations under the License. * | 14 | * limitations under the License. * |
| 15 | * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ | 15 | * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ |
| 16 | 16 | ||
| 17 | -#ifndef __CLASSIFY_H | ||
| 18 | -#define __CLASSIFY_H | 17 | +#ifndef __EVAL_H |
| 18 | +#define __EVAL_H | ||
| 19 | 19 | ||
| 20 | #include <QList> | 20 | #include <QList> |
| 21 | #include <QString> | 21 | #include <QString> |
| 22 | +#include "openbr/openbr_plugin.h" | ||
| 22 | 23 | ||
| 23 | namespace br | 24 | namespace br |
| 24 | { | 25 | { |
| 26 | + float Evaluate(const QString &simmat, const QString &mask = "", const QString &csv = ""); // Returns TAR @ FAR = 0.001 | ||
| 27 | + float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv = "", int parition = 0); | ||
| 28 | + float Evaluate(const cv::Mat &scores, const cv::Mat &masks, const QString &csv = ""); | ||
| 25 | void EvalClassification(const QString &predictedInput, const QString &truthInput); | 29 | void EvalClassification(const QString &predictedInput, const QString &truthInput); |
| 26 | - void EvalDetection(const QString &predictedInput, const QString &truthInput); | 30 | + float EvalDetection(const QString &predictedInput, const QString &truthInput, const QString &csv = ""); // Return average overlap |
| 27 | void EvalRegression(const QString &predictedInput, const QString &truthInput); | 31 | void EvalRegression(const QString &predictedInput, const QString &truthInput); |
| 28 | } | 32 | } |
| 29 | 33 | ||
| 30 | -#endif // __CLASSIFY_H | 34 | +#endif // __EVAL_H |
| 31 | 35 |
openbr/core/plot.cpp
| @@ -14,57 +14,15 @@ | @@ -14,57 +14,15 @@ | ||
| 14 | * limitations under the License. * | 14 | * limitations under the License. * |
| 15 | * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ | 15 | * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ |
| 16 | 16 | ||
| 17 | -#include <QDebug> | ||
| 18 | -#include <QDir> | ||
| 19 | -#include <QFile> | ||
| 20 | -#include <QFileInfo> | ||
| 21 | -#include <QFuture> | ||
| 22 | -#include <QList> | ||
| 23 | -#include <QPair> | ||
| 24 | -#include <QPointF> | ||
| 25 | -#include <QRegExp> | ||
| 26 | -#include <QSet> | ||
| 27 | -#include <QStringList> | ||
| 28 | -#include <QVector> | ||
| 29 | -#include <QtAlgorithms> | ||
| 30 | -#include <opencv2/core/core.hpp> | ||
| 31 | -#include <assert.h> | ||
| 32 | - | ||
| 33 | #include "plot.h" | 17 | #include "plot.h" |
| 34 | #include "version.h" | 18 | #include "version.h" |
| 35 | -#include "openbr/core/bee.h" | ||
| 36 | -#include "openbr/core/common.h" | ||
| 37 | -#include "openbr/core/opencvutils.h" | ||
| 38 | #include "openbr/core/qtutils.h" | 19 | #include "openbr/core/qtutils.h" |
| 39 | 20 | ||
| 40 | -#undef FAR // Windows preprecessor definition | ||
| 41 | - | ||
| 42 | using namespace cv; | 21 | using namespace cv; |
| 43 | 22 | ||
| 44 | namespace br | 23 | namespace br |
| 45 | { | 24 | { |
| 46 | 25 | ||
| 47 | -void Confusion(const QString &file, float score, int &true_positives, int &false_positives, int &true_negatives, int &false_negatives) | ||
| 48 | -{ | ||
| 49 | - qDebug("Computing confusion matrix of %s at %f", qPrintable(file), score); | ||
| 50 | - | ||
| 51 | - QStringList lines = QtUtils::readLines(file); | ||
| 52 | - true_positives = false_positives = true_negatives = false_negatives = 0; | ||
| 53 | - foreach (const QString &line, lines) { | ||
| 54 | - if (!line.startsWith("SD")) continue; | ||
| 55 | - QStringList words = line.split(","); | ||
| 56 | - bool ok; | ||
| 57 | - float similarity = words[1].toFloat(&ok); assert(ok); | ||
| 58 | - if (words[2] == "Genuine") { | ||
| 59 | - if (similarity >= score) true_positives++; | ||
| 60 | - else false_negatives++; | ||
| 61 | - } else { | ||
| 62 | - if (similarity >= score) false_positives++; | ||
| 63 | - else true_negatives++; | ||
| 64 | - } | ||
| 65 | - } | ||
| 66 | -} | ||
| 67 | - | ||
| 68 | static QStringList getPivots(const QString &file, bool headers) | 26 | static QStringList getPivots(const QString &file, bool headers) |
| 69 | { | 27 | { |
| 70 | QString str; | 28 | QString str; |
| @@ -73,224 +31,6 @@ static QStringList getPivots(const QString &file, bool headers) | @@ -73,224 +31,6 @@ static QStringList getPivots(const QString &file, bool headers) | ||
| 73 | return str.split("_"); | 31 | return str.split("_"); |
| 74 | } | 32 | } |
| 75 | 33 | ||
| 76 | -struct Comparison | ||
| 77 | -{ | ||
| 78 | - float score; | ||
| 79 | - int target, query; | ||
| 80 | - bool genuine; | ||
| 81 | - | ||
| 82 | - Comparison() {} | ||
| 83 | - Comparison(float _score, int _target, int _query, bool _genuine) | ||
| 84 | - : score(_score), target(_target), query(_query), genuine(_genuine) {} | ||
| 85 | - inline bool operator<(const Comparison &other) const { return score > other.score; } | ||
| 86 | -}; | ||
| 87 | - | ||
| 88 | -struct OperatingPoint | ||
| 89 | -{ | ||
| 90 | - float score, FAR, TAR; | ||
| 91 | - OperatingPoint() {} | ||
| 92 | - OperatingPoint(float _score, float _FAR, float _TAR) | ||
| 93 | - : score(_score), FAR(_FAR), TAR(_TAR) {} | ||
| 94 | -}; | ||
| 95 | - | ||
| 96 | -static float getTAR(const QList<OperatingPoint> &operatingPoints, float FAR) | ||
| 97 | -{ | ||
| 98 | - int index = 0; | ||
| 99 | - while (operatingPoints[index].FAR < FAR) { | ||
| 100 | - index++; | ||
| 101 | - if (index == operatingPoints.size()) | ||
| 102 | - return 1; | ||
| 103 | - } | ||
| 104 | - | ||
| 105 | - const float x1 = (index == 0 ? 0 : operatingPoints[index-1].FAR); | ||
| 106 | - const float y1 = (index == 0 ? 0 : operatingPoints[index-1].TAR); | ||
| 107 | - const float x2 = operatingPoints[index].FAR; | ||
| 108 | - const float y2 = operatingPoints[index].TAR; | ||
| 109 | - const float m = (y2 - y1) / (x2 - x1); | ||
| 110 | - const float b = y1 - m*x1; | ||
| 111 | - return m * FAR + b; | ||
| 112 | -} | ||
| 113 | - | ||
| 114 | -float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv, int partition) | ||
| 115 | -{ | ||
| 116 | - return Evaluate(scores, BEE::makeMask(target, query, partition), csv); | ||
| 117 | -} | ||
| 118 | - | ||
| 119 | -float Evaluate(const QString &simmat, const QString &mask, const QString &csv) | ||
| 120 | -{ | ||
| 121 | - qDebug("Evaluating %s%s%s", | ||
| 122 | - qPrintable(simmat), | ||
| 123 | - mask.isEmpty() ? "" : qPrintable(" with " + mask), | ||
| 124 | - csv.isEmpty() ? "" : qPrintable(" to " + csv)); | ||
| 125 | - | ||
| 126 | - // Read similarity matrix | ||
| 127 | - QString target, query; | ||
| 128 | - const Mat scores = BEE::readSimmat(simmat, &target, &query); | ||
| 129 | - | ||
| 130 | - // Read mask matrix | ||
| 131 | - Mat truth; | ||
| 132 | - if (mask.isEmpty()) { | ||
| 133 | - // Use the galleries specified in the similarity matrix | ||
| 134 | - truth = BEE::makeMask(TemplateList::fromGallery(target).files(), | ||
| 135 | - TemplateList::fromGallery(query).files()); | ||
| 136 | - } else { | ||
| 137 | - File maskFile(mask); | ||
| 138 | - maskFile.set("rows", scores.rows); | ||
| 139 | - maskFile.set("columns", scores.cols); | ||
| 140 | - truth = BEE::readMask(maskFile); | ||
| 141 | - } | ||
| 142 | - | ||
| 143 | - return Evaluate(scores, truth, csv); | ||
| 144 | -} | ||
| 145 | - | ||
| 146 | -float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv) | ||
| 147 | -{ | ||
| 148 | - if (simmat.size() != mask.size()) | ||
| 149 | - qFatal("Similarity matrix (%ix%i) differs in size from mask matrix (%ix%i).", | ||
| 150 | - simmat.rows, simmat.cols, mask.rows, mask.cols); | ||
| 151 | - | ||
| 152 | - const int Max_Points = 500; | ||
| 153 | - float result = -1; | ||
| 154 | - | ||
| 155 | - // Make comparisons | ||
| 156 | - QList<Comparison> comparisons; comparisons.reserve(simmat.rows*simmat.cols); | ||
| 157 | - int genuineCount = 0, impostorCount = 0, numNaNs = 0; | ||
| 158 | - for (int i=0; i<simmat.rows; i++) { | ||
| 159 | - for (int j=0; j<simmat.cols; j++) { | ||
| 160 | - const BEE::Mask_t mask_val = mask.at<BEE::Mask_t>(i,j); | ||
| 161 | - const BEE::Simmat_t simmat_val = simmat.at<BEE::Simmat_t>(i,j); | ||
| 162 | - if (mask_val == BEE::DontCare) continue; | ||
| 163 | - if (simmat_val != simmat_val) { numNaNs++; continue; } | ||
| 164 | - comparisons.append(Comparison(simmat_val, j, i, mask_val == BEE::Match)); | ||
| 165 | - if (comparisons.last().genuine) genuineCount++; | ||
| 166 | - else impostorCount++; | ||
| 167 | - } | ||
| 168 | - } | ||
| 169 | - | ||
| 170 | - if (numNaNs > 0) qWarning("Encountered %d NaN scores!", numNaNs); | ||
| 171 | - if (genuineCount == 0) qFatal("No genuine scores!"); | ||
| 172 | - if (impostorCount == 0) qFatal("No impostor scores!"); | ||
| 173 | - | ||
| 174 | - // Sort comparisons by simmat_val (score) | ||
| 175 | - std::sort(comparisons.begin(), comparisons.end()); | ||
| 176 | - | ||
| 177 | - QList<OperatingPoint> operatingPoints; | ||
| 178 | - QList<float> genuines; genuines.reserve(sqrt((float)comparisons.size())); | ||
| 179 | - QList<float> impostors; impostors.reserve(comparisons.size()); | ||
| 180 | - QVector<int> firstGenuineReturns(simmat.rows, 0); | ||
| 181 | - | ||
| 182 | - int falsePositives = 0, previousFalsePositives = 0; | ||
| 183 | - int truePositives = 0, previousTruePositives = 0; | ||
| 184 | - int index = 0; | ||
| 185 | - float minGenuineScore = std::numeric_limits<float>::max(); | ||
| 186 | - float minImpostorScore = std::numeric_limits<float>::max(); | ||
| 187 | - | ||
| 188 | - while (index < comparisons.size()) { | ||
| 189 | - float thresh = comparisons[index].score; | ||
| 190 | - // Compute genuine and imposter statistics at a threshold | ||
| 191 | - while ((index < comparisons.size()) && | ||
| 192 | - (comparisons[index].score == thresh)) { | ||
| 193 | - const Comparison &comparison = comparisons[index]; | ||
| 194 | - if (comparison.genuine) { | ||
| 195 | - truePositives++; | ||
| 196 | - genuines.append(comparison.score); | ||
| 197 | - if (firstGenuineReturns[comparison.query] < 1) | ||
| 198 | - firstGenuineReturns[comparison.query] = abs(firstGenuineReturns[comparison.query]) + 1; | ||
| 199 | - if ((comparison.score != -std::numeric_limits<float>::max()) && | ||
| 200 | - (comparison.score < minGenuineScore)) | ||
| 201 | - minGenuineScore = comparison.score; | ||
| 202 | - } else { | ||
| 203 | - falsePositives++; | ||
| 204 | - impostors.append(comparison.score); | ||
| 205 | - if (firstGenuineReturns[comparison.query] < 1) | ||
| 206 | - firstGenuineReturns[comparison.query]--; | ||
| 207 | - if ((comparison.score != -std::numeric_limits<float>::max()) && | ||
| 208 | - (comparison.score < minImpostorScore)) | ||
| 209 | - minImpostorScore = comparison.score; | ||
| 210 | - } | ||
| 211 | - index++; | ||
| 212 | - } | ||
| 213 | - | ||
| 214 | - if ((falsePositives > previousFalsePositives) && | ||
| 215 | - (truePositives > previousTruePositives)) { | ||
| 216 | - // Restrict the extreme ends of the curve | ||
| 217 | - if ((truePositives >= 10) && (falsePositives < impostorCount/2)) | ||
| 218 | - operatingPoints.append(OperatingPoint(thresh, float(falsePositives)/impostorCount, float(truePositives)/genuineCount)); | ||
| 219 | - previousFalsePositives = falsePositives; | ||
| 220 | - previousTruePositives = truePositives; | ||
| 221 | - } | ||
| 222 | - } | ||
| 223 | - | ||
| 224 | - if (operatingPoints.size() == 0) operatingPoints.append(OperatingPoint(1, 1, 1)); | ||
| 225 | - if (operatingPoints.size() == 1) operatingPoints.prepend(OperatingPoint(0, 0, 0)); | ||
| 226 | - if (operatingPoints.size() > 2) operatingPoints.takeLast(); // Remove point (1,1) | ||
| 227 | - | ||
| 228 | - // Write Metadata table | ||
| 229 | - QStringList lines; | ||
| 230 | - lines.append("Plot,X,Y"); | ||
| 231 | - lines.append("Metadata,"+QString::number(simmat.cols)+",Gallery"); | ||
| 232 | - lines.append("Metadata,"+QString::number(simmat.rows)+",Probe"); | ||
| 233 | - lines.append("Metadata,"+QString::number(genuineCount)+",Genuine"); | ||
| 234 | - lines.append("Metadata,"+QString::number(impostorCount)+",Impostor"); | ||
| 235 | - lines.append("Metadata,"+QString::number(simmat.cols*simmat.rows-(genuineCount+impostorCount))+",Ignored"); | ||
| 236 | - | ||
| 237 | - // Write Detection Error Tradeoff (DET), PRE, REC | ||
| 238 | - int points = qMin(operatingPoints.size(), Max_Points); | ||
| 239 | - for (int i=0; i<points; i++) { | ||
| 240 | - const OperatingPoint &operatingPoint = operatingPoints[double(i) / double(points-1) * double(operatingPoints.size()-1)]; | ||
| 241 | - lines.append(QString("DET,%1,%2").arg(QString::number(operatingPoint.FAR), | ||
| 242 | - QString::number(1-operatingPoint.TAR))); | ||
| 243 | - lines.append(QString("FAR,%1,%2").arg(QString::number(operatingPoint.score), | ||
| 244 | - QString::number(operatingPoint.FAR))); | ||
| 245 | - lines.append(QString("FRR,%1,%2").arg(QString::number(operatingPoint.score), | ||
| 246 | - QString::number(1-operatingPoint.TAR))); | ||
| 247 | - } | ||
| 248 | - | ||
| 249 | - // Write FAR/TAR Bar Chart (BC) | ||
| 250 | - lines.append(qPrintable(QString("BC,0.001,%1").arg(QString::number(getTAR(operatingPoints, 0.001), 'f', 3)))); | ||
| 251 | - lines.append(qPrintable(QString("BC,0.01,%1").arg(QString::number(result = getTAR(operatingPoints, 0.01), 'f', 3)))); | ||
| 252 | - | ||
| 253 | - // Write SD & KDE | ||
| 254 | - points = qMin(qMin(Max_Points, genuines.size()), impostors.size()); | ||
| 255 | - QList<double> sampledGenuineScores; sampledGenuineScores.reserve(points); | ||
| 256 | - QList<double> sampledImpostorScores; sampledImpostorScores.reserve(points); | ||
| 257 | - | ||
| 258 | - if (points > 1) { | ||
| 259 | - for (int i=0; i<points; i++) { | ||
| 260 | - float genuineScore = genuines[double(i) / double(points-1) * double(genuines.size()-1)]; | ||
| 261 | - float impostorScore = impostors[double(i) / double(points-1) * double(impostors.size()-1)]; | ||
| 262 | - if (genuineScore == -std::numeric_limits<float>::max()) genuineScore = minGenuineScore; | ||
| 263 | - if (impostorScore == -std::numeric_limits<float>::max()) impostorScore = minImpostorScore; | ||
| 264 | - lines.append(QString("SD,%1,Genuine").arg(QString::number(genuineScore))); | ||
| 265 | - lines.append(QString("SD,%1,Impostor").arg(QString::number(impostorScore))); | ||
| 266 | - sampledGenuineScores.append(genuineScore); | ||
| 267 | - sampledImpostorScores.append(impostorScore); | ||
| 268 | - } | ||
| 269 | - } | ||
| 270 | - | ||
| 271 | - // Write Cumulative Match Characteristic (CMC) curve | ||
| 272 | - const int Max_Retrieval = 200; | ||
| 273 | - const int Report_Retrieval = 5; | ||
| 274 | - | ||
| 275 | - float reportRetrievalRate = -1; | ||
| 276 | - for (int i=1; i<=Max_Retrieval; i++) { | ||
| 277 | - int realizedReturns = 0, possibleReturns = 0; | ||
| 278 | - foreach (int firstGenuineReturn, firstGenuineReturns) { | ||
| 279 | - if (firstGenuineReturn > 0) { | ||
| 280 | - possibleReturns++; | ||
| 281 | - if (firstGenuineReturn <= i) realizedReturns++; | ||
| 282 | - } | ||
| 283 | - } | ||
| 284 | - const float retrievalRate = float(realizedReturns)/possibleReturns; | ||
| 285 | - lines.append(qPrintable(QString("CMC,%1,%2").arg(QString::number(i), QString::number(retrievalRate)))); | ||
| 286 | - if (i == Report_Retrieval) reportRetrievalRate = retrievalRate; | ||
| 287 | - } | ||
| 288 | - | ||
| 289 | - if (!csv.isEmpty()) QtUtils::writeFile(csv, lines); | ||
| 290 | - qDebug("TAR @ FAR = 0.01: %.3f\nRetrieval Rate @ Rank = %d: %.3f", result, Report_Retrieval, reportRetrievalRate); | ||
| 291 | - return result; | ||
| 292 | -} | ||
| 293 | - | ||
| 294 | static QString getScale(const QString &mode, const QString &title, int vals) | 34 | static QString getScale(const QString &mode, const QString &title, int vals) |
| 295 | { | 35 | { |
| 296 | if (vals > 12) return " + scale_"+mode+"_discrete(\""+title+"\")"; | 36 | if (vals > 12) return " + scale_"+mode+"_discrete(\""+title+"\")"; |
| @@ -474,7 +214,6 @@ struct RPlot | @@ -474,7 +214,6 @@ struct RPlot | ||
| 474 | }; | 214 | }; |
| 475 | 215 | ||
| 476 | // Does not work if dataset folder starts with a number | 216 | // Does not work if dataset folder starts with a number |
| 477 | - | ||
| 478 | bool Plot(const QStringList &files, const br::File &destination, bool show) | 217 | bool Plot(const QStringList &files, const br::File &destination, bool show) |
| 479 | { | 218 | { |
| 480 | qDebug("Plotting %d file(s) to %s", files.size(), qPrintable(destination)); | 219 | qDebug("Plotting %d file(s) to %s", files.size(), qPrintable(destination)); |
openbr/core/plot.h
| @@ -24,14 +24,8 @@ | @@ -24,14 +24,8 @@ | ||
| 24 | 24 | ||
| 25 | namespace br | 25 | namespace br |
| 26 | { | 26 | { |
| 27 | - | ||
| 28 | -void Confusion(const QString &file, float score, int &true_positives, int &false_positives, int &true_negatives, int &false_negatives); | ||
| 29 | -float Evaluate(const QString &simmat, const QString &mask = "", const QString &csv = ""); // Returns TAR @ FAR = 0.001 | ||
| 30 | -float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv = "", int parition = 0); | ||
| 31 | -float Evaluate(const cv::Mat &scores, const cv::Mat &masks, const QString &csv = ""); | ||
| 32 | -bool Plot(const QStringList &files, const br::File &destination, bool show = false); | ||
| 33 | -bool PlotMetadata(const QStringList &files, const QString &destination, bool show = false); | ||
| 34 | - | 27 | + bool Plot(const QStringList &files, const br::File &destination, bool show = false); |
| 28 | + bool PlotMetadata(const QStringList &files, const QString &destination, bool show = false); | ||
| 35 | } | 29 | } |
| 36 | 30 | ||
| 37 | #endif // __PLOT_H | 31 | #endif // __PLOT_H |
openbr/openbr.cpp
| @@ -17,8 +17,8 @@ | @@ -17,8 +17,8 @@ | ||
| 17 | #include <openbr/openbr_plugin.h> | 17 | #include <openbr/openbr_plugin.h> |
| 18 | 18 | ||
| 19 | #include "core/bee.h" | 19 | #include "core/bee.h" |
| 20 | -#include "core/classify.h" | ||
| 21 | #include "core/cluster.h" | 20 | #include "core/cluster.h" |
| 21 | +#include "core/eval.h" | ||
| 22 | #include "core/fuse.h" | 22 | #include "core/fuse.h" |
| 23 | #include "core/plot.h" | 23 | #include "core/plot.h" |
| 24 | #include "core/qtutils.h" | 24 | #include "core/qtutils.h" |
| @@ -51,11 +51,6 @@ void br_compare(const char *target_gallery, const char *query_gallery, const cha | @@ -51,11 +51,6 @@ void br_compare(const char *target_gallery, const char *query_gallery, const cha | ||
| 51 | Compare(File(target_gallery), File(query_gallery), File(output)); | 51 | Compare(File(target_gallery), File(query_gallery), File(output)); |
| 52 | } | 52 | } |
| 53 | 53 | ||
| 54 | -void br_confusion(const char *file, float score, int *true_positives, int *false_positives, int *true_negatives, int *false_negatives) | ||
| 55 | -{ | ||
| 56 | - return Confusion(file, score, *true_positives, *false_positives, *true_negatives, *false_negatives); | ||
| 57 | -} | ||
| 58 | - | ||
| 59 | void br_convert(const char *file_type, const char *input_file, const char *output_file) | 54 | void br_convert(const char *file_type, const char *input_file, const char *output_file) |
| 60 | { | 55 | { |
| 61 | Convert(File(file_type), File(input_file), File(output_file)); | 56 | Convert(File(file_type), File(input_file), File(output_file)); |
openbr/openbr.h
| @@ -115,20 +115,6 @@ BR_EXPORT void br_combine_masks(int num_input_masks, const char *input_masks[], | @@ -115,20 +115,6 @@ BR_EXPORT void br_combine_masks(int num_input_masks, const char *input_masks[], | ||
| 115 | BR_EXPORT void br_compare(const char *target_gallery, const char *query_gallery, const char *output = ""); | 115 | BR_EXPORT void br_compare(const char *target_gallery, const char *query_gallery, const char *output = ""); |
| 116 | 116 | ||
| 117 | /*! | 117 | /*! |
| 118 | - * \brief Computes the confusion matrix for a dataset at a particular threshold. | ||
| 119 | - * | ||
| 120 | - * <a href="http://en.wikipedia.org/wiki/Confusion_matrix">Wikipedia Explanation</a> | ||
| 121 | - * \param file <tt>.csv</tt> file created using \ref br_eval. | ||
| 122 | - * \param score The similarity score to threshold at. | ||
| 123 | - * \param[out] true_positives The true positive count. | ||
| 124 | - * \param[out] false_positives The false positive count. | ||
| 125 | - * \param[out] true_negatives The true negative count. | ||
| 126 | - * \param[out] false_negatives The false negative count. | ||
| 127 | - */ | ||
| 128 | -BR_EXPORT void br_confusion(const char *file, float score, | ||
| 129 | - int *true_positives, int *false_positives, int *true_negatives, int *false_negatives); | ||
| 130 | - | ||
| 131 | -/*! | ||
| 132 | * \brief Wraps br::Convert() | 118 | * \brief Wraps br::Convert() |
| 133 | */ | 119 | */ |
| 134 | BR_EXPORT void br_convert(const char *file_type, const char *input_file, const char *output_file); | 120 | BR_EXPORT void br_convert(const char *file_type, const char *input_file, const char *output_file); |
openbr/openbr_plugin.cpp
| @@ -436,7 +436,7 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | @@ -436,7 +436,7 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | ||
| 436 | // stores the index values in "Label" of the output template list | 436 | // stores the index values in "Label" of the output template list |
| 437 | TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propName) | 437 | TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propName) |
| 438 | { | 438 | { |
| 439 | - const QList<QString> originalLabels = tl.get<QString>(propName); | 439 | + const QList<QString> originalLabels = File::get<QString>(tl, propName); |
| 440 | QHash<QString,int> labelTable; | 440 | QHash<QString,int> labelTable; |
| 441 | foreach (const QString & label, originalLabels) | 441 | foreach (const QString & label, originalLabels) |
| 442 | if (!labelTable.contains(label)) | 442 | if (!labelTable.contains(label)) |
| @@ -464,7 +464,7 @@ QList<int> TemplateList::indexProperty(const QString & propName, QHash<QString, | @@ -464,7 +464,7 @@ QList<int> TemplateList::indexProperty(const QString & propName, QHash<QString, | ||
| 464 | valueMap.clear(); | 464 | valueMap.clear(); |
| 465 | reverseLookup.clear(); | 465 | reverseLookup.clear(); |
| 466 | 466 | ||
| 467 | - const QList<QVariant> originalLabels = values(propName); | 467 | + const QList<QVariant> originalLabels = File::values(*this, propName); |
| 468 | foreach (const QVariant & label, originalLabels) { | 468 | foreach (const QVariant & label, originalLabels) { |
| 469 | QString labelString = label.toString(); | 469 | QString labelString = label.toString(); |
| 470 | if (!valueMap.contains(labelString)) { | 470 | if (!valueMap.contains(labelString)) { |
| @@ -481,9 +481,9 @@ QList<int> TemplateList::indexProperty(const QString & propName, QHash<QString, | @@ -481,9 +481,9 @@ QList<int> TemplateList::indexProperty(const QString & propName, QHash<QString, | ||
| 481 | } | 481 | } |
| 482 | 482 | ||
| 483 | // uses -1 for missing values | 483 | // uses -1 for missing values |
| 484 | -QList<int> TemplateList::applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const | 484 | +QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString, int> &valueMap) const |
| 485 | { | 485 | { |
| 486 | - const QList<QString> originalLabels = get<QString>(propName); | 486 | + const QList<QString> originalLabels = File::get<QString>(*this, propName); |
| 487 | 487 | ||
| 488 | QList<int> result; | 488 | QList<int> result; |
| 489 | for (int i=0; i<originalLabels.size(); i++) { | 489 | for (int i=0; i<originalLabels.size(); i++) { |
openbr/openbr_plugin.h
| @@ -228,7 +228,20 @@ struct BR_EXPORT File | @@ -228,7 +228,20 @@ struct BR_EXPORT File | ||
| 228 | return variant.value<T>(); | 228 | return variant.value<T>(); |
| 229 | } | 229 | } |
| 230 | 230 | ||
| 231 | - /*!< \brief Returns a list of type T for the key, throwing an error if the key does not exist or if the value cannot be converted to the specified type. */ | 231 | + /*!< \brief Returns a value for the key, returning \em defaultValue if the key does not exist or can't be converted. */ |
| 232 | + template <typename T> | ||
| 233 | + T get(const QString &key, const T &defaultValue) const | ||
| 234 | + { | ||
| 235 | + if (!contains(key)) return defaultValue; | ||
| 236 | + QVariant variant = value(key); | ||
| 237 | + if (!variant.canConvert<T>()) return defaultValue; | ||
| 238 | + return variant.value<T>(); | ||
| 239 | + } | ||
| 240 | + | ||
| 241 | + /*!< \brief Specialization for boolean type. */ | ||
| 242 | + bool getBool(const QString &key, bool defaultValue = false) const; | ||
| 243 | + | ||
| 244 | + /*!< \brief Specialization for list type. Returns a list of type T for the key, throwing an error if the key does not exist or if the value cannot be converted to the specified type. */ | ||
| 232 | template <typename T> | 245 | template <typename T> |
| 233 | QList<T> getList(const QString &key) const | 246 | QList<T> getList(const QString &key) const |
| 234 | { | 247 | { |
| @@ -241,17 +254,31 @@ struct BR_EXPORT File | @@ -241,17 +254,31 @@ struct BR_EXPORT File | ||
| 241 | return list; | 254 | return list; |
| 242 | } | 255 | } |
| 243 | 256 | ||
| 244 | - /*!< \brief Specialization for boolean type. */ | ||
| 245 | - bool getBool(const QString &key, bool defaultValue = false) const; | 257 | + /*!< \brief Returns the value for the specified key for every file in the list. */ |
| 258 | + template<class U> | ||
| 259 | + static QList<QVariant> values(const QList<U> &fileList, const QString &key) | ||
| 260 | + { | ||
| 261 | + QList<QVariant> values; values.reserve(fileList.size()); | ||
| 262 | + foreach (const U &f, fileList) values.append(((const File&)f).value(key)); | ||
| 263 | + return values; | ||
| 264 | + } | ||
| 246 | 265 | ||
| 247 | - /*!< \brief Returns a value for the key, returning \em defaultValue if the key does not exist or can't be converted. */ | ||
| 248 | - template <typename T> | ||
| 249 | - T get(const QString &key, const T &defaultValue) const | 266 | + /*!< \brief Returns a value for the key for every file in the list, throwing an error if the key does not exist. */ |
| 267 | + template<class T, class U> | ||
| 268 | + static QList<T> get(const QList<U> &fileList, const QString &key) | ||
| 250 | { | 269 | { |
| 251 | - if (!contains(key)) return defaultValue; | ||
| 252 | - QVariant variant = value(key); | ||
| 253 | - if (!variant.canConvert<T>()) return defaultValue; | ||
| 254 | - return variant.value<T>(); | 270 | + QList<T> result; result.reserve(fileList.size()); |
| 271 | + foreach (const U &f, fileList) result.append(((const File&)f).get<T>(key)); | ||
| 272 | + return result; | ||
| 273 | + } | ||
| 274 | + | ||
| 275 | + /*!< \brief Returns a value for the key for every file in the list, returning \em defaultValue if the key does not exist or can't be converted. */ | ||
| 276 | + template<class T, class U> | ||
| 277 | + static QList<T> get(const QList<U> &fileList, const QString &key, const T &defaultValue) | ||
| 278 | + { | ||
| 279 | + QList<T> result; result.reserve(fileList.size()); | ||
| 280 | + foreach (const U &f, fileList) result.append(static_cast<const File&>(f).get<T>(key, defaultValue)); | ||
| 281 | + return result; | ||
| 255 | } | 282 | } |
| 256 | 283 | ||
| 257 | inline bool failed() const { return getBool("FTE") || getBool("FTO"); } /*!< \brief Returns \c true if the file failed to open or enroll, \c false otherwise. */ | 284 | inline bool failed() const { return getBool("FTE") || getBool("FTO"); } /*!< \brief Returns \c true if the file failed to open or enroll, \c false otherwise. */ |
| @@ -297,23 +324,6 @@ struct BR_EXPORT FileList : public QList<File> | @@ -297,23 +324,6 @@ struct BR_EXPORT FileList : public QList<File> | ||
| 297 | QStringList flat() const; /*!< \brief Returns br::File::flat() for each file in the list. */ | 324 | QStringList flat() const; /*!< \brief Returns br::File::flat() for each file in the list. */ |
| 298 | QStringList names() const; /*!< \brief Returns #br::File::name for each file in the list. */ | 325 | QStringList names() const; /*!< \brief Returns #br::File::name for each file in the list. */ |
| 299 | void sort(const QString& key); /*!< \brief Sort the list based on metadata. */ | 326 | void sort(const QString& key); /*!< \brief Sort the list based on metadata. */ |
| 300 | - /*!< \brief Returns values associated with the input propName for each file in the list. */ | ||
| 301 | - template<typename T> | ||
| 302 | - QList<T> get(const QString & propName) const | ||
| 303 | - { | ||
| 304 | - QList<T> values; values.reserve(size()); | ||
| 305 | - foreach (const File &f, *this) | ||
| 306 | - values.append(f.get<T>(propName)); | ||
| 307 | - return values; | ||
| 308 | - } | ||
| 309 | - template<typename T> | ||
| 310 | - QList<T> get(const QString & propName, T defaultValue) const | ||
| 311 | - { | ||
| 312 | - QList<T> values; values.reserve(size()); | ||
| 313 | - foreach (const File &f, *this) | ||
| 314 | - values.append(f.contains(propName) ? f.get<T>(propName) : defaultValue); | ||
| 315 | - return values; | ||
| 316 | - } | ||
| 317 | 327 | ||
| 318 | QList<int> crossValidationPartitions() const; /*!< \brief Returns the cross-validation partition (default=0) for each file in the list. */ | 328 | QList<int> crossValidationPartitions() const; /*!< \brief Returns the cross-validation partition (default=0) for each file in the list. */ |
| 319 | int failures() const; /*!< \brief Returns the number of files with br::File::failed(). */ | 329 | int failures() const; /*!< \brief Returns the number of files with br::File::failed(). */ |
| @@ -344,6 +354,7 @@ struct Template : public QList<cv::Mat> | @@ -344,6 +354,7 @@ struct Template : public QList<cv::Mat> | ||
| 344 | inline const cv::Mat &m() const { static const cv::Mat NullMatrix; | 354 | inline const cv::Mat &m() const { static const cv::Mat NullMatrix; |
| 345 | return isEmpty() ? qFatal("Empty template."), NullMatrix : last(); } /*!< \brief Idiom to treat the template as a matrix. */ | 355 | return isEmpty() ? qFatal("Empty template."), NullMatrix : last(); } /*!< \brief Idiom to treat the template as a matrix. */ |
| 346 | inline cv::Mat &m() { return isEmpty() ? append(cv::Mat()), last() : last(); } /*!< \brief Idiom to treat the template as a matrix. */ | 356 | inline cv::Mat &m() { return isEmpty() ? append(cv::Mat()), last() : last(); } /*!< \brief Idiom to treat the template as a matrix. */ |
| 357 | + inline const File &operator()() const { return file; } | ||
| 347 | inline cv::Mat &operator=(const cv::Mat &other) { return m() = other; } /*!< \brief Idiom to treat the template as a matrix. */ | 358 | inline cv::Mat &operator=(const cv::Mat &other) { return m() = other; } /*!< \brief Idiom to treat the template as a matrix. */ |
| 348 | inline operator const cv::Mat&() const { return m(); } /*!< \brief Idiom to treat the template as a matrix. */ | 359 | inline operator const cv::Mat&() const { return m(); } /*!< \brief Idiom to treat the template as a matrix. */ |
| 349 | inline operator cv::Mat&() { return m(); } /*!< \brief Idiom to treat the template as a matrix. */ | 360 | inline operator cv::Mat&() { return m(); } /*!< \brief Idiom to treat the template as a matrix. */ |
| @@ -406,7 +417,6 @@ struct TemplateList : public QList<Template> | @@ -406,7 +417,6 @@ struct TemplateList : public QList<Template> | ||
| 406 | QList<int> indexProperty(const QString & propName, QHash<QString, int> & valueMap, QHash<int, QVariant> & reverseLookup) const; | 417 | QList<int> indexProperty(const QString & propName, QHash<QString, int> & valueMap, QHash<int, QVariant> & reverseLookup) const; |
| 407 | QList<int> applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const; | 418 | QList<int> applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const; |
| 408 | 419 | ||
| 409 | - | ||
| 410 | /*! | 420 | /*! |
| 411 | * \brief Returns the total number of bytes in all the templates. | 421 | * \brief Returns the total number of bytes in all the templates. |
| 412 | */ | 422 | */ |
| @@ -477,23 +487,6 @@ struct TemplateList : public QList<Template> | @@ -477,23 +487,6 @@ struct TemplateList : public QList<Template> | ||
| 477 | FileList operator()() const { return files(); } | 487 | FileList operator()() const { return files(); } |
| 478 | 488 | ||
| 479 | /*! | 489 | /*! |
| 480 | - * \brief Returns br::Template::label() for each template in the list. | ||
| 481 | - */ | ||
| 482 | - template<typename T> | ||
| 483 | - QList<T> get(const QString & propName) const | ||
| 484 | - { | ||
| 485 | - QList<T> values; values.reserve(size()); | ||
| 486 | - foreach (const Template &t, *this) values.append(t.file.get<T>(propName)); | ||
| 487 | - return values; | ||
| 488 | - } | ||
| 489 | - QList<QVariant> values(const QString & propName) const | ||
| 490 | - { | ||
| 491 | - QList<QVariant> values; values.reserve(size()); | ||
| 492 | - foreach (const Template &t, *this) values.append(t.file.value(propName)); | ||
| 493 | - return values; | ||
| 494 | - } | ||
| 495 | - | ||
| 496 | - /*! | ||
| 497 | * \brief Returns the number of occurences for each label in the list. | 490 | * \brief Returns the number of occurences for each label in the list. |
| 498 | */ | 491 | */ |
| 499 | template<typename T> | 492 | template<typename T> |
| @@ -814,7 +807,8 @@ struct Factory | @@ -814,7 +807,8 @@ struct Factory | ||
| 814 | //! [Factory make] | 807 | //! [Factory make] |
| 815 | static T *make(const File &file) | 808 | static T *make(const File &file) |
| 816 | { | 809 | { |
| 817 | - QString name = file.suffix(); | 810 | + QString name = file.get<QString>("plugin", ""); |
| 811 | + if (name.isEmpty()) name = file.suffix(); | ||
| 818 | if (!names().contains(name)) { | 812 | if (!names().contains(name)) { |
| 819 | if (names().contains("Empty") && name.isEmpty()) name = "Empty"; | 813 | if (names().contains("Empty") && name.isEmpty()) name = "Empty"; |
| 820 | else if (names().contains("Default")) name = "Default"; | 814 | else if (names().contains("Default")) name = "Default"; |
| @@ -1024,7 +1018,7 @@ public: | @@ -1024,7 +1018,7 @@ public: | ||
| 1024 | virtual TemplateList readBlock(bool *done) = 0; /*!< \brief Retrieve a portion of the stored templates. */ | 1018 | virtual TemplateList readBlock(bool *done) = 0; /*!< \brief Retrieve a portion of the stored templates. */ |
| 1025 | void writeBlock(const TemplateList &templates); /*!< \brief Serialize a template list. */ | 1019 | void writeBlock(const TemplateList &templates); /*!< \brief Serialize a template list. */ |
| 1026 | virtual void write(const Template &t) = 0; /*!< \brief Serialize a template. */ | 1020 | virtual void write(const Template &t) = 0; /*!< \brief Serialize a template. */ |
| 1027 | - static Gallery *make(const File &file); /*!< \brief Make a gallery from a file list. */ | 1021 | + static Gallery *make(const File &file); /*!< \brief Make a gallery to/from a file on disk. */ |
| 1028 | 1022 | ||
| 1029 | private: | 1023 | private: |
| 1030 | QSharedPointer<Gallery> next; | 1024 | QSharedPointer<Gallery> next; |
openbr/plugins/eigen3.cpp
| @@ -348,7 +348,7 @@ class LDATransform : public Transform | @@ -348,7 +348,7 @@ class LDATransform : public Transform | ||
| 348 | 348 | ||
| 349 | // OpenBR ensures that class values range from 0 to numClasses-1. | 349 | // OpenBR ensures that class values range from 0 to numClasses-1. |
| 350 | // Label exists because we created it earlier with relabel | 350 | // Label exists because we created it earlier with relabel |
| 351 | - QList<int> classes = trainingSet.get<int>("Label"); | 351 | + QList<int> classes = File::get<int>(trainingSet, "Label"); |
| 352 | QMap<int, int> classCounts = trainingSet.countValues<int>("Label"); | 352 | QMap<int, int> classCounts = trainingSet.countValues<int>("Label"); |
| 353 | const int numClasses = classCounts.size(); | 353 | const int numClasses = classCounts.size(); |
| 354 | 354 |
openbr/plugins/gallery.cpp
| @@ -881,6 +881,45 @@ class statGallery : public Gallery | @@ -881,6 +881,45 @@ class statGallery : public Gallery | ||
| 881 | 881 | ||
| 882 | BR_REGISTER(Gallery, statGallery) | 882 | BR_REGISTER(Gallery, statGallery) |
| 883 | 883 | ||
| 884 | +/*! | ||
| 885 | + * \ingroup galleries | ||
| 886 | + * \brief Implements the FDDB detection format. | ||
| 887 | + * \author Josh Klontz \cite jklontz | ||
| 888 | + * | ||
| 889 | + * http://vis-www.cs.umass.edu/fddb/README.txt | ||
| 890 | + */ | ||
| 891 | +class FDDBGallery : public Gallery | ||
| 892 | +{ | ||
| 893 | + Q_OBJECT | ||
| 894 | + | ||
| 895 | + TemplateList readBlock(bool *done) | ||
| 896 | + { | ||
| 897 | + *done = true; | ||
| 898 | + QStringList lines = QtUtils::readLines(file); | ||
| 899 | + TemplateList templates; | ||
| 900 | + while (!lines.empty()) { | ||
| 901 | + const QString fileName = lines.takeFirst(); | ||
| 902 | + int numDetects = lines.takeFirst().toInt(); | ||
| 903 | + for (int i=0; i<numDetects; i++) { | ||
| 904 | + const QStringList detect = lines.takeFirst().split(' '); | ||
| 905 | + Template t(fileName); | ||
| 906 | + t.file.set("Face", QRectF(detect[0].toFloat(), detect[1].toFloat(), detect[2].toFloat(), detect[3].toFloat())); | ||
| 907 | + t.file.set("Confidence", detect[4].toFloat()); | ||
| 908 | + templates.append(t); | ||
| 909 | + } | ||
| 910 | + } | ||
| 911 | + return templates; | ||
| 912 | + } | ||
| 913 | + | ||
| 914 | + void write(const Template &t) | ||
| 915 | + { | ||
| 916 | + (void) t; | ||
| 917 | + qFatal("Not implemented."); | ||
| 918 | + } | ||
| 919 | +}; | ||
| 920 | + | ||
| 921 | +BR_REGISTER(Gallery, FDDBGallery) | ||
| 922 | + | ||
| 884 | } // namespace br | 923 | } // namespace br |
| 885 | 924 | ||
| 886 | #include "gallery.moc" | 925 | #include "gallery.moc" |
openbr/plugins/independent.cpp
| @@ -20,7 +20,7 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t | @@ -20,7 +20,7 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t | ||
| 20 | const bool atLeast = transform->instances < 0; | 20 | const bool atLeast = transform->instances < 0; |
| 21 | const int instances = abs(transform->instances); | 21 | const int instances = abs(transform->instances); |
| 22 | 22 | ||
| 23 | - QList<QString> allLabels = templates.get<QString>("Subject"); | 23 | + QList<QString> allLabels = File::get<QString>(templates, "Subject"); |
| 24 | QList<QString> uniqueLabels = allLabels.toSet().toList(); | 24 | QList<QString> uniqueLabels = allLabels.toSet().toList(); |
| 25 | qSort(uniqueLabels); | 25 | qSort(uniqueLabels); |
| 26 | 26 |
openbr/plugins/output.cpp
| @@ -35,9 +35,9 @@ | @@ -35,9 +35,9 @@ | ||
| 35 | #include "openbr_internal.h" | 35 | #include "openbr_internal.h" |
| 36 | 36 | ||
| 37 | #include "openbr/core/bee.h" | 37 | #include "openbr/core/bee.h" |
| 38 | +#include "openbr/core/eval.h" | ||
| 38 | #include "openbr/core/common.h" | 39 | #include "openbr/core/common.h" |
| 39 | #include "openbr/core/opencvutils.h" | 40 | #include "openbr/core/opencvutils.h" |
| 40 | -#include "openbr/core/plot.h" | ||
| 41 | #include "openbr/core/qtutils.h" | 41 | #include "openbr/core/qtutils.h" |
| 42 | 42 | ||
| 43 | namespace br | 43 | namespace br |
| @@ -146,8 +146,8 @@ class meltOutput : public MatrixOutput | @@ -146,8 +146,8 @@ class meltOutput : public MatrixOutput | ||
| 146 | QStringList lines; | 146 | QStringList lines; |
| 147 | if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys)); | 147 | if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys)); |
| 148 | 148 | ||
| 149 | - QList<QString> queryLabels = queryFiles.get<QString>("Subject"); | ||
| 150 | - QList<QString> targetLabels = targetFiles.get<QString>("Subject"); | 149 | + QList<QString> queryLabels = File::get<QString>(queryFiles, "Subject"); |
| 150 | + QList<QString> targetLabels = File::get<QString>(targetFiles, "Subject"); | ||
| 151 | 151 | ||
| 152 | for (int i=0; i<queryFiles.size(); i++) { | 152 | for (int i=0; i<queryFiles.size(); i++) { |
| 153 | for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) { | 153 | for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) { |
openbr/plugins/svm.cpp
| @@ -130,7 +130,7 @@ private: | @@ -130,7 +130,7 @@ private: | ||
| 130 | Mat lab; | 130 | Mat lab; |
| 131 | // If we are doing regression, assume subject has float values | 131 | // If we are doing regression, assume subject has float values |
| 132 | if (type == EPS_SVR || type == NU_SVR) { | 132 | if (type == EPS_SVR || type == NU_SVR) { |
| 133 | - lab = OpenCVUtils::toMat(_data.get<float>("Subject")); | 133 | + lab = OpenCVUtils::toMat(File::get<float>(_data, "Subject")); |
| 134 | } | 134 | } |
| 135 | // If we are doing classification, assume subject has discrete values, map them | 135 | // If we are doing classification, assume subject has discrete values, map them |
| 136 | // and store the mapping data | 136 | // and store the mapping data |
scripts/downloadDatasets.sh
| @@ -48,3 +48,19 @@ if [ ! -d ../data/MEDS/img ]; then | @@ -48,3 +48,19 @@ if [ ! -d ../data/MEDS/img ]; then | ||
| 48 | mv data/*/*.jpg ../data/MEDS/img | 48 | mv data/*/*.jpg ../data/MEDS/img |
| 49 | rm -r data NIST_SD32_MEDS-II_face.zip | 49 | rm -r data NIST_SD32_MEDS-II_face.zip |
| 50 | fi | 50 | fi |
| 51 | + | ||
| 52 | +# KTH | ||
| 53 | +if [ ! -d ../data/KTH/vid ]; then | ||
| 54 | + echo "Downloading KTH..." | ||
| 55 | + mkdir ../data/KTH/vid | ||
| 56 | + for vidclass in {'boxing','handclapping','handwaving','jogging','running','walking'}; do | ||
| 57 | + if hash curl 2>/dev/null; then | ||
| 58 | + curl -OL http://www.nada.kth.se/cvap/actions/${vidclass}.zip | ||
| 59 | + else | ||
| 60 | + wget http://www.nada.kth.se/cvap/actions/${vidclass}.zip | ||
| 61 | + fi | ||
| 62 | + mkdir ../data/KTH/vid/${vidclass} | ||
| 63 | + unzip ${vidclass}.zip -d ../data/KTH/vid/${vidclass} | ||
| 64 | + rm ${vidclass}.zip | ||
| 65 | + done | ||
| 66 | +fi |