Commit ec4af7799c2f589d03adea56955c8a2faac9f6a4

Authored by Scott Klum
2 parents 4936fcc1 ff04e9ed

Merge branch 'master' of https://github.com/biometrics/openbr

.gitignore
... ... @@ -29,3 +29,6 @@ scripts/results
29 29  
30 30 ### OS X ###
31 31 *.DS_Store
  32 +
  33 +### vim ###
  34 +*.swp
... ...
app/br/br.cpp
... ... @@ -140,13 +140,6 @@ public:
140 140 } else if (!strcmp(fun, "evalRegression")) {
141 141 check(parc == 2, "Incorrect parameter count for 'evalRegression'.");
142 142 br_eval_regression(parv[0], parv[1]);
143   - } else if (!strcmp(fun, "confusion")) {
144   - check(parc == 2, "Incorrect parameter count for 'confusion'.");
145   - int true_positives, false_positives, true_negatives, false_negatives;
146   - br_confusion(parv[0], atof(parv[1]),
147   - &true_positives, &false_positives, &true_negatives, &false_negatives);
148   - printf("True Positives = %d\nFalse Positives = %d\nTrue Negatives = %d\nFalseNegatives = %d\n",
149   - true_positives, false_positives, true_negatives, false_negatives);
150 143 } else if (!strcmp(fun, "plotMetadata")) {
151 144 check(parc >= 2, "Incorrect parameter count for 'plotMetadata'.");
152 145 br_plot_metadata(parc-1, parv, parv[parc-1], true);
... ... @@ -223,7 +216,6 @@ private:
223 216 "-evalClustering <clusters> <gallery>\n"
224 217 "-evalDetection <predicted_gallery> <truth_gallery>\n"
225 218 "-evalRegression <predicted_gallery> <truth_gallery>\n"
226   - "-confusion <file> <score>\n"
227 219 "-plotMetadata <file> ... <file> <columns>\n"
228 220 "-getHeader <matrix>\n"
229 221 "-setHeader {<matrix>} <target_gallery> <query_gallery>\n"
... ...
data/KTH/README.md 0 → 100644
  1 +## KTH Human Action Database
  2 +Grayscale human action videos. Six actions performed by 25 subjects in four scenarios, for a total of 600 160x120 videos.
  3 +* [Website](http://www.nada.kth.se/cvap/actions/)
... ...
data/README.md
... ... @@ -9,6 +9,7 @@
9 9 * [MEDS](MEDS/README.md)
10 10 * [MNIST](MNIST/README.md)
11 11 * [PCSO](PCSO/README.md)
  12 +* [KTH](KTH/README.md)
12 13  
13 14 For both practical and legal reasons we only include images for some of the datasets in this repository.
14 15 Researchers should contact the respective owners of the other datasets in order to obtain a copy.
... ...
openbr/core/bee.cpp
... ... @@ -268,8 +268,8 @@ cv::Mat BEE::makeMask(const br::FileList &amp;targets, const br::FileList &amp;queries,
268 268 {
269 269 // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet
270 270 // -cao
271   - QList<QString> targetLabels = targets.get<QString>("Subject", "-1");
272   - QList<QString> queryLabels = queries.get<QString>("Subject", "-1");
  271 + QList<QString> targetLabels = File::get<QString>(targets, "Subject", "-1");
  272 + QList<QString> queryLabels = File::get<QString>(queries, "Subject", "-1");
273 273 QList<int> targetPartitions = targets.crossValidationPartitions();
274 274 QList<int> queryPartitions = queries.crossValidationPartitions();
275 275  
... ...
openbr/core/classify.cpp deleted
1   -/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
2   - * Copyright 2012 The MITRE Corporation *
3   - * *
4   - * Licensed under the Apache License, Version 2.0 (the "License"); *
5   - * you may not use this file except in compliance with the License. *
6   - * You may obtain a copy of the License at *
7   - * *
8   - * http://www.apache.org/licenses/LICENSE-2.0 *
9   - * *
10   - * Unless required by applicable law or agreed to in writing, software *
11   - * distributed under the License is distributed on an "AS IS" BASIS, *
12   - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
13   - * See the License for the specific language governing permissions and *
14   - * limitations under the License. *
15   - * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
16   -
17   -#include <openbr/openbr_plugin.h>
18   -
19   -#include "classify.h"
20   -#include "openbr/core/qtutils.h"
21   -
22   -// Helper struct for statistics accumulation
23   -struct Counter
24   -{
25   - float truePositive, falsePositive, falseNegative;
26   - Counter()
27   - {
28   - truePositive = 0;
29   - falsePositive = 0;
30   - falseNegative = 0;
31   - }
32   -};
33   -
34   -void br::EvalClassification(const QString &predictedInput, const QString &truthInput)
35   -{
36   - qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
37   -
38   - TemplateList predicted(TemplateList::fromGallery(predictedInput));
39   - TemplateList truth(TemplateList::fromGallery(truthInput));
40   - if (predicted.size() != truth.size()) qFatal("Input size mismatch.");
41   -
42   - QHash<QString, Counter> counters;
43   - for (int i=0; i<predicted.size(); i++) {
44   - if (predicted[i].file.name != truth[i].file.name)
45   - qFatal("Input order mismatch.");
46   -
47   - // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy.
48   - QString predictedSubject = predicted[i].file.get<QString>("Subject");
49   - QString trueSubject = truth[i].file.get<QString>("Subject");
50   -
51   - QStringList predictedSubjects(predictedSubject);
52   - QStringList trueSubjects(trueSubject);
53   -
54   - foreach (const QString &subject, trueSubjects.toVector() /* Hack to copy the list. */) {
55   - if (predictedSubjects.contains(subject)) {
56   - counters[subject].truePositive++;
57   - trueSubjects.removeOne(subject);
58   - predictedSubjects.removeOne(subject);
59   - } else {
60   - counters[subject].falseNegative++;
61   - }
62   - }
63   -
64   - for (int i=0; i<trueSubjects.size(); i++)
65   - foreach (const QString &subject, predictedSubjects)
66   - counters[subject].falsePositive += 1.f / predictedSubjects.size();
67   - }
68   -
69   - const QStringList keys = counters.keys();
70   - QSharedPointer<Output> output(Output::make("", FileList() << "Count" << "Precision" << "Recall" << "F-score", FileList(keys)));
71   -
72   - int tpc = 0;
73   - int fnc = 0;
74   -
75   - for (int i=0; i<counters.size(); i++) {
76   - const QString &subject = keys[i];
77   - const Counter &counter = counters[subject];
78   - tpc += counter.truePositive;
79   - fnc += counter.falseNegative;
80   - const int count = counter.truePositive + counter.falseNegative;
81   - const float precision = counter.truePositive / (float)(counter.truePositive + counter.falsePositive);
82   - const float recall = counter.truePositive / (float)(counter.truePositive + counter.falseNegative);
83   - const float fscore = 2 * precision * recall / (precision + recall);
84   - output->setRelative(count, i, 0);
85   - output->setRelative(precision, i, 1);
86   - output->setRelative(recall, i, 2);
87   - output->setRelative(fscore, i, 3);
88   - }
89   -
90   - qDebug("Overall Accuracy = %f", (float)tpc / (float)(tpc + fnc));
91   -}
92   -
93   -void br::EvalDetection(const QString &predictedInput, const QString &truthInput)
94   -{
95   - (void) predictedInput;
96   - (void) truthInput;
97   -}
98   -
99   -void br::EvalRegression(const QString &predictedInput, const QString &truthInput)
100   -{
101   - qDebug("Evaluating regression of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
102   -
103   - const TemplateList predicted(TemplateList::fromGallery(predictedInput));
104   - const TemplateList truth(TemplateList::fromGallery(truthInput));
105   - if (predicted.size() != truth.size()) qFatal("Input size mismatch.");
106   -
107   - float rmsError = 0;
108   - QStringList truthValues, predictedValues;
109   - for (int i=0; i<predicted.size(); i++) {
110   - if (predicted[i].file.name != truth[i].file.name)
111   - qFatal("Input order mismatch.");
112   - rmsError += pow(predicted[i].file.get<float>("Subject")-truth[i].file.get<float>("Subject"), 2.f);
113   - truthValues.append(QString::number(truth[i].file.get<float>("Subject")));
114   - predictedValues.append(QString::number(predicted[i].file.get<float>("Subject")));
115   - }
116   -
117   - QStringList rSource;
118   - rSource << "# Load libraries" << "library(ggplot2)" << "" << "# Set Data"
119   - << "Actual <- c(" + truthValues.join(",") + ")"
120   - << "Predicted <- c(" + predictedValues.join(",") + ")"
121   - << "data <- data.frame(Actual, Predicted)"
122   - << "" << "# Construct Plot" << "pdf(\"EvalRegression.pdf\")"
123   - << "print(qplot(Actual, Predicted, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=1, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())"
124   - << "print(qplot(Actual, Predicted-Actual, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=0, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())"
125   - << "dev.off()";
126   -
127   -
128   - QString rFile = "EvalRegression.R";
129   - QtUtils::writeFile(rFile, rSource);
130   - bool success = QtUtils::runRScript(rFile);
131   - if (success) QtUtils::showFile("EvalRegression.pdf");
132   -
133   - qDebug("RMS Error = %f", sqrt(rmsError/predicted.size()));
134   -}
openbr/core/cluster.cpp
... ... @@ -280,7 +280,7 @@ void br::EvalClustering(const QString &amp;csv, const QString &amp;input)
280 280  
281 281 // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are
282 282 // not named).
283   - QList<int> labels = TemplateList::fromGallery(input).files().get<int>("Subject");
  283 + QList<int> labels = File::get<int>(TemplateList::fromGallery(input), "Subject");
284 284  
285 285 QHash<int, int> labelToIndex;
286 286 int nClusters = 0;
... ...
openbr/core/eval.cpp 0 → 100644
  1 +/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
  2 + * Copyright 2012 The MITRE Corporation *
  3 + * *
  4 + * Licensed under the Apache License, Version 2.0 (the "License"); *
  5 + * you may not use this file except in compliance with the License. *
  6 + * You may obtain a copy of the License at *
  7 + * *
  8 + * http://www.apache.org/licenses/LICENSE-2.0 *
  9 + * *
  10 + * Unless required by applicable law or agreed to in writing, software *
  11 + * distributed under the License is distributed on an "AS IS" BASIS, *
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
  13 + * See the License for the specific language governing permissions and *
  14 + * limitations under the License. *
  15 + * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
  16 +
  17 +#include "bee.h"
  18 +#include "eval.h"
  19 +#include "openbr/core/qtutils.h"
  20 +
  21 +using namespace cv;
  22 +
  23 +namespace br
  24 +{
  25 +
  26 +struct Comparison
  27 +{
  28 + float score;
  29 + int target, query;
  30 + bool genuine;
  31 +
  32 + Comparison() {}
  33 + Comparison(float _score, int _target, int _query, bool _genuine)
  34 + : score(_score), target(_target), query(_query), genuine(_genuine) {}
  35 + inline bool operator<(const Comparison &other) const { return score > other.score; }
  36 +};
  37 +
  38 +#undef FAR // Windows preprecessor definition conflicts with variable name
  39 +struct OperatingPoint
  40 +{
  41 + float score, FAR, TAR;
  42 + OperatingPoint() {}
  43 + OperatingPoint(float _score, float _FAR, float _TAR)
  44 + : score(_score), FAR(_FAR), TAR(_TAR) {}
  45 +};
  46 +
  47 +static float getTAR(const QList<OperatingPoint> &operatingPoints, float FAR)
  48 +{
  49 + int index = 0;
  50 + while (operatingPoints[index].FAR < FAR) {
  51 + index++;
  52 + if (index == operatingPoints.size())
  53 + return 1;
  54 + }
  55 +
  56 + const float x1 = (index == 0 ? 0 : operatingPoints[index-1].FAR);
  57 + const float y1 = (index == 0 ? 0 : operatingPoints[index-1].TAR);
  58 + const float x2 = operatingPoints[index].FAR;
  59 + const float y2 = operatingPoints[index].TAR;
  60 + const float m = (y2 - y1) / (x2 - x1);
  61 + const float b = y1 - m*x1;
  62 + return m * FAR + b;
  63 +}
  64 +
  65 +float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv, int partition)
  66 +{
  67 + return Evaluate(scores, BEE::makeMask(target, query, partition), csv);
  68 +}
  69 +
  70 +float Evaluate(const QString &simmat, const QString &mask, const QString &csv)
  71 +{
  72 + qDebug("Evaluating %s%s%s",
  73 + qPrintable(simmat),
  74 + mask.isEmpty() ? "" : qPrintable(" with " + mask),
  75 + csv.isEmpty() ? "" : qPrintable(" to " + csv));
  76 +
  77 + // Read similarity matrix
  78 + QString target, query;
  79 + const Mat scores = BEE::readSimmat(simmat, &target, &query);
  80 +
  81 + // Read mask matrix
  82 + Mat truth;
  83 + if (mask.isEmpty()) {
  84 + // Use the galleries specified in the similarity matrix
  85 + truth = BEE::makeMask(TemplateList::fromGallery(target).files(),
  86 + TemplateList::fromGallery(query).files());
  87 + } else {
  88 + File maskFile(mask);
  89 + maskFile.set("rows", scores.rows);
  90 + maskFile.set("columns", scores.cols);
  91 + truth = BEE::readMask(maskFile);
  92 + }
  93 +
  94 + return Evaluate(scores, truth, csv);
  95 +}
  96 +
  97 +float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv)
  98 +{
  99 + if (simmat.size() != mask.size())
  100 + qFatal("Similarity matrix (%ix%i) differs in size from mask matrix (%ix%i).",
  101 + simmat.rows, simmat.cols, mask.rows, mask.cols);
  102 +
  103 + const int Max_Points = 500;
  104 + float result = -1;
  105 +
  106 + // Make comparisons
  107 + QList<Comparison> comparisons; comparisons.reserve(simmat.rows*simmat.cols);
  108 + int genuineCount = 0, impostorCount = 0, numNaNs = 0;
  109 + for (int i=0; i<simmat.rows; i++) {
  110 + for (int j=0; j<simmat.cols; j++) {
  111 + const BEE::Mask_t mask_val = mask.at<BEE::Mask_t>(i,j);
  112 + const BEE::Simmat_t simmat_val = simmat.at<BEE::Simmat_t>(i,j);
  113 + if (mask_val == BEE::DontCare) continue;
  114 + if (simmat_val != simmat_val) { numNaNs++; continue; }
  115 + comparisons.append(Comparison(simmat_val, j, i, mask_val == BEE::Match));
  116 + if (comparisons.last().genuine) genuineCount++;
  117 + else impostorCount++;
  118 + }
  119 + }
  120 +
  121 + if (numNaNs > 0) qWarning("Encountered %d NaN scores!", numNaNs);
  122 + if (genuineCount == 0) qFatal("No genuine scores!");
  123 + if (impostorCount == 0) qFatal("No impostor scores!");
  124 +
  125 + // Sort comparisons by simmat_val (score)
  126 + std::sort(comparisons.begin(), comparisons.end());
  127 +
  128 + QList<OperatingPoint> operatingPoints;
  129 + QList<float> genuines; genuines.reserve(sqrt((float)comparisons.size()));
  130 + QList<float> impostors; impostors.reserve(comparisons.size());
  131 + QVector<int> firstGenuineReturns(simmat.rows, 0);
  132 +
  133 + int falsePositives = 0, previousFalsePositives = 0;
  134 + int truePositives = 0, previousTruePositives = 0;
  135 + int index = 0;
  136 + float minGenuineScore = std::numeric_limits<float>::max();
  137 + float minImpostorScore = std::numeric_limits<float>::max();
  138 +
  139 + while (index < comparisons.size()) {
  140 + float thresh = comparisons[index].score;
  141 + // Compute genuine and imposter statistics at a threshold
  142 + while ((index < comparisons.size()) &&
  143 + (comparisons[index].score == thresh)) {
  144 + const Comparison &comparison = comparisons[index];
  145 + if (comparison.genuine) {
  146 + truePositives++;
  147 + genuines.append(comparison.score);
  148 + if (firstGenuineReturns[comparison.query] < 1)
  149 + firstGenuineReturns[comparison.query] = abs(firstGenuineReturns[comparison.query]) + 1;
  150 + if ((comparison.score != -std::numeric_limits<float>::max()) &&
  151 + (comparison.score < minGenuineScore))
  152 + minGenuineScore = comparison.score;
  153 + } else {
  154 + falsePositives++;
  155 + impostors.append(comparison.score);
  156 + if (firstGenuineReturns[comparison.query] < 1)
  157 + firstGenuineReturns[comparison.query]--;
  158 + if ((comparison.score != -std::numeric_limits<float>::max()) &&
  159 + (comparison.score < minImpostorScore))
  160 + minImpostorScore = comparison.score;
  161 + }
  162 + index++;
  163 + }
  164 +
  165 + if ((falsePositives > previousFalsePositives) &&
  166 + (truePositives > previousTruePositives)) {
  167 + // Restrict the extreme ends of the curve
  168 + if ((truePositives >= 10) && (falsePositives < impostorCount/2))
  169 + operatingPoints.append(OperatingPoint(thresh, float(falsePositives)/impostorCount, float(truePositives)/genuineCount));
  170 + previousFalsePositives = falsePositives;
  171 + previousTruePositives = truePositives;
  172 + }
  173 + }
  174 +
  175 + if (operatingPoints.size() == 0) operatingPoints.append(OperatingPoint(1, 1, 1));
  176 + if (operatingPoints.size() == 1) operatingPoints.prepend(OperatingPoint(0, 0, 0));
  177 + if (operatingPoints.size() > 2) operatingPoints.takeLast(); // Remove point (1,1)
  178 +
  179 + // Write Metadata table
  180 + QStringList lines;
  181 + lines.append("Plot,X,Y");
  182 + lines.append("Metadata,"+QString::number(simmat.cols)+",Gallery");
  183 + lines.append("Metadata,"+QString::number(simmat.rows)+",Probe");
  184 + lines.append("Metadata,"+QString::number(genuineCount)+",Genuine");
  185 + lines.append("Metadata,"+QString::number(impostorCount)+",Impostor");
  186 + lines.append("Metadata,"+QString::number(simmat.cols*simmat.rows-(genuineCount+impostorCount))+",Ignored");
  187 +
  188 + // Write Detection Error Tradeoff (DET), PRE, REC
  189 + int points = qMin(operatingPoints.size(), Max_Points);
  190 + for (int i=0; i<points; i++) {
  191 + const OperatingPoint &operatingPoint = operatingPoints[double(i) / double(points-1) * double(operatingPoints.size()-1)];
  192 + lines.append(QString("DET,%1,%2").arg(QString::number(operatingPoint.FAR),
  193 + QString::number(1-operatingPoint.TAR)));
  194 + lines.append(QString("FAR,%1,%2").arg(QString::number(operatingPoint.score),
  195 + QString::number(operatingPoint.FAR)));
  196 + lines.append(QString("FRR,%1,%2").arg(QString::number(operatingPoint.score),
  197 + QString::number(1-operatingPoint.TAR)));
  198 + }
  199 +
  200 + // Write FAR/TAR Bar Chart (BC)
  201 + lines.append(qPrintable(QString("BC,0.001,%1").arg(QString::number(getTAR(operatingPoints, 0.001), 'f', 3))));
  202 + lines.append(qPrintable(QString("BC,0.01,%1").arg(QString::number(result = getTAR(operatingPoints, 0.01), 'f', 3))));
  203 +
  204 + // Write SD & KDE
  205 + points = qMin(qMin(Max_Points, genuines.size()), impostors.size());
  206 + QList<double> sampledGenuineScores; sampledGenuineScores.reserve(points);
  207 + QList<double> sampledImpostorScores; sampledImpostorScores.reserve(points);
  208 +
  209 + if (points > 1) {
  210 + for (int i=0; i<points; i++) {
  211 + float genuineScore = genuines[double(i) / double(points-1) * double(genuines.size()-1)];
  212 + float impostorScore = impostors[double(i) / double(points-1) * double(impostors.size()-1)];
  213 + if (genuineScore == -std::numeric_limits<float>::max()) genuineScore = minGenuineScore;
  214 + if (impostorScore == -std::numeric_limits<float>::max()) impostorScore = minImpostorScore;
  215 + lines.append(QString("SD,%1,Genuine").arg(QString::number(genuineScore)));
  216 + lines.append(QString("SD,%1,Impostor").arg(QString::number(impostorScore)));
  217 + sampledGenuineScores.append(genuineScore);
  218 + sampledImpostorScores.append(impostorScore);
  219 + }
  220 + }
  221 +
  222 + // Write Cumulative Match Characteristic (CMC) curve
  223 + const int Max_Retrieval = 200;
  224 + const int Report_Retrieval = 5;
  225 +
  226 + float reportRetrievalRate = -1;
  227 + for (int i=1; i<=Max_Retrieval; i++) {
  228 + int realizedReturns = 0, possibleReturns = 0;
  229 + foreach (int firstGenuineReturn, firstGenuineReturns) {
  230 + if (firstGenuineReturn > 0) {
  231 + possibleReturns++;
  232 + if (firstGenuineReturn <= i) realizedReturns++;
  233 + }
  234 + }
  235 + const float retrievalRate = float(realizedReturns)/possibleReturns;
  236 + lines.append(qPrintable(QString("CMC,%1,%2").arg(QString::number(i), QString::number(retrievalRate))));
  237 + if (i == Report_Retrieval) reportRetrievalRate = retrievalRate;
  238 + }
  239 +
  240 + if (!csv.isEmpty()) QtUtils::writeFile(csv, lines);
  241 + qDebug("TAR @ FAR = 0.01: %.3f\nRetrieval Rate @ Rank = %d: %.3f", result, Report_Retrieval, reportRetrievalRate);
  242 + return result;
  243 +}
  244 +
  245 +// Helper struct for statistics accumulation
  246 +struct Counter
  247 +{
  248 + float truePositive, falsePositive, falseNegative;
  249 + Counter()
  250 + {
  251 + truePositive = 0;
  252 + falsePositive = 0;
  253 + falseNegative = 0;
  254 + }
  255 +};
  256 +
  257 +void EvalClassification(const QString &predictedInput, const QString &truthInput)
  258 +{
  259 + qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
  260 + TemplateList predicted(TemplateList::fromGallery(predictedInput));
  261 + TemplateList truth(TemplateList::fromGallery(truthInput));
  262 + if (predicted.size() != truth.size()) qFatal("Input size mismatch.");
  263 +
  264 + QHash<QString, Counter> counters;
  265 + for (int i=0; i<predicted.size(); i++) {
  266 + if (predicted[i].file.name != truth[i].file.name)
  267 + qFatal("Input order mismatch.");
  268 +
  269 + // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy.
  270 + QString predictedSubject = predicted[i].file.get<QString>("Subject");
  271 + QString trueSubject = truth[i].file.get<QString>("Subject");
  272 +
  273 + QStringList predictedSubjects(predictedSubject);
  274 + QStringList trueSubjects(trueSubject);
  275 +
  276 + foreach (const QString &subject, trueSubjects.toVector() /* Hack to copy the list. */) {
  277 + if (predictedSubjects.contains(subject)) {
  278 + counters[subject].truePositive++;
  279 + trueSubjects.removeOne(subject);
  280 + predictedSubjects.removeOne(subject);
  281 + } else {
  282 + counters[subject].falseNegative++;
  283 + }
  284 + }
  285 +
  286 + for (int i=0; i<trueSubjects.size(); i++)
  287 + foreach (const QString &subject, predictedSubjects)
  288 + counters[subject].falsePositive += 1.f / predictedSubjects.size();
  289 + }
  290 +
  291 + const QStringList keys = counters.keys();
  292 + QSharedPointer<Output> output(Output::make("", FileList() << "Count" << "Precision" << "Recall" << "F-score", FileList(keys)));
  293 +
  294 + int tpc = 0;
  295 + int fnc = 0;
  296 +
  297 + for (int i=0; i<counters.size(); i++) {
  298 + const QString &subject = keys[i];
  299 + const Counter &counter = counters[subject];
  300 + tpc += counter.truePositive;
  301 + fnc += counter.falseNegative;
  302 + const int count = counter.truePositive + counter.falseNegative;
  303 + const float precision = counter.truePositive / (float)(counter.truePositive + counter.falsePositive);
  304 + const float recall = counter.truePositive / (float)(counter.truePositive + counter.falseNegative);
  305 + const float fscore = 2 * precision * recall / (precision + recall);
  306 + output->setRelative(count, i, 0);
  307 + output->setRelative(precision, i, 1);
  308 + output->setRelative(recall, i, 2);
  309 + output->setRelative(fscore, i, 3);
  310 + }
  311 +
  312 + qDebug("Overall Accuracy = %f", (float)tpc / (float)(tpc + fnc));
  313 +}
  314 +
  315 +struct Detection
  316 +{
  317 + QRectF boundingBox;
  318 + float confidence;
  319 +
  320 + Detection() {}
  321 + Detection(const QRectF &boundingBox_, float confidence_ = -1)
  322 + : boundingBox(boundingBox_), confidence(confidence_) {}
  323 +
  324 + float area() const { return boundingBox.width() * boundingBox.height(); }
  325 + float overlap(const Detection &other) const
  326 + {
  327 + const Detection intersection(boundingBox.intersected(other.boundingBox));
  328 + return intersection.area() / (area() + other.area() - 2*intersection.area());
  329 + }
  330 +};
  331 +
  332 +struct Detections
  333 +{
  334 + QList<Detection> predicted, truth;
  335 +};
  336 +
  337 +struct DetectionOperatingPoint
  338 +{
  339 + float confidence, overlap;
  340 + DetectionOperatingPoint() : confidence(-1), overlap(-1) {}
  341 + DetectionOperatingPoint(float confidence_, float overlap_) : confidence(confidence_), overlap(overlap_) {}
  342 + inline bool operator<(const DetectionOperatingPoint &other) const { return confidence > other.confidence; }
  343 +};
  344 +
  345 +float EvalDetection(const QString &predictedInput, const QString &truthInput, const QString &csv)
  346 +{
  347 + qDebug("Evaluating detection of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
  348 + const TemplateList predicted(TemplateList::fromGallery(predictedInput));
  349 + const TemplateList truth(TemplateList::fromGallery(truthInput));
  350 +
  351 + // Figure out which metadata field contains a bounding box
  352 + QString detectKey;
  353 + foreach (const QString &key, truth.first().file.localKeys())
  354 + if (!truth.first().file.get<QRectF>(key, QRectF()).isNull()) {
  355 + detectKey = key;
  356 + break;
  357 + }
  358 + if (detectKey.isNull()) qFatal("No suitable metadata key found.");
  359 + else qDebug("Using metadata key: %s", qPrintable(detectKey));
  360 +
  361 + QHash<QString, Detections> allDetections; // Organized by file
  362 + foreach (const Template &t, predicted)
  363 + allDetections[t.file.baseName()].predicted.append(Detection(t.file.get<QRectF>(detectKey), t.file.get<float>("Confidence", -1)));
  364 + foreach (const Template &t, truth)
  365 + allDetections[t.file.baseName()].truth.append(Detection(t.file.get<QRectF>(detectKey)));
  366 +
  367 + QList<DetectionOperatingPoint> points;
  368 + foreach (Detections detections, allDetections.values()) {
  369 + while (!detections.truth.isEmpty() && !detections.predicted.isEmpty()) {
  370 + Detection truth = detections.truth.takeFirst();
  371 + int bestIndex = -1;
  372 + float bestOverlap = -1;
  373 + for (int i=0; i<detections.predicted.size(); i++) {
  374 + const float overlap = truth.overlap(detections.predicted[i]);
  375 + if (overlap > bestOverlap) {
  376 + bestOverlap = overlap;
  377 + bestIndex = i;
  378 + }
  379 + }
  380 + Detection predicted = detections.predicted.takeAt(bestIndex);
  381 + points.append(DetectionOperatingPoint(predicted.confidence, bestOverlap));
  382 + }
  383 +
  384 + foreach (const Detection &detection, detections.predicted)
  385 + points.append(DetectionOperatingPoint(detection.confidence, 0));
  386 + for (int i=0; i<detections.truth.size(); i++)
  387 + points.append(DetectionOperatingPoint(-std::numeric_limits<float>::max(), 0));
  388 + }
  389 +
  390 + std::sort(points.begin(), points.end());
  391 +
  392 + QStringList lines;
  393 + lines.append("Plot, X, Y");
  394 +
  395 + // TODO: finish implementing
  396 +
  397 + (void) csv;
  398 + return 0;
  399 +}
  400 +
  401 +void EvalRegression(const QString &predictedInput, const QString &truthInput)
  402 +{
  403 + qDebug("Evaluating regression of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
  404 + const TemplateList predicted(TemplateList::fromGallery(predictedInput));
  405 + const TemplateList truth(TemplateList::fromGallery(truthInput));
  406 + if (predicted.size() != truth.size()) qFatal("Input size mismatch.");
  407 +
  408 + float rmsError = 0;
  409 + QStringList truthValues, predictedValues;
  410 + for (int i=0; i<predicted.size(); i++) {
  411 + if (predicted[i].file.name != truth[i].file.name)
  412 + qFatal("Input order mismatch.");
  413 + rmsError += pow(predicted[i].file.get<float>("Subject")-truth[i].file.get<float>("Subject"), 2.f);
  414 + truthValues.append(QString::number(truth[i].file.get<float>("Subject")));
  415 + predictedValues.append(QString::number(predicted[i].file.get<float>("Subject")));
  416 + }
  417 +
  418 + QStringList rSource;
  419 + rSource << "# Load libraries" << "library(ggplot2)" << "" << "# Set Data"
  420 + << "Actual <- c(" + truthValues.join(",") + ")"
  421 + << "Predicted <- c(" + predictedValues.join(",") + ")"
  422 + << "data <- data.frame(Actual, Predicted)"
  423 + << "" << "# Construct Plot" << "pdf(\"EvalRegression.pdf\")"
  424 + << "print(qplot(Actual, Predicted, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=1, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())"
  425 + << "print(qplot(Actual, Predicted-Actual, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=0, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())"
  426 + << "dev.off()";
  427 +
  428 +
  429 + QString rFile = "EvalRegression.R";
  430 + QtUtils::writeFile(rFile, rSource);
  431 + bool success = QtUtils::runRScript(rFile);
  432 + if (success) QtUtils::showFile("EvalRegression.pdf");
  433 +
  434 + qDebug("RMS Error = %f", sqrt(rmsError/predicted.size()));
  435 +}
  436 +
  437 +} // namespace br
... ...
openbr/core/classify.h renamed to openbr/core/eval.h
... ... @@ -14,18 +14,22 @@
14 14 * limitations under the License. *
15 15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
16 16  
17   -#ifndef __CLASSIFY_H
18   -#define __CLASSIFY_H
  17 +#ifndef __EVAL_H
  18 +#define __EVAL_H
19 19  
20 20 #include <QList>
21 21 #include <QString>
  22 +#include "openbr/openbr_plugin.h"
22 23  
23 24 namespace br
24 25 {
  26 + float Evaluate(const QString &simmat, const QString &mask = "", const QString &csv = ""); // Returns TAR @ FAR = 0.001
  27 + float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv = "", int parition = 0);
  28 + float Evaluate(const cv::Mat &scores, const cv::Mat &masks, const QString &csv = "");
25 29 void EvalClassification(const QString &predictedInput, const QString &truthInput);
26   - void EvalDetection(const QString &predictedInput, const QString &truthInput);
  30 + float EvalDetection(const QString &predictedInput, const QString &truthInput, const QString &csv = ""); // Return average overlap
27 31 void EvalRegression(const QString &predictedInput, const QString &truthInput);
28 32 }
29 33  
30   -#endif // __CLASSIFY_H
  34 +#endif // __EVAL_H
31 35  
... ...
openbr/core/plot.cpp
... ... @@ -14,57 +14,15 @@
14 14 * limitations under the License. *
15 15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
16 16  
17   -#include <QDebug>
18   -#include <QDir>
19   -#include <QFile>
20   -#include <QFileInfo>
21   -#include <QFuture>
22   -#include <QList>
23   -#include <QPair>
24   -#include <QPointF>
25   -#include <QRegExp>
26   -#include <QSet>
27   -#include <QStringList>
28   -#include <QVector>
29   -#include <QtAlgorithms>
30   -#include <opencv2/core/core.hpp>
31   -#include <assert.h>
32   -
33 17 #include "plot.h"
34 18 #include "version.h"
35   -#include "openbr/core/bee.h"
36   -#include "openbr/core/common.h"
37   -#include "openbr/core/opencvutils.h"
38 19 #include "openbr/core/qtutils.h"
39 20  
40   -#undef FAR // Windows preprecessor definition
41   -
42 21 using namespace cv;
43 22  
44 23 namespace br
45 24 {
46 25  
47   -void Confusion(const QString &file, float score, int &true_positives, int &false_positives, int &true_negatives, int &false_negatives)
48   -{
49   - qDebug("Computing confusion matrix of %s at %f", qPrintable(file), score);
50   -
51   - QStringList lines = QtUtils::readLines(file);
52   - true_positives = false_positives = true_negatives = false_negatives = 0;
53   - foreach (const QString &line, lines) {
54   - if (!line.startsWith("SD")) continue;
55   - QStringList words = line.split(",");
56   - bool ok;
57   - float similarity = words[1].toFloat(&ok); assert(ok);
58   - if (words[2] == "Genuine") {
59   - if (similarity >= score) true_positives++;
60   - else false_negatives++;
61   - } else {
62   - if (similarity >= score) false_positives++;
63   - else true_negatives++;
64   - }
65   - }
66   -}
67   -
68 26 static QStringList getPivots(const QString &file, bool headers)
69 27 {
70 28 QString str;
... ... @@ -73,224 +31,6 @@ static QStringList getPivots(const QString &amp;file, bool headers)
73 31 return str.split("_");
74 32 }
75 33  
76   -struct Comparison
77   -{
78   - float score;
79   - int target, query;
80   - bool genuine;
81   -
82   - Comparison() {}
83   - Comparison(float _score, int _target, int _query, bool _genuine)
84   - : score(_score), target(_target), query(_query), genuine(_genuine) {}
85   - inline bool operator<(const Comparison &other) const { return score > other.score; }
86   -};
87   -
88   -struct OperatingPoint
89   -{
90   - float score, FAR, TAR;
91   - OperatingPoint() {}
92   - OperatingPoint(float _score, float _FAR, float _TAR)
93   - : score(_score), FAR(_FAR), TAR(_TAR) {}
94   -};
95   -
96   -static float getTAR(const QList<OperatingPoint> &operatingPoints, float FAR)
97   -{
98   - int index = 0;
99   - while (operatingPoints[index].FAR < FAR) {
100   - index++;
101   - if (index == operatingPoints.size())
102   - return 1;
103   - }
104   -
105   - const float x1 = (index == 0 ? 0 : operatingPoints[index-1].FAR);
106   - const float y1 = (index == 0 ? 0 : operatingPoints[index-1].TAR);
107   - const float x2 = operatingPoints[index].FAR;
108   - const float y2 = operatingPoints[index].TAR;
109   - const float m = (y2 - y1) / (x2 - x1);
110   - const float b = y1 - m*x1;
111   - return m * FAR + b;
112   -}
113   -
114   -float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv, int partition)
115   -{
116   - return Evaluate(scores, BEE::makeMask(target, query, partition), csv);
117   -}
118   -
119   -float Evaluate(const QString &simmat, const QString &mask, const QString &csv)
120   -{
121   - qDebug("Evaluating %s%s%s",
122   - qPrintable(simmat),
123   - mask.isEmpty() ? "" : qPrintable(" with " + mask),
124   - csv.isEmpty() ? "" : qPrintable(" to " + csv));
125   -
126   - // Read similarity matrix
127   - QString target, query;
128   - const Mat scores = BEE::readSimmat(simmat, &target, &query);
129   -
130   - // Read mask matrix
131   - Mat truth;
132   - if (mask.isEmpty()) {
133   - // Use the galleries specified in the similarity matrix
134   - truth = BEE::makeMask(TemplateList::fromGallery(target).files(),
135   - TemplateList::fromGallery(query).files());
136   - } else {
137   - File maskFile(mask);
138   - maskFile.set("rows", scores.rows);
139   - maskFile.set("columns", scores.cols);
140   - truth = BEE::readMask(maskFile);
141   - }
142   -
143   - return Evaluate(scores, truth, csv);
144   -}
145   -
146   -float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv)
147   -{
148   - if (simmat.size() != mask.size())
149   - qFatal("Similarity matrix (%ix%i) differs in size from mask matrix (%ix%i).",
150   - simmat.rows, simmat.cols, mask.rows, mask.cols);
151   -
152   - const int Max_Points = 500;
153   - float result = -1;
154   -
155   - // Make comparisons
156   - QList<Comparison> comparisons; comparisons.reserve(simmat.rows*simmat.cols);
157   - int genuineCount = 0, impostorCount = 0, numNaNs = 0;
158   - for (int i=0; i<simmat.rows; i++) {
159   - for (int j=0; j<simmat.cols; j++) {
160   - const BEE::Mask_t mask_val = mask.at<BEE::Mask_t>(i,j);
161   - const BEE::Simmat_t simmat_val = simmat.at<BEE::Simmat_t>(i,j);
162   - if (mask_val == BEE::DontCare) continue;
163   - if (simmat_val != simmat_val) { numNaNs++; continue; }
164   - comparisons.append(Comparison(simmat_val, j, i, mask_val == BEE::Match));
165   - if (comparisons.last().genuine) genuineCount++;
166   - else impostorCount++;
167   - }
168   - }
169   -
170   - if (numNaNs > 0) qWarning("Encountered %d NaN scores!", numNaNs);
171   - if (genuineCount == 0) qFatal("No genuine scores!");
172   - if (impostorCount == 0) qFatal("No impostor scores!");
173   -
174   - // Sort comparisons by simmat_val (score)
175   - std::sort(comparisons.begin(), comparisons.end());
176   -
177   - QList<OperatingPoint> operatingPoints;
178   - QList<float> genuines; genuines.reserve(sqrt((float)comparisons.size()));
179   - QList<float> impostors; impostors.reserve(comparisons.size());
180   - QVector<int> firstGenuineReturns(simmat.rows, 0);
181   -
182   - int falsePositives = 0, previousFalsePositives = 0;
183   - int truePositives = 0, previousTruePositives = 0;
184   - int index = 0;
185   - float minGenuineScore = std::numeric_limits<float>::max();
186   - float minImpostorScore = std::numeric_limits<float>::max();
187   -
188   - while (index < comparisons.size()) {
189   - float thresh = comparisons[index].score;
190   - // Compute genuine and imposter statistics at a threshold
191   - while ((index < comparisons.size()) &&
192   - (comparisons[index].score == thresh)) {
193   - const Comparison &comparison = comparisons[index];
194   - if (comparison.genuine) {
195   - truePositives++;
196   - genuines.append(comparison.score);
197   - if (firstGenuineReturns[comparison.query] < 1)
198   - firstGenuineReturns[comparison.query] = abs(firstGenuineReturns[comparison.query]) + 1;
199   - if ((comparison.score != -std::numeric_limits<float>::max()) &&
200   - (comparison.score < minGenuineScore))
201   - minGenuineScore = comparison.score;
202   - } else {
203   - falsePositives++;
204   - impostors.append(comparison.score);
205   - if (firstGenuineReturns[comparison.query] < 1)
206   - firstGenuineReturns[comparison.query]--;
207   - if ((comparison.score != -std::numeric_limits<float>::max()) &&
208   - (comparison.score < minImpostorScore))
209   - minImpostorScore = comparison.score;
210   - }
211   - index++;
212   - }
213   -
214   - if ((falsePositives > previousFalsePositives) &&
215   - (truePositives > previousTruePositives)) {
216   - // Restrict the extreme ends of the curve
217   - if ((truePositives >= 10) && (falsePositives < impostorCount/2))
218   - operatingPoints.append(OperatingPoint(thresh, float(falsePositives)/impostorCount, float(truePositives)/genuineCount));
219   - previousFalsePositives = falsePositives;
220   - previousTruePositives = truePositives;
221   - }
222   - }
223   -
224   - if (operatingPoints.size() == 0) operatingPoints.append(OperatingPoint(1, 1, 1));
225   - if (operatingPoints.size() == 1) operatingPoints.prepend(OperatingPoint(0, 0, 0));
226   - if (operatingPoints.size() > 2) operatingPoints.takeLast(); // Remove point (1,1)
227   -
228   - // Write Metadata table
229   - QStringList lines;
230   - lines.append("Plot,X,Y");
231   - lines.append("Metadata,"+QString::number(simmat.cols)+",Gallery");
232   - lines.append("Metadata,"+QString::number(simmat.rows)+",Probe");
233   - lines.append("Metadata,"+QString::number(genuineCount)+",Genuine");
234   - lines.append("Metadata,"+QString::number(impostorCount)+",Impostor");
235   - lines.append("Metadata,"+QString::number(simmat.cols*simmat.rows-(genuineCount+impostorCount))+",Ignored");
236   -
237   - // Write Detection Error Tradeoff (DET), PRE, REC
238   - int points = qMin(operatingPoints.size(), Max_Points);
239   - for (int i=0; i<points; i++) {
240   - const OperatingPoint &operatingPoint = operatingPoints[double(i) / double(points-1) * double(operatingPoints.size()-1)];
241   - lines.append(QString("DET,%1,%2").arg(QString::number(operatingPoint.FAR),
242   - QString::number(1-operatingPoint.TAR)));
243   - lines.append(QString("FAR,%1,%2").arg(QString::number(operatingPoint.score),
244   - QString::number(operatingPoint.FAR)));
245   - lines.append(QString("FRR,%1,%2").arg(QString::number(operatingPoint.score),
246   - QString::number(1-operatingPoint.TAR)));
247   - }
248   -
249   - // Write FAR/TAR Bar Chart (BC)
250   - lines.append(qPrintable(QString("BC,0.001,%1").arg(QString::number(getTAR(operatingPoints, 0.001), 'f', 3))));
251   - lines.append(qPrintable(QString("BC,0.01,%1").arg(QString::number(result = getTAR(operatingPoints, 0.01), 'f', 3))));
252   -
253   - // Write SD & KDE
254   - points = qMin(qMin(Max_Points, genuines.size()), impostors.size());
255   - QList<double> sampledGenuineScores; sampledGenuineScores.reserve(points);
256   - QList<double> sampledImpostorScores; sampledImpostorScores.reserve(points);
257   -
258   - if (points > 1) {
259   - for (int i=0; i<points; i++) {
260   - float genuineScore = genuines[double(i) / double(points-1) * double(genuines.size()-1)];
261   - float impostorScore = impostors[double(i) / double(points-1) * double(impostors.size()-1)];
262   - if (genuineScore == -std::numeric_limits<float>::max()) genuineScore = minGenuineScore;
263   - if (impostorScore == -std::numeric_limits<float>::max()) impostorScore = minImpostorScore;
264   - lines.append(QString("SD,%1,Genuine").arg(QString::number(genuineScore)));
265   - lines.append(QString("SD,%1,Impostor").arg(QString::number(impostorScore)));
266   - sampledGenuineScores.append(genuineScore);
267   - sampledImpostorScores.append(impostorScore);
268   - }
269   - }
270   -
271   - // Write Cumulative Match Characteristic (CMC) curve
272   - const int Max_Retrieval = 200;
273   - const int Report_Retrieval = 5;
274   -
275   - float reportRetrievalRate = -1;
276   - for (int i=1; i<=Max_Retrieval; i++) {
277   - int realizedReturns = 0, possibleReturns = 0;
278   - foreach (int firstGenuineReturn, firstGenuineReturns) {
279   - if (firstGenuineReturn > 0) {
280   - possibleReturns++;
281   - if (firstGenuineReturn <= i) realizedReturns++;
282   - }
283   - }
284   - const float retrievalRate = float(realizedReturns)/possibleReturns;
285   - lines.append(qPrintable(QString("CMC,%1,%2").arg(QString::number(i), QString::number(retrievalRate))));
286   - if (i == Report_Retrieval) reportRetrievalRate = retrievalRate;
287   - }
288   -
289   - if (!csv.isEmpty()) QtUtils::writeFile(csv, lines);
290   - qDebug("TAR @ FAR = 0.01: %.3f\nRetrieval Rate @ Rank = %d: %.3f", result, Report_Retrieval, reportRetrievalRate);
291   - return result;
292   -}
293   -
294 34 static QString getScale(const QString &mode, const QString &title, int vals)
295 35 {
296 36 if (vals > 12) return " + scale_"+mode+"_discrete(\""+title+"\")";
... ... @@ -474,7 +214,6 @@ struct RPlot
474 214 };
475 215  
476 216 // Does not work if dataset folder starts with a number
477   -
478 217 bool Plot(const QStringList &files, const br::File &destination, bool show)
479 218 {
480 219 qDebug("Plotting %d file(s) to %s", files.size(), qPrintable(destination));
... ...
openbr/core/plot.h
... ... @@ -24,14 +24,8 @@
24 24  
25 25 namespace br
26 26 {
27   -
28   -void Confusion(const QString &file, float score, int &true_positives, int &false_positives, int &true_negatives, int &false_negatives);
29   -float Evaluate(const QString &simmat, const QString &mask = "", const QString &csv = ""); // Returns TAR @ FAR = 0.001
30   -float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv = "", int parition = 0);
31   -float Evaluate(const cv::Mat &scores, const cv::Mat &masks, const QString &csv = "");
32   -bool Plot(const QStringList &files, const br::File &destination, bool show = false);
33   -bool PlotMetadata(const QStringList &files, const QString &destination, bool show = false);
34   -
  27 + bool Plot(const QStringList &files, const br::File &destination, bool show = false);
  28 + bool PlotMetadata(const QStringList &files, const QString &destination, bool show = false);
35 29 }
36 30  
37 31 #endif // __PLOT_H
... ...
openbr/openbr.cpp
... ... @@ -17,8 +17,8 @@
17 17 #include <openbr/openbr_plugin.h>
18 18  
19 19 #include "core/bee.h"
20   -#include "core/classify.h"
21 20 #include "core/cluster.h"
  21 +#include "core/eval.h"
22 22 #include "core/fuse.h"
23 23 #include "core/plot.h"
24 24 #include "core/qtutils.h"
... ... @@ -51,11 +51,6 @@ void br_compare(const char *target_gallery, const char *query_gallery, const cha
51 51 Compare(File(target_gallery), File(query_gallery), File(output));
52 52 }
53 53  
54   -void br_confusion(const char *file, float score, int *true_positives, int *false_positives, int *true_negatives, int *false_negatives)
55   -{
56   - return Confusion(file, score, *true_positives, *false_positives, *true_negatives, *false_negatives);
57   -}
58   -
59 54 void br_convert(const char *file_type, const char *input_file, const char *output_file)
60 55 {
61 56 Convert(File(file_type), File(input_file), File(output_file));
... ...
openbr/openbr.h
... ... @@ -115,20 +115,6 @@ BR_EXPORT void br_combine_masks(int num_input_masks, const char *input_masks[],
115 115 BR_EXPORT void br_compare(const char *target_gallery, const char *query_gallery, const char *output = "");
116 116  
117 117 /*!
118   - * \brief Computes the confusion matrix for a dataset at a particular threshold.
119   - *
120   - * <a href="http://en.wikipedia.org/wiki/Confusion_matrix">Wikipedia Explanation</a>
121   - * \param file <tt>.csv</tt> file created using \ref br_eval.
122   - * \param score The similarity score to threshold at.
123   - * \param[out] true_positives The true positive count.
124   - * \param[out] false_positives The false positive count.
125   - * \param[out] true_negatives The true negative count.
126   - * \param[out] false_negatives The false negative count.
127   - */
128   -BR_EXPORT void br_confusion(const char *file, float score,
129   - int *true_positives, int *false_positives, int *true_negatives, int *false_negatives);
130   -
131   -/*!
132 118 * \brief Wraps br::Convert()
133 119 */
134 120 BR_EXPORT void br_convert(const char *file_type, const char *input_file, const char *output_file);
... ...
openbr/openbr_plugin.cpp
... ... @@ -436,7 +436,7 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery)
436 436 // stores the index values in "Label" of the output template list
437 437 TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propName)
438 438 {
439   - const QList<QString> originalLabels = tl.get<QString>(propName);
  439 + const QList<QString> originalLabels = File::get<QString>(tl, propName);
440 440 QHash<QString,int> labelTable;
441 441 foreach (const QString & label, originalLabels)
442 442 if (!labelTable.contains(label))
... ... @@ -464,7 +464,7 @@ QList&lt;int&gt; TemplateList::indexProperty(const QString &amp; propName, QHash&lt;QString,
464 464 valueMap.clear();
465 465 reverseLookup.clear();
466 466  
467   - const QList<QVariant> originalLabels = values(propName);
  467 + const QList<QVariant> originalLabels = File::values(*this, propName);
468 468 foreach (const QVariant & label, originalLabels) {
469 469 QString labelString = label.toString();
470 470 if (!valueMap.contains(labelString)) {
... ... @@ -481,9 +481,9 @@ QList&lt;int&gt; TemplateList::indexProperty(const QString &amp; propName, QHash&lt;QString,
481 481 }
482 482  
483 483 // uses -1 for missing values
484   -QList<int> TemplateList::applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const
  484 +QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString, int> &valueMap) const
485 485 {
486   - const QList<QString> originalLabels = get<QString>(propName);
  486 + const QList<QString> originalLabels = File::get<QString>(*this, propName);
487 487  
488 488 QList<int> result;
489 489 for (int i=0; i<originalLabels.size(); i++) {
... ...
openbr/openbr_plugin.h
... ... @@ -228,7 +228,20 @@ struct BR_EXPORT File
228 228 return variant.value<T>();
229 229 }
230 230  
231   - /*!< \brief Returns a list of type T for the key, throwing an error if the key does not exist or if the value cannot be converted to the specified type. */
  231 + /*!< \brief Returns a value for the key, returning \em defaultValue if the key does not exist or can't be converted. */
  232 + template <typename T>
  233 + T get(const QString &key, const T &defaultValue) const
  234 + {
  235 + if (!contains(key)) return defaultValue;
  236 + QVariant variant = value(key);
  237 + if (!variant.canConvert<T>()) return defaultValue;
  238 + return variant.value<T>();
  239 + }
  240 +
  241 + /*!< \brief Specialization for boolean type. */
  242 + bool getBool(const QString &key, bool defaultValue = false) const;
  243 +
  244 + /*!< \brief Specialization for list type. Returns a list of type T for the key, throwing an error if the key does not exist or if the value cannot be converted to the specified type. */
232 245 template <typename T>
233 246 QList<T> getList(const QString &key) const
234 247 {
... ... @@ -241,17 +254,31 @@ struct BR_EXPORT File
241 254 return list;
242 255 }
243 256  
244   - /*!< \brief Specialization for boolean type. */
245   - bool getBool(const QString &key, bool defaultValue = false) const;
  257 + /*!< \brief Returns the value for the specified key for every file in the list. */
  258 + template<class U>
  259 + static QList<QVariant> values(const QList<U> &fileList, const QString &key)
  260 + {
  261 + QList<QVariant> values; values.reserve(fileList.size());
  262 + foreach (const U &f, fileList) values.append(((const File&)f).value(key));
  263 + return values;
  264 + }
246 265  
247   - /*!< \brief Returns a value for the key, returning \em defaultValue if the key does not exist or can't be converted. */
248   - template <typename T>
249   - T get(const QString &key, const T &defaultValue) const
  266 + /*!< \brief Returns a value for the key for every file in the list, throwing an error if the key does not exist. */
  267 + template<class T, class U>
  268 + static QList<T> get(const QList<U> &fileList, const QString &key)
250 269 {
251   - if (!contains(key)) return defaultValue;
252   - QVariant variant = value(key);
253   - if (!variant.canConvert<T>()) return defaultValue;
254   - return variant.value<T>();
  270 + QList<T> result; result.reserve(fileList.size());
  271 + foreach (const U &f, fileList) result.append(((const File&)f).get<T>(key));
  272 + return result;
  273 + }
  274 +
  275 + /*!< \brief Returns a value for the key for every file in the list, returning \em defaultValue if the key does not exist or can't be converted. */
  276 + template<class T, class U>
  277 + static QList<T> get(const QList<U> &fileList, const QString &key, const T &defaultValue)
  278 + {
  279 + QList<T> result; result.reserve(fileList.size());
  280 + foreach (const U &f, fileList) result.append(static_cast<const File&>(f).get<T>(key, defaultValue));
  281 + return result;
255 282 }
256 283  
257 284 inline bool failed() const { return getBool("FTE") || getBool("FTO"); } /*!< \brief Returns \c true if the file failed to open or enroll, \c false otherwise. */
... ... @@ -297,23 +324,6 @@ struct BR_EXPORT FileList : public QList&lt;File&gt;
297 324 QStringList flat() const; /*!< \brief Returns br::File::flat() for each file in the list. */
298 325 QStringList names() const; /*!< \brief Returns #br::File::name for each file in the list. */
299 326 void sort(const QString& key); /*!< \brief Sort the list based on metadata. */
300   - /*!< \brief Returns values associated with the input propName for each file in the list. */
301   - template<typename T>
302   - QList<T> get(const QString & propName) const
303   - {
304   - QList<T> values; values.reserve(size());
305   - foreach (const File &f, *this)
306   - values.append(f.get<T>(propName));
307   - return values;
308   - }
309   - template<typename T>
310   - QList<T> get(const QString & propName, T defaultValue) const
311   - {
312   - QList<T> values; values.reserve(size());
313   - foreach (const File &f, *this)
314   - values.append(f.contains(propName) ? f.get<T>(propName) : defaultValue);
315   - return values;
316   - }
317 327  
318 328 QList<int> crossValidationPartitions() const; /*!< \brief Returns the cross-validation partition (default=0) for each file in the list. */
319 329 int failures() const; /*!< \brief Returns the number of files with br::File::failed(). */
... ... @@ -344,6 +354,7 @@ struct Template : public QList&lt;cv::Mat&gt;
344 354 inline const cv::Mat &m() const { static const cv::Mat NullMatrix;
345 355 return isEmpty() ? qFatal("Empty template."), NullMatrix : last(); } /*!< \brief Idiom to treat the template as a matrix. */
346 356 inline cv::Mat &m() { return isEmpty() ? append(cv::Mat()), last() : last(); } /*!< \brief Idiom to treat the template as a matrix. */
  357 + inline const File &operator()() const { return file; }
347 358 inline cv::Mat &operator=(const cv::Mat &other) { return m() = other; } /*!< \brief Idiom to treat the template as a matrix. */
348 359 inline operator const cv::Mat&() const { return m(); } /*!< \brief Idiom to treat the template as a matrix. */
349 360 inline operator cv::Mat&() { return m(); } /*!< \brief Idiom to treat the template as a matrix. */
... ... @@ -406,7 +417,6 @@ struct TemplateList : public QList&lt;Template&gt;
406 417 QList<int> indexProperty(const QString & propName, QHash<QString, int> & valueMap, QHash<int, QVariant> & reverseLookup) const;
407 418 QList<int> applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const;
408 419  
409   -
410 420 /*!
411 421 * \brief Returns the total number of bytes in all the templates.
412 422 */
... ... @@ -477,23 +487,6 @@ struct TemplateList : public QList&lt;Template&gt;
477 487 FileList operator()() const { return files(); }
478 488  
479 489 /*!
480   - * \brief Returns br::Template::label() for each template in the list.
481   - */
482   - template<typename T>
483   - QList<T> get(const QString & propName) const
484   - {
485   - QList<T> values; values.reserve(size());
486   - foreach (const Template &t, *this) values.append(t.file.get<T>(propName));
487   - return values;
488   - }
489   - QList<QVariant> values(const QString & propName) const
490   - {
491   - QList<QVariant> values; values.reserve(size());
492   - foreach (const Template &t, *this) values.append(t.file.value(propName));
493   - return values;
494   - }
495   -
496   - /*!
497 490 * \brief Returns the number of occurences for each label in the list.
498 491 */
499 492 template<typename T>
... ... @@ -814,7 +807,8 @@ struct Factory
814 807 //! [Factory make]
815 808 static T *make(const File &file)
816 809 {
817   - QString name = file.suffix();
  810 + QString name = file.get<QString>("plugin", "");
  811 + if (name.isEmpty()) name = file.suffix();
818 812 if (!names().contains(name)) {
819 813 if (names().contains("Empty") && name.isEmpty()) name = "Empty";
820 814 else if (names().contains("Default")) name = "Default";
... ... @@ -1024,7 +1018,7 @@ public:
1024 1018 virtual TemplateList readBlock(bool *done) = 0; /*!< \brief Retrieve a portion of the stored templates. */
1025 1019 void writeBlock(const TemplateList &templates); /*!< \brief Serialize a template list. */
1026 1020 virtual void write(const Template &t) = 0; /*!< \brief Serialize a template. */
1027   - static Gallery *make(const File &file); /*!< \brief Make a gallery from a file list. */
  1021 + static Gallery *make(const File &file); /*!< \brief Make a gallery to/from a file on disk. */
1028 1022  
1029 1023 private:
1030 1024 QSharedPointer<Gallery> next;
... ...
openbr/plugins/eigen3.cpp
... ... @@ -348,7 +348,7 @@ class LDATransform : public Transform
348 348  
349 349 // OpenBR ensures that class values range from 0 to numClasses-1.
350 350 // Label exists because we created it earlier with relabel
351   - QList<int> classes = trainingSet.get<int>("Label");
  351 + QList<int> classes = File::get<int>(trainingSet, "Label");
352 352 QMap<int, int> classCounts = trainingSet.countValues<int>("Label");
353 353 const int numClasses = classCounts.size();
354 354  
... ...
openbr/plugins/gallery.cpp
... ... @@ -881,6 +881,45 @@ class statGallery : public Gallery
881 881  
882 882 BR_REGISTER(Gallery, statGallery)
883 883  
  884 +/*!
  885 + * \ingroup galleries
  886 + * \brief Implements the FDDB detection format.
  887 + * \author Josh Klontz \cite jklontz
  888 + *
  889 + * http://vis-www.cs.umass.edu/fddb/README.txt
  890 + */
  891 +class FDDBGallery : public Gallery
  892 +{
  893 + Q_OBJECT
  894 +
  895 + TemplateList readBlock(bool *done)
  896 + {
  897 + *done = true;
  898 + QStringList lines = QtUtils::readLines(file);
  899 + TemplateList templates;
  900 + while (!lines.empty()) {
  901 + const QString fileName = lines.takeFirst();
  902 + int numDetects = lines.takeFirst().toInt();
  903 + for (int i=0; i<numDetects; i++) {
  904 + const QStringList detect = lines.takeFirst().split(' ');
  905 + Template t(fileName);
  906 + t.file.set("Face", QRectF(detect[0].toFloat(), detect[1].toFloat(), detect[2].toFloat(), detect[3].toFloat()));
  907 + t.file.set("Confidence", detect[4].toFloat());
  908 + templates.append(t);
  909 + }
  910 + }
  911 + return templates;
  912 + }
  913 +
  914 + void write(const Template &t)
  915 + {
  916 + (void) t;
  917 + qFatal("Not implemented.");
  918 + }
  919 +};
  920 +
  921 +BR_REGISTER(Gallery, FDDBGallery)
  922 +
884 923 } // namespace br
885 924  
886 925 #include "gallery.moc"
... ...
openbr/plugins/independent.cpp
... ... @@ -20,7 +20,7 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t
20 20 const bool atLeast = transform->instances < 0;
21 21 const int instances = abs(transform->instances);
22 22  
23   - QList<QString> allLabels = templates.get<QString>("Subject");
  23 + QList<QString> allLabels = File::get<QString>(templates, "Subject");
24 24 QList<QString> uniqueLabels = allLabels.toSet().toList();
25 25 qSort(uniqueLabels);
26 26  
... ...
openbr/plugins/output.cpp
... ... @@ -35,9 +35,9 @@
35 35 #include "openbr_internal.h"
36 36  
37 37 #include "openbr/core/bee.h"
  38 +#include "openbr/core/eval.h"
38 39 #include "openbr/core/common.h"
39 40 #include "openbr/core/opencvutils.h"
40   -#include "openbr/core/plot.h"
41 41 #include "openbr/core/qtutils.h"
42 42  
43 43 namespace br
... ... @@ -146,8 +146,8 @@ class meltOutput : public MatrixOutput
146 146 QStringList lines;
147 147 if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys));
148 148  
149   - QList<QString> queryLabels = queryFiles.get<QString>("Subject");
150   - QList<QString> targetLabels = targetFiles.get<QString>("Subject");
  149 + QList<QString> queryLabels = File::get<QString>(queryFiles, "Subject");
  150 + QList<QString> targetLabels = File::get<QString>(targetFiles, "Subject");
151 151  
152 152 for (int i=0; i<queryFiles.size(); i++) {
153 153 for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) {
... ...
openbr/plugins/svm.cpp
... ... @@ -130,7 +130,7 @@ private:
130 130 Mat lab;
131 131 // If we are doing regression, assume subject has float values
132 132 if (type == EPS_SVR || type == NU_SVR) {
133   - lab = OpenCVUtils::toMat(_data.get<float>("Subject"));
  133 + lab = OpenCVUtils::toMat(File::get<float>(_data, "Subject"));
134 134 }
135 135 // If we are doing classification, assume subject has discrete values, map them
136 136 // and store the mapping data
... ...
scripts/downloadDatasets.sh
... ... @@ -48,3 +48,19 @@ if [ ! -d ../data/MEDS/img ]; then
48 48 mv data/*/*.jpg ../data/MEDS/img
49 49 rm -r data NIST_SD32_MEDS-II_face.zip
50 50 fi
  51 +
  52 +# KTH
  53 +if [ ! -d ../data/KTH/vid ]; then
  54 + echo "Downloading KTH..."
  55 + mkdir ../data/KTH/vid
  56 + for vidclass in {'boxing','handclapping','handwaving','jogging','running','walking'}; do
  57 + if hash curl 2>/dev/null; then
  58 + curl -OL http://www.nada.kth.se/cvap/actions/${vidclass}.zip
  59 + else
  60 + wget http://www.nada.kth.se/cvap/actions/${vidclass}.zip
  61 + fi
  62 + mkdir ../data/KTH/vid/${vidclass}
  63 + unzip ${vidclass}.zip -d ../data/KTH/vid/${vidclass}
  64 + rm ${vidclass}.zip
  65 + done
  66 +fi
... ...