Commit ec4af7799c2f589d03adea56955c8a2faac9f6a4

Authored by Scott Klum
2 parents 4936fcc1 ff04e9ed

Merge branch 'master' of https://github.com/biometrics/openbr

.gitignore
@@ -29,3 +29,6 @@ scripts/results @@ -29,3 +29,6 @@ scripts/results
29 29
30 ### OS X ### 30 ### OS X ###
31 *.DS_Store 31 *.DS_Store
  32 +
  33 +### vim ###
  34 +*.swp
app/br/br.cpp
@@ -140,13 +140,6 @@ public: @@ -140,13 +140,6 @@ public:
140 } else if (!strcmp(fun, "evalRegression")) { 140 } else if (!strcmp(fun, "evalRegression")) {
141 check(parc == 2, "Incorrect parameter count for 'evalRegression'."); 141 check(parc == 2, "Incorrect parameter count for 'evalRegression'.");
142 br_eval_regression(parv[0], parv[1]); 142 br_eval_regression(parv[0], parv[1]);
143 - } else if (!strcmp(fun, "confusion")) {  
144 - check(parc == 2, "Incorrect parameter count for 'confusion'.");  
145 - int true_positives, false_positives, true_negatives, false_negatives;  
146 - br_confusion(parv[0], atof(parv[1]),  
147 - &true_positives, &false_positives, &true_negatives, &false_negatives);  
148 - printf("True Positives = %d\nFalse Positives = %d\nTrue Negatives = %d\nFalseNegatives = %d\n",  
149 - true_positives, false_positives, true_negatives, false_negatives);  
150 } else if (!strcmp(fun, "plotMetadata")) { 143 } else if (!strcmp(fun, "plotMetadata")) {
151 check(parc >= 2, "Incorrect parameter count for 'plotMetadata'."); 144 check(parc >= 2, "Incorrect parameter count for 'plotMetadata'.");
152 br_plot_metadata(parc-1, parv, parv[parc-1], true); 145 br_plot_metadata(parc-1, parv, parv[parc-1], true);
@@ -223,7 +216,6 @@ private: @@ -223,7 +216,6 @@ private:
223 "-evalClustering <clusters> <gallery>\n" 216 "-evalClustering <clusters> <gallery>\n"
224 "-evalDetection <predicted_gallery> <truth_gallery>\n" 217 "-evalDetection <predicted_gallery> <truth_gallery>\n"
225 "-evalRegression <predicted_gallery> <truth_gallery>\n" 218 "-evalRegression <predicted_gallery> <truth_gallery>\n"
226 - "-confusion <file> <score>\n"  
227 "-plotMetadata <file> ... <file> <columns>\n" 219 "-plotMetadata <file> ... <file> <columns>\n"
228 "-getHeader <matrix>\n" 220 "-getHeader <matrix>\n"
229 "-setHeader {<matrix>} <target_gallery> <query_gallery>\n" 221 "-setHeader {<matrix>} <target_gallery> <query_gallery>\n"
data/KTH/README.md 0 → 100644
  1 +## KTH Human Action Database
  2 +Grayscale human action videos. Six actions performed by 25 subjects in four scenarios, for a total of 600 160x120 videos.
  3 +* [Website](http://www.nada.kth.se/cvap/actions/)
data/README.md
@@ -9,6 +9,7 @@ @@ -9,6 +9,7 @@
9 * [MEDS](MEDS/README.md) 9 * [MEDS](MEDS/README.md)
10 * [MNIST](MNIST/README.md) 10 * [MNIST](MNIST/README.md)
11 * [PCSO](PCSO/README.md) 11 * [PCSO](PCSO/README.md)
  12 +* [KTH](KTH/README.md)
12 13
13 For both practical and legal reasons we only include images for some of the datasets in this repository. 14 For both practical and legal reasons we only include images for some of the datasets in this repository.
14 Researchers should contact the respective owners of the other datasets in order to obtain a copy. 15 Researchers should contact the respective owners of the other datasets in order to obtain a copy.
openbr/core/bee.cpp
@@ -268,8 +268,8 @@ cv::Mat BEE::makeMask(const br::FileList &amp;targets, const br::FileList &amp;queries, @@ -268,8 +268,8 @@ cv::Mat BEE::makeMask(const br::FileList &amp;targets, const br::FileList &amp;queries,
268 { 268 {
269 // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet 269 // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet
270 // -cao 270 // -cao
271 - QList<QString> targetLabels = targets.get<QString>("Subject", "-1");  
272 - QList<QString> queryLabels = queries.get<QString>("Subject", "-1"); 271 + QList<QString> targetLabels = File::get<QString>(targets, "Subject", "-1");
  272 + QList<QString> queryLabels = File::get<QString>(queries, "Subject", "-1");
273 QList<int> targetPartitions = targets.crossValidationPartitions(); 273 QList<int> targetPartitions = targets.crossValidationPartitions();
274 QList<int> queryPartitions = queries.crossValidationPartitions(); 274 QList<int> queryPartitions = queries.crossValidationPartitions();
275 275
openbr/core/classify.cpp deleted
1 -/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *  
2 - * Copyright 2012 The MITRE Corporation *  
3 - * *  
4 - * Licensed under the Apache License, Version 2.0 (the "License"); *  
5 - * you may not use this file except in compliance with the License. *  
6 - * You may obtain a copy of the License at *  
7 - * *  
8 - * http://www.apache.org/licenses/LICENSE-2.0 *  
9 - * *  
10 - * Unless required by applicable law or agreed to in writing, software *  
11 - * distributed under the License is distributed on an "AS IS" BASIS, *  
12 - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *  
13 - * See the License for the specific language governing permissions and *  
14 - * limitations under the License. *  
15 - * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */  
16 -  
17 -#include <openbr/openbr_plugin.h>  
18 -  
19 -#include "classify.h"  
20 -#include "openbr/core/qtutils.h"  
21 -  
22 -// Helper struct for statistics accumulation  
23 -struct Counter  
24 -{  
25 - float truePositive, falsePositive, falseNegative;  
26 - Counter()  
27 - {  
28 - truePositive = 0;  
29 - falsePositive = 0;  
30 - falseNegative = 0;  
31 - }  
32 -};  
33 -  
34 -void br::EvalClassification(const QString &predictedInput, const QString &truthInput)  
35 -{  
36 - qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));  
37 -  
38 - TemplateList predicted(TemplateList::fromGallery(predictedInput));  
39 - TemplateList truth(TemplateList::fromGallery(truthInput));  
40 - if (predicted.size() != truth.size()) qFatal("Input size mismatch.");  
41 -  
42 - QHash<QString, Counter> counters;  
43 - for (int i=0; i<predicted.size(); i++) {  
44 - if (predicted[i].file.name != truth[i].file.name)  
45 - qFatal("Input order mismatch.");  
46 -  
47 - // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy.  
48 - QString predictedSubject = predicted[i].file.get<QString>("Subject");  
49 - QString trueSubject = truth[i].file.get<QString>("Subject");  
50 -  
51 - QStringList predictedSubjects(predictedSubject);  
52 - QStringList trueSubjects(trueSubject);  
53 -  
54 - foreach (const QString &subject, trueSubjects.toVector() /* Hack to copy the list. */) {  
55 - if (predictedSubjects.contains(subject)) {  
56 - counters[subject].truePositive++;  
57 - trueSubjects.removeOne(subject);  
58 - predictedSubjects.removeOne(subject);  
59 - } else {  
60 - counters[subject].falseNegative++;  
61 - }  
62 - }  
63 -  
64 - for (int i=0; i<trueSubjects.size(); i++)  
65 - foreach (const QString &subject, predictedSubjects)  
66 - counters[subject].falsePositive += 1.f / predictedSubjects.size();  
67 - }  
68 -  
69 - const QStringList keys = counters.keys();  
70 - QSharedPointer<Output> output(Output::make("", FileList() << "Count" << "Precision" << "Recall" << "F-score", FileList(keys)));  
71 -  
72 - int tpc = 0;  
73 - int fnc = 0;  
74 -  
75 - for (int i=0; i<counters.size(); i++) {  
76 - const QString &subject = keys[i];  
77 - const Counter &counter = counters[subject];  
78 - tpc += counter.truePositive;  
79 - fnc += counter.falseNegative;  
80 - const int count = counter.truePositive + counter.falseNegative;  
81 - const float precision = counter.truePositive / (float)(counter.truePositive + counter.falsePositive);  
82 - const float recall = counter.truePositive / (float)(counter.truePositive + counter.falseNegative);  
83 - const float fscore = 2 * precision * recall / (precision + recall);  
84 - output->setRelative(count, i, 0);  
85 - output->setRelative(precision, i, 1);  
86 - output->setRelative(recall, i, 2);  
87 - output->setRelative(fscore, i, 3);  
88 - }  
89 -  
90 - qDebug("Overall Accuracy = %f", (float)tpc / (float)(tpc + fnc));  
91 -}  
92 -  
93 -void br::EvalDetection(const QString &predictedInput, const QString &truthInput)  
94 -{  
95 - (void) predictedInput;  
96 - (void) truthInput;  
97 -}  
98 -  
99 -void br::EvalRegression(const QString &predictedInput, const QString &truthInput)  
100 -{  
101 - qDebug("Evaluating regression of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));  
102 -  
103 - const TemplateList predicted(TemplateList::fromGallery(predictedInput));  
104 - const TemplateList truth(TemplateList::fromGallery(truthInput));  
105 - if (predicted.size() != truth.size()) qFatal("Input size mismatch.");  
106 -  
107 - float rmsError = 0;  
108 - QStringList truthValues, predictedValues;  
109 - for (int i=0; i<predicted.size(); i++) {  
110 - if (predicted[i].file.name != truth[i].file.name)  
111 - qFatal("Input order mismatch.");  
112 - rmsError += pow(predicted[i].file.get<float>("Subject")-truth[i].file.get<float>("Subject"), 2.f);  
113 - truthValues.append(QString::number(truth[i].file.get<float>("Subject")));  
114 - predictedValues.append(QString::number(predicted[i].file.get<float>("Subject")));  
115 - }  
116 -  
117 - QStringList rSource;  
118 - rSource << "# Load libraries" << "library(ggplot2)" << "" << "# Set Data"  
119 - << "Actual <- c(" + truthValues.join(",") + ")"  
120 - << "Predicted <- c(" + predictedValues.join(",") + ")"  
121 - << "data <- data.frame(Actual, Predicted)"  
122 - << "" << "# Construct Plot" << "pdf(\"EvalRegression.pdf\")"  
123 - << "print(qplot(Actual, Predicted, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=1, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())"  
124 - << "print(qplot(Actual, Predicted-Actual, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=0, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())"  
125 - << "dev.off()";  
126 -  
127 -  
128 - QString rFile = "EvalRegression.R";  
129 - QtUtils::writeFile(rFile, rSource);  
130 - bool success = QtUtils::runRScript(rFile);  
131 - if (success) QtUtils::showFile("EvalRegression.pdf");  
132 -  
133 - qDebug("RMS Error = %f", sqrt(rmsError/predicted.size()));  
134 -}  
openbr/core/cluster.cpp
@@ -280,7 +280,7 @@ void br::EvalClustering(const QString &amp;csv, const QString &amp;input) @@ -280,7 +280,7 @@ void br::EvalClustering(const QString &amp;csv, const QString &amp;input)
280 280
281 // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are 281 // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are
282 // not named). 282 // not named).
283 - QList<int> labels = TemplateList::fromGallery(input).files().get<int>("Subject"); 283 + QList<int> labels = File::get<int>(TemplateList::fromGallery(input), "Subject");
284 284
285 QHash<int, int> labelToIndex; 285 QHash<int, int> labelToIndex;
286 int nClusters = 0; 286 int nClusters = 0;
openbr/core/eval.cpp 0 → 100644
  1 +/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
  2 + * Copyright 2012 The MITRE Corporation *
  3 + * *
  4 + * Licensed under the Apache License, Version 2.0 (the "License"); *
  5 + * you may not use this file except in compliance with the License. *
  6 + * You may obtain a copy of the License at *
  7 + * *
  8 + * http://www.apache.org/licenses/LICENSE-2.0 *
  9 + * *
  10 + * Unless required by applicable law or agreed to in writing, software *
  11 + * distributed under the License is distributed on an "AS IS" BASIS, *
  12 + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
  13 + * See the License for the specific language governing permissions and *
  14 + * limitations under the License. *
  15 + * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
  16 +
  17 +#include "bee.h"
  18 +#include "eval.h"
  19 +#include "openbr/core/qtutils.h"
  20 +
  21 +using namespace cv;
  22 +
  23 +namespace br
  24 +{
  25 +
  26 +struct Comparison
  27 +{
  28 + float score;
  29 + int target, query;
  30 + bool genuine;
  31 +
  32 + Comparison() {}
  33 + Comparison(float _score, int _target, int _query, bool _genuine)
  34 + : score(_score), target(_target), query(_query), genuine(_genuine) {}
  35 + inline bool operator<(const Comparison &other) const { return score > other.score; }
  36 +};
  37 +
  38 +#undef FAR // Windows preprecessor definition conflicts with variable name
  39 +struct OperatingPoint
  40 +{
  41 + float score, FAR, TAR;
  42 + OperatingPoint() {}
  43 + OperatingPoint(float _score, float _FAR, float _TAR)
  44 + : score(_score), FAR(_FAR), TAR(_TAR) {}
  45 +};
  46 +
  47 +static float getTAR(const QList<OperatingPoint> &operatingPoints, float FAR)
  48 +{
  49 + int index = 0;
  50 + while (operatingPoints[index].FAR < FAR) {
  51 + index++;
  52 + if (index == operatingPoints.size())
  53 + return 1;
  54 + }
  55 +
  56 + const float x1 = (index == 0 ? 0 : operatingPoints[index-1].FAR);
  57 + const float y1 = (index == 0 ? 0 : operatingPoints[index-1].TAR);
  58 + const float x2 = operatingPoints[index].FAR;
  59 + const float y2 = operatingPoints[index].TAR;
  60 + const float m = (y2 - y1) / (x2 - x1);
  61 + const float b = y1 - m*x1;
  62 + return m * FAR + b;
  63 +}
  64 +
  65 +float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv, int partition)
  66 +{
  67 + return Evaluate(scores, BEE::makeMask(target, query, partition), csv);
  68 +}
  69 +
  70 +float Evaluate(const QString &simmat, const QString &mask, const QString &csv)
  71 +{
  72 + qDebug("Evaluating %s%s%s",
  73 + qPrintable(simmat),
  74 + mask.isEmpty() ? "" : qPrintable(" with " + mask),
  75 + csv.isEmpty() ? "" : qPrintable(" to " + csv));
  76 +
  77 + // Read similarity matrix
  78 + QString target, query;
  79 + const Mat scores = BEE::readSimmat(simmat, &target, &query);
  80 +
  81 + // Read mask matrix
  82 + Mat truth;
  83 + if (mask.isEmpty()) {
  84 + // Use the galleries specified in the similarity matrix
  85 + truth = BEE::makeMask(TemplateList::fromGallery(target).files(),
  86 + TemplateList::fromGallery(query).files());
  87 + } else {
  88 + File maskFile(mask);
  89 + maskFile.set("rows", scores.rows);
  90 + maskFile.set("columns", scores.cols);
  91 + truth = BEE::readMask(maskFile);
  92 + }
  93 +
  94 + return Evaluate(scores, truth, csv);
  95 +}
  96 +
  97 +float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv)
  98 +{
  99 + if (simmat.size() != mask.size())
  100 + qFatal("Similarity matrix (%ix%i) differs in size from mask matrix (%ix%i).",
  101 + simmat.rows, simmat.cols, mask.rows, mask.cols);
  102 +
  103 + const int Max_Points = 500;
  104 + float result = -1;
  105 +
  106 + // Make comparisons
  107 + QList<Comparison> comparisons; comparisons.reserve(simmat.rows*simmat.cols);
  108 + int genuineCount = 0, impostorCount = 0, numNaNs = 0;
  109 + for (int i=0; i<simmat.rows; i++) {
  110 + for (int j=0; j<simmat.cols; j++) {
  111 + const BEE::Mask_t mask_val = mask.at<BEE::Mask_t>(i,j);
  112 + const BEE::Simmat_t simmat_val = simmat.at<BEE::Simmat_t>(i,j);
  113 + if (mask_val == BEE::DontCare) continue;
  114 + if (simmat_val != simmat_val) { numNaNs++; continue; }
  115 + comparisons.append(Comparison(simmat_val, j, i, mask_val == BEE::Match));
  116 + if (comparisons.last().genuine) genuineCount++;
  117 + else impostorCount++;
  118 + }
  119 + }
  120 +
  121 + if (numNaNs > 0) qWarning("Encountered %d NaN scores!", numNaNs);
  122 + if (genuineCount == 0) qFatal("No genuine scores!");
  123 + if (impostorCount == 0) qFatal("No impostor scores!");
  124 +
  125 + // Sort comparisons by simmat_val (score)
  126 + std::sort(comparisons.begin(), comparisons.end());
  127 +
  128 + QList<OperatingPoint> operatingPoints;
  129 + QList<float> genuines; genuines.reserve(sqrt((float)comparisons.size()));
  130 + QList<float> impostors; impostors.reserve(comparisons.size());
  131 + QVector<int> firstGenuineReturns(simmat.rows, 0);
  132 +
  133 + int falsePositives = 0, previousFalsePositives = 0;
  134 + int truePositives = 0, previousTruePositives = 0;
  135 + int index = 0;
  136 + float minGenuineScore = std::numeric_limits<float>::max();
  137 + float minImpostorScore = std::numeric_limits<float>::max();
  138 +
  139 + while (index < comparisons.size()) {
  140 + float thresh = comparisons[index].score;
  141 + // Compute genuine and imposter statistics at a threshold
  142 + while ((index < comparisons.size()) &&
  143 + (comparisons[index].score == thresh)) {
  144 + const Comparison &comparison = comparisons[index];
  145 + if (comparison.genuine) {
  146 + truePositives++;
  147 + genuines.append(comparison.score);
  148 + if (firstGenuineReturns[comparison.query] < 1)
  149 + firstGenuineReturns[comparison.query] = abs(firstGenuineReturns[comparison.query]) + 1;
  150 + if ((comparison.score != -std::numeric_limits<float>::max()) &&
  151 + (comparison.score < minGenuineScore))
  152 + minGenuineScore = comparison.score;
  153 + } else {
  154 + falsePositives++;
  155 + impostors.append(comparison.score);
  156 + if (firstGenuineReturns[comparison.query] < 1)
  157 + firstGenuineReturns[comparison.query]--;
  158 + if ((comparison.score != -std::numeric_limits<float>::max()) &&
  159 + (comparison.score < minImpostorScore))
  160 + minImpostorScore = comparison.score;
  161 + }
  162 + index++;
  163 + }
  164 +
  165 + if ((falsePositives > previousFalsePositives) &&
  166 + (truePositives > previousTruePositives)) {
  167 + // Restrict the extreme ends of the curve
  168 + if ((truePositives >= 10) && (falsePositives < impostorCount/2))
  169 + operatingPoints.append(OperatingPoint(thresh, float(falsePositives)/impostorCount, float(truePositives)/genuineCount));
  170 + previousFalsePositives = falsePositives;
  171 + previousTruePositives = truePositives;
  172 + }
  173 + }
  174 +
  175 + if (operatingPoints.size() == 0) operatingPoints.append(OperatingPoint(1, 1, 1));
  176 + if (operatingPoints.size() == 1) operatingPoints.prepend(OperatingPoint(0, 0, 0));
  177 + if (operatingPoints.size() > 2) operatingPoints.takeLast(); // Remove point (1,1)
  178 +
  179 + // Write Metadata table
  180 + QStringList lines;
  181 + lines.append("Plot,X,Y");
  182 + lines.append("Metadata,"+QString::number(simmat.cols)+",Gallery");
  183 + lines.append("Metadata,"+QString::number(simmat.rows)+",Probe");
  184 + lines.append("Metadata,"+QString::number(genuineCount)+",Genuine");
  185 + lines.append("Metadata,"+QString::number(impostorCount)+",Impostor");
  186 + lines.append("Metadata,"+QString::number(simmat.cols*simmat.rows-(genuineCount+impostorCount))+",Ignored");
  187 +
  188 + // Write Detection Error Tradeoff (DET), PRE, REC
  189 + int points = qMin(operatingPoints.size(), Max_Points);
  190 + for (int i=0; i<points; i++) {
  191 + const OperatingPoint &operatingPoint = operatingPoints[double(i) / double(points-1) * double(operatingPoints.size()-1)];
  192 + lines.append(QString("DET,%1,%2").arg(QString::number(operatingPoint.FAR),
  193 + QString::number(1-operatingPoint.TAR)));
  194 + lines.append(QString("FAR,%1,%2").arg(QString::number(operatingPoint.score),
  195 + QString::number(operatingPoint.FAR)));
  196 + lines.append(QString("FRR,%1,%2").arg(QString::number(operatingPoint.score),
  197 + QString::number(1-operatingPoint.TAR)));
  198 + }
  199 +
  200 + // Write FAR/TAR Bar Chart (BC)
  201 + lines.append(qPrintable(QString("BC,0.001,%1").arg(QString::number(getTAR(operatingPoints, 0.001), 'f', 3))));
  202 + lines.append(qPrintable(QString("BC,0.01,%1").arg(QString::number(result = getTAR(operatingPoints, 0.01), 'f', 3))));
  203 +
  204 + // Write SD & KDE
  205 + points = qMin(qMin(Max_Points, genuines.size()), impostors.size());
  206 + QList<double> sampledGenuineScores; sampledGenuineScores.reserve(points);
  207 + QList<double> sampledImpostorScores; sampledImpostorScores.reserve(points);
  208 +
  209 + if (points > 1) {
  210 + for (int i=0; i<points; i++) {
  211 + float genuineScore = genuines[double(i) / double(points-1) * double(genuines.size()-1)];
  212 + float impostorScore = impostors[double(i) / double(points-1) * double(impostors.size()-1)];
  213 + if (genuineScore == -std::numeric_limits<float>::max()) genuineScore = minGenuineScore;
  214 + if (impostorScore == -std::numeric_limits<float>::max()) impostorScore = minImpostorScore;
  215 + lines.append(QString("SD,%1,Genuine").arg(QString::number(genuineScore)));
  216 + lines.append(QString("SD,%1,Impostor").arg(QString::number(impostorScore)));
  217 + sampledGenuineScores.append(genuineScore);
  218 + sampledImpostorScores.append(impostorScore);
  219 + }
  220 + }
  221 +
  222 + // Write Cumulative Match Characteristic (CMC) curve
  223 + const int Max_Retrieval = 200;
  224 + const int Report_Retrieval = 5;
  225 +
  226 + float reportRetrievalRate = -1;
  227 + for (int i=1; i<=Max_Retrieval; i++) {
  228 + int realizedReturns = 0, possibleReturns = 0;
  229 + foreach (int firstGenuineReturn, firstGenuineReturns) {
  230 + if (firstGenuineReturn > 0) {
  231 + possibleReturns++;
  232 + if (firstGenuineReturn <= i) realizedReturns++;
  233 + }
  234 + }
  235 + const float retrievalRate = float(realizedReturns)/possibleReturns;
  236 + lines.append(qPrintable(QString("CMC,%1,%2").arg(QString::number(i), QString::number(retrievalRate))));
  237 + if (i == Report_Retrieval) reportRetrievalRate = retrievalRate;
  238 + }
  239 +
  240 + if (!csv.isEmpty()) QtUtils::writeFile(csv, lines);
  241 + qDebug("TAR @ FAR = 0.01: %.3f\nRetrieval Rate @ Rank = %d: %.3f", result, Report_Retrieval, reportRetrievalRate);
  242 + return result;
  243 +}
  244 +
  245 +// Helper struct for statistics accumulation
  246 +struct Counter
  247 +{
  248 + float truePositive, falsePositive, falseNegative;
  249 + Counter()
  250 + {
  251 + truePositive = 0;
  252 + falsePositive = 0;
  253 + falseNegative = 0;
  254 + }
  255 +};
  256 +
  257 +void EvalClassification(const QString &predictedInput, const QString &truthInput)
  258 +{
  259 + qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
  260 + TemplateList predicted(TemplateList::fromGallery(predictedInput));
  261 + TemplateList truth(TemplateList::fromGallery(truthInput));
  262 + if (predicted.size() != truth.size()) qFatal("Input size mismatch.");
  263 +
  264 + QHash<QString, Counter> counters;
  265 + for (int i=0; i<predicted.size(); i++) {
  266 + if (predicted[i].file.name != truth[i].file.name)
  267 + qFatal("Input order mismatch.");
  268 +
  269 + // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy.
  270 + QString predictedSubject = predicted[i].file.get<QString>("Subject");
  271 + QString trueSubject = truth[i].file.get<QString>("Subject");
  272 +
  273 + QStringList predictedSubjects(predictedSubject);
  274 + QStringList trueSubjects(trueSubject);
  275 +
  276 + foreach (const QString &subject, trueSubjects.toVector() /* Hack to copy the list. */) {
  277 + if (predictedSubjects.contains(subject)) {
  278 + counters[subject].truePositive++;
  279 + trueSubjects.removeOne(subject);
  280 + predictedSubjects.removeOne(subject);
  281 + } else {
  282 + counters[subject].falseNegative++;
  283 + }
  284 + }
  285 +
  286 + for (int i=0; i<trueSubjects.size(); i++)
  287 + foreach (const QString &subject, predictedSubjects)
  288 + counters[subject].falsePositive += 1.f / predictedSubjects.size();
  289 + }
  290 +
  291 + const QStringList keys = counters.keys();
  292 + QSharedPointer<Output> output(Output::make("", FileList() << "Count" << "Precision" << "Recall" << "F-score", FileList(keys)));
  293 +
  294 + int tpc = 0;
  295 + int fnc = 0;
  296 +
  297 + for (int i=0; i<counters.size(); i++) {
  298 + const QString &subject = keys[i];
  299 + const Counter &counter = counters[subject];
  300 + tpc += counter.truePositive;
  301 + fnc += counter.falseNegative;
  302 + const int count = counter.truePositive + counter.falseNegative;
  303 + const float precision = counter.truePositive / (float)(counter.truePositive + counter.falsePositive);
  304 + const float recall = counter.truePositive / (float)(counter.truePositive + counter.falseNegative);
  305 + const float fscore = 2 * precision * recall / (precision + recall);
  306 + output->setRelative(count, i, 0);
  307 + output->setRelative(precision, i, 1);
  308 + output->setRelative(recall, i, 2);
  309 + output->setRelative(fscore, i, 3);
  310 + }
  311 +
  312 + qDebug("Overall Accuracy = %f", (float)tpc / (float)(tpc + fnc));
  313 +}
  314 +
  315 +struct Detection
  316 +{
  317 + QRectF boundingBox;
  318 + float confidence;
  319 +
  320 + Detection() {}
  321 + Detection(const QRectF &boundingBox_, float confidence_ = -1)
  322 + : boundingBox(boundingBox_), confidence(confidence_) {}
  323 +
  324 + float area() const { return boundingBox.width() * boundingBox.height(); }
  325 + float overlap(const Detection &other) const
  326 + {
  327 + const Detection intersection(boundingBox.intersected(other.boundingBox));
  328 + return intersection.area() / (area() + other.area() - 2*intersection.area());
  329 + }
  330 +};
  331 +
  332 +struct Detections
  333 +{
  334 + QList<Detection> predicted, truth;
  335 +};
  336 +
  337 +struct DetectionOperatingPoint
  338 +{
  339 + float confidence, overlap;
  340 + DetectionOperatingPoint() : confidence(-1), overlap(-1) {}
  341 + DetectionOperatingPoint(float confidence_, float overlap_) : confidence(confidence_), overlap(overlap_) {}
  342 + inline bool operator<(const DetectionOperatingPoint &other) const { return confidence > other.confidence; }
  343 +};
  344 +
  345 +float EvalDetection(const QString &predictedInput, const QString &truthInput, const QString &csv)
  346 +{
  347 + qDebug("Evaluating detection of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
  348 + const TemplateList predicted(TemplateList::fromGallery(predictedInput));
  349 + const TemplateList truth(TemplateList::fromGallery(truthInput));
  350 +
  351 + // Figure out which metadata field contains a bounding box
  352 + QString detectKey;
  353 + foreach (const QString &key, truth.first().file.localKeys())
  354 + if (!truth.first().file.get<QRectF>(key, QRectF()).isNull()) {
  355 + detectKey = key;
  356 + break;
  357 + }
  358 + if (detectKey.isNull()) qFatal("No suitable metadata key found.");
  359 + else qDebug("Using metadata key: %s", qPrintable(detectKey));
  360 +
  361 + QHash<QString, Detections> allDetections; // Organized by file
  362 + foreach (const Template &t, predicted)
  363 + allDetections[t.file.baseName()].predicted.append(Detection(t.file.get<QRectF>(detectKey), t.file.get<float>("Confidence", -1)));
  364 + foreach (const Template &t, truth)
  365 + allDetections[t.file.baseName()].truth.append(Detection(t.file.get<QRectF>(detectKey)));
  366 +
  367 + QList<DetectionOperatingPoint> points;
  368 + foreach (Detections detections, allDetections.values()) {
  369 + while (!detections.truth.isEmpty() && !detections.predicted.isEmpty()) {
  370 + Detection truth = detections.truth.takeFirst();
  371 + int bestIndex = -1;
  372 + float bestOverlap = -1;
  373 + for (int i=0; i<detections.predicted.size(); i++) {
  374 + const float overlap = truth.overlap(detections.predicted[i]);
  375 + if (overlap > bestOverlap) {
  376 + bestOverlap = overlap;
  377 + bestIndex = i;
  378 + }
  379 + }
  380 + Detection predicted = detections.predicted.takeAt(bestIndex);
  381 + points.append(DetectionOperatingPoint(predicted.confidence, bestOverlap));
  382 + }
  383 +
  384 + foreach (const Detection &detection, detections.predicted)
  385 + points.append(DetectionOperatingPoint(detection.confidence, 0));
  386 + for (int i=0; i<detections.truth.size(); i++)
  387 + points.append(DetectionOperatingPoint(-std::numeric_limits<float>::max(), 0));
  388 + }
  389 +
  390 + std::sort(points.begin(), points.end());
  391 +
  392 + QStringList lines;
  393 + lines.append("Plot, X, Y");
  394 +
  395 + // TODO: finish implementing
  396 +
  397 + (void) csv;
  398 + return 0;
  399 +}
  400 +
  401 +void EvalRegression(const QString &predictedInput, const QString &truthInput)
  402 +{
  403 + qDebug("Evaluating regression of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
  404 + const TemplateList predicted(TemplateList::fromGallery(predictedInput));
  405 + const TemplateList truth(TemplateList::fromGallery(truthInput));
  406 + if (predicted.size() != truth.size()) qFatal("Input size mismatch.");
  407 +
  408 + float rmsError = 0;
  409 + QStringList truthValues, predictedValues;
  410 + for (int i=0; i<predicted.size(); i++) {
  411 + if (predicted[i].file.name != truth[i].file.name)
  412 + qFatal("Input order mismatch.");
  413 + rmsError += pow(predicted[i].file.get<float>("Subject")-truth[i].file.get<float>("Subject"), 2.f);
  414 + truthValues.append(QString::number(truth[i].file.get<float>("Subject")));
  415 + predictedValues.append(QString::number(predicted[i].file.get<float>("Subject")));
  416 + }
  417 +
  418 + QStringList rSource;
  419 + rSource << "# Load libraries" << "library(ggplot2)" << "" << "# Set Data"
  420 + << "Actual <- c(" + truthValues.join(",") + ")"
  421 + << "Predicted <- c(" + predictedValues.join(",") + ")"
  422 + << "data <- data.frame(Actual, Predicted)"
  423 + << "" << "# Construct Plot" << "pdf(\"EvalRegression.pdf\")"
  424 + << "print(qplot(Actual, Predicted, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=1, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())"
  425 + << "print(qplot(Actual, Predicted-Actual, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=0, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())"
  426 + << "dev.off()";
  427 +
  428 +
  429 + QString rFile = "EvalRegression.R";
  430 + QtUtils::writeFile(rFile, rSource);
  431 + bool success = QtUtils::runRScript(rFile);
  432 + if (success) QtUtils::showFile("EvalRegression.pdf");
  433 +
  434 + qDebug("RMS Error = %f", sqrt(rmsError/predicted.size()));
  435 +}
  436 +
  437 +} // namespace br
openbr/core/classify.h renamed to openbr/core/eval.h
@@ -14,18 +14,22 @@ @@ -14,18 +14,22 @@
14 * limitations under the License. * 14 * limitations under the License. *
15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
16 16
17 -#ifndef __CLASSIFY_H  
18 -#define __CLASSIFY_H 17 +#ifndef __EVAL_H
  18 +#define __EVAL_H
19 19
20 #include <QList> 20 #include <QList>
21 #include <QString> 21 #include <QString>
  22 +#include "openbr/openbr_plugin.h"
22 23
23 namespace br 24 namespace br
24 { 25 {
  26 + float Evaluate(const QString &simmat, const QString &mask = "", const QString &csv = ""); // Returns TAR @ FAR = 0.001
  27 + float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv = "", int parition = 0);
  28 + float Evaluate(const cv::Mat &scores, const cv::Mat &masks, const QString &csv = "");
25 void EvalClassification(const QString &predictedInput, const QString &truthInput); 29 void EvalClassification(const QString &predictedInput, const QString &truthInput);
26 - void EvalDetection(const QString &predictedInput, const QString &truthInput); 30 + float EvalDetection(const QString &predictedInput, const QString &truthInput, const QString &csv = ""); // Return average overlap
27 void EvalRegression(const QString &predictedInput, const QString &truthInput); 31 void EvalRegression(const QString &predictedInput, const QString &truthInput);
28 } 32 }
29 33
30 -#endif // __CLASSIFY_H 34 +#endif // __EVAL_H
31 35
openbr/core/plot.cpp
@@ -14,57 +14,15 @@ @@ -14,57 +14,15 @@
14 * limitations under the License. * 14 * limitations under the License. *
15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
16 16
17 -#include <QDebug>  
18 -#include <QDir>  
19 -#include <QFile>  
20 -#include <QFileInfo>  
21 -#include <QFuture>  
22 -#include <QList>  
23 -#include <QPair>  
24 -#include <QPointF>  
25 -#include <QRegExp>  
26 -#include <QSet>  
27 -#include <QStringList>  
28 -#include <QVector>  
29 -#include <QtAlgorithms>  
30 -#include <opencv2/core/core.hpp>  
31 -#include <assert.h>  
32 -  
33 #include "plot.h" 17 #include "plot.h"
34 #include "version.h" 18 #include "version.h"
35 -#include "openbr/core/bee.h"  
36 -#include "openbr/core/common.h"  
37 -#include "openbr/core/opencvutils.h"  
38 #include "openbr/core/qtutils.h" 19 #include "openbr/core/qtutils.h"
39 20
40 -#undef FAR // Windows preprecessor definition  
41 -  
42 using namespace cv; 21 using namespace cv;
43 22
44 namespace br 23 namespace br
45 { 24 {
46 25
47 -void Confusion(const QString &file, float score, int &true_positives, int &false_positives, int &true_negatives, int &false_negatives)  
48 -{  
49 - qDebug("Computing confusion matrix of %s at %f", qPrintable(file), score);  
50 -  
51 - QStringList lines = QtUtils::readLines(file);  
52 - true_positives = false_positives = true_negatives = false_negatives = 0;  
53 - foreach (const QString &line, lines) {  
54 - if (!line.startsWith("SD")) continue;  
55 - QStringList words = line.split(",");  
56 - bool ok;  
57 - float similarity = words[1].toFloat(&ok); assert(ok);  
58 - if (words[2] == "Genuine") {  
59 - if (similarity >= score) true_positives++;  
60 - else false_negatives++;  
61 - } else {  
62 - if (similarity >= score) false_positives++;  
63 - else true_negatives++;  
64 - }  
65 - }  
66 -}  
67 -  
68 static QStringList getPivots(const QString &file, bool headers) 26 static QStringList getPivots(const QString &file, bool headers)
69 { 27 {
70 QString str; 28 QString str;
@@ -73,224 +31,6 @@ static QStringList getPivots(const QString &amp;file, bool headers) @@ -73,224 +31,6 @@ static QStringList getPivots(const QString &amp;file, bool headers)
73 return str.split("_"); 31 return str.split("_");
74 } 32 }
75 33
76 -struct Comparison  
77 -{  
78 - float score;  
79 - int target, query;  
80 - bool genuine;  
81 -  
82 - Comparison() {}  
83 - Comparison(float _score, int _target, int _query, bool _genuine)  
84 - : score(_score), target(_target), query(_query), genuine(_genuine) {}  
85 - inline bool operator<(const Comparison &other) const { return score > other.score; }  
86 -};  
87 -  
88 -struct OperatingPoint  
89 -{  
90 - float score, FAR, TAR;  
91 - OperatingPoint() {}  
92 - OperatingPoint(float _score, float _FAR, float _TAR)  
93 - : score(_score), FAR(_FAR), TAR(_TAR) {}  
94 -};  
95 -  
96 -static float getTAR(const QList<OperatingPoint> &operatingPoints, float FAR)  
97 -{  
98 - int index = 0;  
99 - while (operatingPoints[index].FAR < FAR) {  
100 - index++;  
101 - if (index == operatingPoints.size())  
102 - return 1;  
103 - }  
104 -  
105 - const float x1 = (index == 0 ? 0 : operatingPoints[index-1].FAR);  
106 - const float y1 = (index == 0 ? 0 : operatingPoints[index-1].TAR);  
107 - const float x2 = operatingPoints[index].FAR;  
108 - const float y2 = operatingPoints[index].TAR;  
109 - const float m = (y2 - y1) / (x2 - x1);  
110 - const float b = y1 - m*x1;  
111 - return m * FAR + b;  
112 -}  
113 -  
114 -float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv, int partition)  
115 -{  
116 - return Evaluate(scores, BEE::makeMask(target, query, partition), csv);  
117 -}  
118 -  
119 -float Evaluate(const QString &simmat, const QString &mask, const QString &csv)  
120 -{  
121 - qDebug("Evaluating %s%s%s",  
122 - qPrintable(simmat),  
123 - mask.isEmpty() ? "" : qPrintable(" with " + mask),  
124 - csv.isEmpty() ? "" : qPrintable(" to " + csv));  
125 -  
126 - // Read similarity matrix  
127 - QString target, query;  
128 - const Mat scores = BEE::readSimmat(simmat, &target, &query);  
129 -  
130 - // Read mask matrix  
131 - Mat truth;  
132 - if (mask.isEmpty()) {  
133 - // Use the galleries specified in the similarity matrix  
134 - truth = BEE::makeMask(TemplateList::fromGallery(target).files(),  
135 - TemplateList::fromGallery(query).files());  
136 - } else {  
137 - File maskFile(mask);  
138 - maskFile.set("rows", scores.rows);  
139 - maskFile.set("columns", scores.cols);  
140 - truth = BEE::readMask(maskFile);  
141 - }  
142 -  
143 - return Evaluate(scores, truth, csv);  
144 -}  
145 -  
146 -float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv)  
147 -{  
148 - if (simmat.size() != mask.size())  
149 - qFatal("Similarity matrix (%ix%i) differs in size from mask matrix (%ix%i).",  
150 - simmat.rows, simmat.cols, mask.rows, mask.cols);  
151 -  
152 - const int Max_Points = 500;  
153 - float result = -1;  
154 -  
155 - // Make comparisons  
156 - QList<Comparison> comparisons; comparisons.reserve(simmat.rows*simmat.cols);  
157 - int genuineCount = 0, impostorCount = 0, numNaNs = 0;  
158 - for (int i=0; i<simmat.rows; i++) {  
159 - for (int j=0; j<simmat.cols; j++) {  
160 - const BEE::Mask_t mask_val = mask.at<BEE::Mask_t>(i,j);  
161 - const BEE::Simmat_t simmat_val = simmat.at<BEE::Simmat_t>(i,j);  
162 - if (mask_val == BEE::DontCare) continue;  
163 - if (simmat_val != simmat_val) { numNaNs++; continue; }  
164 - comparisons.append(Comparison(simmat_val, j, i, mask_val == BEE::Match));  
165 - if (comparisons.last().genuine) genuineCount++;  
166 - else impostorCount++;  
167 - }  
168 - }  
169 -  
170 - if (numNaNs > 0) qWarning("Encountered %d NaN scores!", numNaNs);  
171 - if (genuineCount == 0) qFatal("No genuine scores!");  
172 - if (impostorCount == 0) qFatal("No impostor scores!");  
173 -  
174 - // Sort comparisons by simmat_val (score)  
175 - std::sort(comparisons.begin(), comparisons.end());  
176 -  
177 - QList<OperatingPoint> operatingPoints;  
178 - QList<float> genuines; genuines.reserve(sqrt((float)comparisons.size()));  
179 - QList<float> impostors; impostors.reserve(comparisons.size());  
180 - QVector<int> firstGenuineReturns(simmat.rows, 0);  
181 -  
182 - int falsePositives = 0, previousFalsePositives = 0;  
183 - int truePositives = 0, previousTruePositives = 0;  
184 - int index = 0;  
185 - float minGenuineScore = std::numeric_limits<float>::max();  
186 - float minImpostorScore = std::numeric_limits<float>::max();  
187 -  
188 - while (index < comparisons.size()) {  
189 - float thresh = comparisons[index].score;  
190 - // Compute genuine and imposter statistics at a threshold  
191 - while ((index < comparisons.size()) &&  
192 - (comparisons[index].score == thresh)) {  
193 - const Comparison &comparison = comparisons[index];  
194 - if (comparison.genuine) {  
195 - truePositives++;  
196 - genuines.append(comparison.score);  
197 - if (firstGenuineReturns[comparison.query] < 1)  
198 - firstGenuineReturns[comparison.query] = abs(firstGenuineReturns[comparison.query]) + 1;  
199 - if ((comparison.score != -std::numeric_limits<float>::max()) &&  
200 - (comparison.score < minGenuineScore))  
201 - minGenuineScore = comparison.score;  
202 - } else {  
203 - falsePositives++;  
204 - impostors.append(comparison.score);  
205 - if (firstGenuineReturns[comparison.query] < 1)  
206 - firstGenuineReturns[comparison.query]--;  
207 - if ((comparison.score != -std::numeric_limits<float>::max()) &&  
208 - (comparison.score < minImpostorScore))  
209 - minImpostorScore = comparison.score;  
210 - }  
211 - index++;  
212 - }  
213 -  
214 - if ((falsePositives > previousFalsePositives) &&  
215 - (truePositives > previousTruePositives)) {  
216 - // Restrict the extreme ends of the curve  
217 - if ((truePositives >= 10) && (falsePositives < impostorCount/2))  
218 - operatingPoints.append(OperatingPoint(thresh, float(falsePositives)/impostorCount, float(truePositives)/genuineCount));  
219 - previousFalsePositives = falsePositives;  
220 - previousTruePositives = truePositives;  
221 - }  
222 - }  
223 -  
224 - if (operatingPoints.size() == 0) operatingPoints.append(OperatingPoint(1, 1, 1));  
225 - if (operatingPoints.size() == 1) operatingPoints.prepend(OperatingPoint(0, 0, 0));  
226 - if (operatingPoints.size() > 2) operatingPoints.takeLast(); // Remove point (1,1)  
227 -  
228 - // Write Metadata table  
229 - QStringList lines;  
230 - lines.append("Plot,X,Y");  
231 - lines.append("Metadata,"+QString::number(simmat.cols)+",Gallery");  
232 - lines.append("Metadata,"+QString::number(simmat.rows)+",Probe");  
233 - lines.append("Metadata,"+QString::number(genuineCount)+",Genuine");  
234 - lines.append("Metadata,"+QString::number(impostorCount)+",Impostor");  
235 - lines.append("Metadata,"+QString::number(simmat.cols*simmat.rows-(genuineCount+impostorCount))+",Ignored");  
236 -  
237 - // Write Detection Error Tradeoff (DET), PRE, REC  
238 - int points = qMin(operatingPoints.size(), Max_Points);  
239 - for (int i=0; i<points; i++) {  
240 - const OperatingPoint &operatingPoint = operatingPoints[double(i) / double(points-1) * double(operatingPoints.size()-1)];  
241 - lines.append(QString("DET,%1,%2").arg(QString::number(operatingPoint.FAR),  
242 - QString::number(1-operatingPoint.TAR)));  
243 - lines.append(QString("FAR,%1,%2").arg(QString::number(operatingPoint.score),  
244 - QString::number(operatingPoint.FAR)));  
245 - lines.append(QString("FRR,%1,%2").arg(QString::number(operatingPoint.score),  
246 - QString::number(1-operatingPoint.TAR)));  
247 - }  
248 -  
249 - // Write FAR/TAR Bar Chart (BC)  
250 - lines.append(qPrintable(QString("BC,0.001,%1").arg(QString::number(getTAR(operatingPoints, 0.001), 'f', 3))));  
251 - lines.append(qPrintable(QString("BC,0.01,%1").arg(QString::number(result = getTAR(operatingPoints, 0.01), 'f', 3))));  
252 -  
253 - // Write SD & KDE  
254 - points = qMin(qMin(Max_Points, genuines.size()), impostors.size());  
255 - QList<double> sampledGenuineScores; sampledGenuineScores.reserve(points);  
256 - QList<double> sampledImpostorScores; sampledImpostorScores.reserve(points);  
257 -  
258 - if (points > 1) {  
259 - for (int i=0; i<points; i++) {  
260 - float genuineScore = genuines[double(i) / double(points-1) * double(genuines.size()-1)];  
261 - float impostorScore = impostors[double(i) / double(points-1) * double(impostors.size()-1)];  
262 - if (genuineScore == -std::numeric_limits<float>::max()) genuineScore = minGenuineScore;  
263 - if (impostorScore == -std::numeric_limits<float>::max()) impostorScore = minImpostorScore;  
264 - lines.append(QString("SD,%1,Genuine").arg(QString::number(genuineScore)));  
265 - lines.append(QString("SD,%1,Impostor").arg(QString::number(impostorScore)));  
266 - sampledGenuineScores.append(genuineScore);  
267 - sampledImpostorScores.append(impostorScore);  
268 - }  
269 - }  
270 -  
271 - // Write Cumulative Match Characteristic (CMC) curve  
272 - const int Max_Retrieval = 200;  
273 - const int Report_Retrieval = 5;  
274 -  
275 - float reportRetrievalRate = -1;  
276 - for (int i=1; i<=Max_Retrieval; i++) {  
277 - int realizedReturns = 0, possibleReturns = 0;  
278 - foreach (int firstGenuineReturn, firstGenuineReturns) {  
279 - if (firstGenuineReturn > 0) {  
280 - possibleReturns++;  
281 - if (firstGenuineReturn <= i) realizedReturns++;  
282 - }  
283 - }  
284 - const float retrievalRate = float(realizedReturns)/possibleReturns;  
285 - lines.append(qPrintable(QString("CMC,%1,%2").arg(QString::number(i), QString::number(retrievalRate))));  
286 - if (i == Report_Retrieval) reportRetrievalRate = retrievalRate;  
287 - }  
288 -  
289 - if (!csv.isEmpty()) QtUtils::writeFile(csv, lines);  
290 - qDebug("TAR @ FAR = 0.01: %.3f\nRetrieval Rate @ Rank = %d: %.3f", result, Report_Retrieval, reportRetrievalRate);  
291 - return result;  
292 -}  
293 -  
294 static QString getScale(const QString &mode, const QString &title, int vals) 34 static QString getScale(const QString &mode, const QString &title, int vals)
295 { 35 {
296 if (vals > 12) return " + scale_"+mode+"_discrete(\""+title+"\")"; 36 if (vals > 12) return " + scale_"+mode+"_discrete(\""+title+"\")";
@@ -474,7 +214,6 @@ struct RPlot @@ -474,7 +214,6 @@ struct RPlot
474 }; 214 };
475 215
476 // Does not work if dataset folder starts with a number 216 // Does not work if dataset folder starts with a number
477 -  
478 bool Plot(const QStringList &files, const br::File &destination, bool show) 217 bool Plot(const QStringList &files, const br::File &destination, bool show)
479 { 218 {
480 qDebug("Plotting %d file(s) to %s", files.size(), qPrintable(destination)); 219 qDebug("Plotting %d file(s) to %s", files.size(), qPrintable(destination));
openbr/core/plot.h
@@ -24,14 +24,8 @@ @@ -24,14 +24,8 @@
24 24
25 namespace br 25 namespace br
26 { 26 {
27 -  
28 -void Confusion(const QString &file, float score, int &true_positives, int &false_positives, int &true_negatives, int &false_negatives);  
29 -float Evaluate(const QString &simmat, const QString &mask = "", const QString &csv = ""); // Returns TAR @ FAR = 0.001  
30 -float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv = "", int parition = 0);  
31 -float Evaluate(const cv::Mat &scores, const cv::Mat &masks, const QString &csv = "");  
32 -bool Plot(const QStringList &files, const br::File &destination, bool show = false);  
33 -bool PlotMetadata(const QStringList &files, const QString &destination, bool show = false);  
34 - 27 + bool Plot(const QStringList &files, const br::File &destination, bool show = false);
  28 + bool PlotMetadata(const QStringList &files, const QString &destination, bool show = false);
35 } 29 }
36 30
37 #endif // __PLOT_H 31 #endif // __PLOT_H
openbr/openbr.cpp
@@ -17,8 +17,8 @@ @@ -17,8 +17,8 @@
17 #include <openbr/openbr_plugin.h> 17 #include <openbr/openbr_plugin.h>
18 18
19 #include "core/bee.h" 19 #include "core/bee.h"
20 -#include "core/classify.h"  
21 #include "core/cluster.h" 20 #include "core/cluster.h"
  21 +#include "core/eval.h"
22 #include "core/fuse.h" 22 #include "core/fuse.h"
23 #include "core/plot.h" 23 #include "core/plot.h"
24 #include "core/qtutils.h" 24 #include "core/qtutils.h"
@@ -51,11 +51,6 @@ void br_compare(const char *target_gallery, const char *query_gallery, const cha @@ -51,11 +51,6 @@ void br_compare(const char *target_gallery, const char *query_gallery, const cha
51 Compare(File(target_gallery), File(query_gallery), File(output)); 51 Compare(File(target_gallery), File(query_gallery), File(output));
52 } 52 }
53 53
54 -void br_confusion(const char *file, float score, int *true_positives, int *false_positives, int *true_negatives, int *false_negatives)  
55 -{  
56 - return Confusion(file, score, *true_positives, *false_positives, *true_negatives, *false_negatives);  
57 -}  
58 -  
59 void br_convert(const char *file_type, const char *input_file, const char *output_file) 54 void br_convert(const char *file_type, const char *input_file, const char *output_file)
60 { 55 {
61 Convert(File(file_type), File(input_file), File(output_file)); 56 Convert(File(file_type), File(input_file), File(output_file));
openbr/openbr.h
@@ -115,20 +115,6 @@ BR_EXPORT void br_combine_masks(int num_input_masks, const char *input_masks[], @@ -115,20 +115,6 @@ BR_EXPORT void br_combine_masks(int num_input_masks, const char *input_masks[],
115 BR_EXPORT void br_compare(const char *target_gallery, const char *query_gallery, const char *output = ""); 115 BR_EXPORT void br_compare(const char *target_gallery, const char *query_gallery, const char *output = "");
116 116
117 /*! 117 /*!
118 - * \brief Computes the confusion matrix for a dataset at a particular threshold.  
119 - *  
120 - * <a href="http://en.wikipedia.org/wiki/Confusion_matrix">Wikipedia Explanation</a>  
121 - * \param file <tt>.csv</tt> file created using \ref br_eval.  
122 - * \param score The similarity score to threshold at.  
123 - * \param[out] true_positives The true positive count.  
124 - * \param[out] false_positives The false positive count.  
125 - * \param[out] true_negatives The true negative count.  
126 - * \param[out] false_negatives The false negative count.  
127 - */  
128 -BR_EXPORT void br_confusion(const char *file, float score,  
129 - int *true_positives, int *false_positives, int *true_negatives, int *false_negatives);  
130 -  
131 -/*!  
132 * \brief Wraps br::Convert() 118 * \brief Wraps br::Convert()
133 */ 119 */
134 BR_EXPORT void br_convert(const char *file_type, const char *input_file, const char *output_file); 120 BR_EXPORT void br_convert(const char *file_type, const char *input_file, const char *output_file);
openbr/openbr_plugin.cpp
@@ -436,7 +436,7 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery) @@ -436,7 +436,7 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery)
436 // stores the index values in "Label" of the output template list 436 // stores the index values in "Label" of the output template list
437 TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propName) 437 TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propName)
438 { 438 {
439 - const QList<QString> originalLabels = tl.get<QString>(propName); 439 + const QList<QString> originalLabels = File::get<QString>(tl, propName);
440 QHash<QString,int> labelTable; 440 QHash<QString,int> labelTable;
441 foreach (const QString & label, originalLabels) 441 foreach (const QString & label, originalLabels)
442 if (!labelTable.contains(label)) 442 if (!labelTable.contains(label))
@@ -464,7 +464,7 @@ QList&lt;int&gt; TemplateList::indexProperty(const QString &amp; propName, QHash&lt;QString, @@ -464,7 +464,7 @@ QList&lt;int&gt; TemplateList::indexProperty(const QString &amp; propName, QHash&lt;QString,
464 valueMap.clear(); 464 valueMap.clear();
465 reverseLookup.clear(); 465 reverseLookup.clear();
466 466
467 - const QList<QVariant> originalLabels = values(propName); 467 + const QList<QVariant> originalLabels = File::values(*this, propName);
468 foreach (const QVariant & label, originalLabels) { 468 foreach (const QVariant & label, originalLabels) {
469 QString labelString = label.toString(); 469 QString labelString = label.toString();
470 if (!valueMap.contains(labelString)) { 470 if (!valueMap.contains(labelString)) {
@@ -481,9 +481,9 @@ QList&lt;int&gt; TemplateList::indexProperty(const QString &amp; propName, QHash&lt;QString, @@ -481,9 +481,9 @@ QList&lt;int&gt; TemplateList::indexProperty(const QString &amp; propName, QHash&lt;QString,
481 } 481 }
482 482
483 // uses -1 for missing values 483 // uses -1 for missing values
484 -QList<int> TemplateList::applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const 484 +QList<int> TemplateList::applyIndex(const QString &propName, const QHash<QString, int> &valueMap) const
485 { 485 {
486 - const QList<QString> originalLabels = get<QString>(propName); 486 + const QList<QString> originalLabels = File::get<QString>(*this, propName);
487 487
488 QList<int> result; 488 QList<int> result;
489 for (int i=0; i<originalLabels.size(); i++) { 489 for (int i=0; i<originalLabels.size(); i++) {
openbr/openbr_plugin.h
@@ -228,7 +228,20 @@ struct BR_EXPORT File @@ -228,7 +228,20 @@ struct BR_EXPORT File
228 return variant.value<T>(); 228 return variant.value<T>();
229 } 229 }
230 230
231 - /*!< \brief Returns a list of type T for the key, throwing an error if the key does not exist or if the value cannot be converted to the specified type. */ 231 + /*!< \brief Returns a value for the key, returning \em defaultValue if the key does not exist or can't be converted. */
  232 + template <typename T>
  233 + T get(const QString &key, const T &defaultValue) const
  234 + {
  235 + if (!contains(key)) return defaultValue;
  236 + QVariant variant = value(key);
  237 + if (!variant.canConvert<T>()) return defaultValue;
  238 + return variant.value<T>();
  239 + }
  240 +
  241 + /*!< \brief Specialization for boolean type. */
  242 + bool getBool(const QString &key, bool defaultValue = false) const;
  243 +
  244 + /*!< \brief Specialization for list type. Returns a list of type T for the key, throwing an error if the key does not exist or if the value cannot be converted to the specified type. */
232 template <typename T> 245 template <typename T>
233 QList<T> getList(const QString &key) const 246 QList<T> getList(const QString &key) const
234 { 247 {
@@ -241,17 +254,31 @@ struct BR_EXPORT File @@ -241,17 +254,31 @@ struct BR_EXPORT File
241 return list; 254 return list;
242 } 255 }
243 256
244 - /*!< \brief Specialization for boolean type. */  
245 - bool getBool(const QString &key, bool defaultValue = false) const; 257 + /*!< \brief Returns the value for the specified key for every file in the list. */
  258 + template<class U>
  259 + static QList<QVariant> values(const QList<U> &fileList, const QString &key)
  260 + {
  261 + QList<QVariant> values; values.reserve(fileList.size());
  262 + foreach (const U &f, fileList) values.append(((const File&)f).value(key));
  263 + return values;
  264 + }
246 265
247 - /*!< \brief Returns a value for the key, returning \em defaultValue if the key does not exist or can't be converted. */  
248 - template <typename T>  
249 - T get(const QString &key, const T &defaultValue) const 266 + /*!< \brief Returns a value for the key for every file in the list, throwing an error if the key does not exist. */
  267 + template<class T, class U>
  268 + static QList<T> get(const QList<U> &fileList, const QString &key)
250 { 269 {
251 - if (!contains(key)) return defaultValue;  
252 - QVariant variant = value(key);  
253 - if (!variant.canConvert<T>()) return defaultValue;  
254 - return variant.value<T>(); 270 + QList<T> result; result.reserve(fileList.size());
  271 + foreach (const U &f, fileList) result.append(((const File&)f).get<T>(key));
  272 + return result;
  273 + }
  274 +
  275 + /*!< \brief Returns a value for the key for every file in the list, returning \em defaultValue if the key does not exist or can't be converted. */
  276 + template<class T, class U>
  277 + static QList<T> get(const QList<U> &fileList, const QString &key, const T &defaultValue)
  278 + {
  279 + QList<T> result; result.reserve(fileList.size());
  280 + foreach (const U &f, fileList) result.append(static_cast<const File&>(f).get<T>(key, defaultValue));
  281 + return result;
255 } 282 }
256 283
257 inline bool failed() const { return getBool("FTE") || getBool("FTO"); } /*!< \brief Returns \c true if the file failed to open or enroll, \c false otherwise. */ 284 inline bool failed() const { return getBool("FTE") || getBool("FTO"); } /*!< \brief Returns \c true if the file failed to open or enroll, \c false otherwise. */
@@ -297,23 +324,6 @@ struct BR_EXPORT FileList : public QList&lt;File&gt; @@ -297,23 +324,6 @@ struct BR_EXPORT FileList : public QList&lt;File&gt;
297 QStringList flat() const; /*!< \brief Returns br::File::flat() for each file in the list. */ 324 QStringList flat() const; /*!< \brief Returns br::File::flat() for each file in the list. */
298 QStringList names() const; /*!< \brief Returns #br::File::name for each file in the list. */ 325 QStringList names() const; /*!< \brief Returns #br::File::name for each file in the list. */
299 void sort(const QString& key); /*!< \brief Sort the list based on metadata. */ 326 void sort(const QString& key); /*!< \brief Sort the list based on metadata. */
300 - /*!< \brief Returns values associated with the input propName for each file in the list. */  
301 - template<typename T>  
302 - QList<T> get(const QString & propName) const  
303 - {  
304 - QList<T> values; values.reserve(size());  
305 - foreach (const File &f, *this)  
306 - values.append(f.get<T>(propName));  
307 - return values;  
308 - }  
309 - template<typename T>  
310 - QList<T> get(const QString & propName, T defaultValue) const  
311 - {  
312 - QList<T> values; values.reserve(size());  
313 - foreach (const File &f, *this)  
314 - values.append(f.contains(propName) ? f.get<T>(propName) : defaultValue);  
315 - return values;  
316 - }  
317 327
318 QList<int> crossValidationPartitions() const; /*!< \brief Returns the cross-validation partition (default=0) for each file in the list. */ 328 QList<int> crossValidationPartitions() const; /*!< \brief Returns the cross-validation partition (default=0) for each file in the list. */
319 int failures() const; /*!< \brief Returns the number of files with br::File::failed(). */ 329 int failures() const; /*!< \brief Returns the number of files with br::File::failed(). */
@@ -344,6 +354,7 @@ struct Template : public QList&lt;cv::Mat&gt; @@ -344,6 +354,7 @@ struct Template : public QList&lt;cv::Mat&gt;
344 inline const cv::Mat &m() const { static const cv::Mat NullMatrix; 354 inline const cv::Mat &m() const { static const cv::Mat NullMatrix;
345 return isEmpty() ? qFatal("Empty template."), NullMatrix : last(); } /*!< \brief Idiom to treat the template as a matrix. */ 355 return isEmpty() ? qFatal("Empty template."), NullMatrix : last(); } /*!< \brief Idiom to treat the template as a matrix. */
346 inline cv::Mat &m() { return isEmpty() ? append(cv::Mat()), last() : last(); } /*!< \brief Idiom to treat the template as a matrix. */ 356 inline cv::Mat &m() { return isEmpty() ? append(cv::Mat()), last() : last(); } /*!< \brief Idiom to treat the template as a matrix. */
  357 + inline const File &operator()() const { return file; }
347 inline cv::Mat &operator=(const cv::Mat &other) { return m() = other; } /*!< \brief Idiom to treat the template as a matrix. */ 358 inline cv::Mat &operator=(const cv::Mat &other) { return m() = other; } /*!< \brief Idiom to treat the template as a matrix. */
348 inline operator const cv::Mat&() const { return m(); } /*!< \brief Idiom to treat the template as a matrix. */ 359 inline operator const cv::Mat&() const { return m(); } /*!< \brief Idiom to treat the template as a matrix. */
349 inline operator cv::Mat&() { return m(); } /*!< \brief Idiom to treat the template as a matrix. */ 360 inline operator cv::Mat&() { return m(); } /*!< \brief Idiom to treat the template as a matrix. */
@@ -406,7 +417,6 @@ struct TemplateList : public QList&lt;Template&gt; @@ -406,7 +417,6 @@ struct TemplateList : public QList&lt;Template&gt;
406 QList<int> indexProperty(const QString & propName, QHash<QString, int> & valueMap, QHash<int, QVariant> & reverseLookup) const; 417 QList<int> indexProperty(const QString & propName, QHash<QString, int> & valueMap, QHash<int, QVariant> & reverseLookup) const;
407 QList<int> applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const; 418 QList<int> applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const;
408 419
409 -  
410 /*! 420 /*!
411 * \brief Returns the total number of bytes in all the templates. 421 * \brief Returns the total number of bytes in all the templates.
412 */ 422 */
@@ -477,23 +487,6 @@ struct TemplateList : public QList&lt;Template&gt; @@ -477,23 +487,6 @@ struct TemplateList : public QList&lt;Template&gt;
477 FileList operator()() const { return files(); } 487 FileList operator()() const { return files(); }
478 488
479 /*! 489 /*!
480 - * \brief Returns br::Template::label() for each template in the list.  
481 - */  
482 - template<typename T>  
483 - QList<T> get(const QString & propName) const  
484 - {  
485 - QList<T> values; values.reserve(size());  
486 - foreach (const Template &t, *this) values.append(t.file.get<T>(propName));  
487 - return values;  
488 - }  
489 - QList<QVariant> values(const QString & propName) const  
490 - {  
491 - QList<QVariant> values; values.reserve(size());  
492 - foreach (const Template &t, *this) values.append(t.file.value(propName));  
493 - return values;  
494 - }  
495 -  
496 - /*!  
497 * \brief Returns the number of occurences for each label in the list. 490 * \brief Returns the number of occurences for each label in the list.
498 */ 491 */
499 template<typename T> 492 template<typename T>
@@ -814,7 +807,8 @@ struct Factory @@ -814,7 +807,8 @@ struct Factory
814 //! [Factory make] 807 //! [Factory make]
815 static T *make(const File &file) 808 static T *make(const File &file)
816 { 809 {
817 - QString name = file.suffix(); 810 + QString name = file.get<QString>("plugin", "");
  811 + if (name.isEmpty()) name = file.suffix();
818 if (!names().contains(name)) { 812 if (!names().contains(name)) {
819 if (names().contains("Empty") && name.isEmpty()) name = "Empty"; 813 if (names().contains("Empty") && name.isEmpty()) name = "Empty";
820 else if (names().contains("Default")) name = "Default"; 814 else if (names().contains("Default")) name = "Default";
@@ -1024,7 +1018,7 @@ public: @@ -1024,7 +1018,7 @@ public:
1024 virtual TemplateList readBlock(bool *done) = 0; /*!< \brief Retrieve a portion of the stored templates. */ 1018 virtual TemplateList readBlock(bool *done) = 0; /*!< \brief Retrieve a portion of the stored templates. */
1025 void writeBlock(const TemplateList &templates); /*!< \brief Serialize a template list. */ 1019 void writeBlock(const TemplateList &templates); /*!< \brief Serialize a template list. */
1026 virtual void write(const Template &t) = 0; /*!< \brief Serialize a template. */ 1020 virtual void write(const Template &t) = 0; /*!< \brief Serialize a template. */
1027 - static Gallery *make(const File &file); /*!< \brief Make a gallery from a file list. */ 1021 + static Gallery *make(const File &file); /*!< \brief Make a gallery to/from a file on disk. */
1028 1022
1029 private: 1023 private:
1030 QSharedPointer<Gallery> next; 1024 QSharedPointer<Gallery> next;
openbr/plugins/eigen3.cpp
@@ -348,7 +348,7 @@ class LDATransform : public Transform @@ -348,7 +348,7 @@ class LDATransform : public Transform
348 348
349 // OpenBR ensures that class values range from 0 to numClasses-1. 349 // OpenBR ensures that class values range from 0 to numClasses-1.
350 // Label exists because we created it earlier with relabel 350 // Label exists because we created it earlier with relabel
351 - QList<int> classes = trainingSet.get<int>("Label"); 351 + QList<int> classes = File::get<int>(trainingSet, "Label");
352 QMap<int, int> classCounts = trainingSet.countValues<int>("Label"); 352 QMap<int, int> classCounts = trainingSet.countValues<int>("Label");
353 const int numClasses = classCounts.size(); 353 const int numClasses = classCounts.size();
354 354
openbr/plugins/gallery.cpp
@@ -881,6 +881,45 @@ class statGallery : public Gallery @@ -881,6 +881,45 @@ class statGallery : public Gallery
881 881
882 BR_REGISTER(Gallery, statGallery) 882 BR_REGISTER(Gallery, statGallery)
883 883
  884 +/*!
  885 + * \ingroup galleries
  886 + * \brief Implements the FDDB detection format.
  887 + * \author Josh Klontz \cite jklontz
  888 + *
  889 + * http://vis-www.cs.umass.edu/fddb/README.txt
  890 + */
  891 +class FDDBGallery : public Gallery
  892 +{
  893 + Q_OBJECT
  894 +
  895 + TemplateList readBlock(bool *done)
  896 + {
  897 + *done = true;
  898 + QStringList lines = QtUtils::readLines(file);
  899 + TemplateList templates;
  900 + while (!lines.empty()) {
  901 + const QString fileName = lines.takeFirst();
  902 + int numDetects = lines.takeFirst().toInt();
  903 + for (int i=0; i<numDetects; i++) {
  904 + const QStringList detect = lines.takeFirst().split(' ');
  905 + Template t(fileName);
  906 + t.file.set("Face", QRectF(detect[0].toFloat(), detect[1].toFloat(), detect[2].toFloat(), detect[3].toFloat()));
  907 + t.file.set("Confidence", detect[4].toFloat());
  908 + templates.append(t);
  909 + }
  910 + }
  911 + return templates;
  912 + }
  913 +
  914 + void write(const Template &t)
  915 + {
  916 + (void) t;
  917 + qFatal("Not implemented.");
  918 + }
  919 +};
  920 +
  921 +BR_REGISTER(Gallery, FDDBGallery)
  922 +
884 } // namespace br 923 } // namespace br
885 924
886 #include "gallery.moc" 925 #include "gallery.moc"
openbr/plugins/independent.cpp
@@ -20,7 +20,7 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t @@ -20,7 +20,7 @@ static TemplateList Downsample(const TemplateList &amp;templates, const Transform *t
20 const bool atLeast = transform->instances < 0; 20 const bool atLeast = transform->instances < 0;
21 const int instances = abs(transform->instances); 21 const int instances = abs(transform->instances);
22 22
23 - QList<QString> allLabels = templates.get<QString>("Subject"); 23 + QList<QString> allLabels = File::get<QString>(templates, "Subject");
24 QList<QString> uniqueLabels = allLabels.toSet().toList(); 24 QList<QString> uniqueLabels = allLabels.toSet().toList();
25 qSort(uniqueLabels); 25 qSort(uniqueLabels);
26 26
openbr/plugins/output.cpp
@@ -35,9 +35,9 @@ @@ -35,9 +35,9 @@
35 #include "openbr_internal.h" 35 #include "openbr_internal.h"
36 36
37 #include "openbr/core/bee.h" 37 #include "openbr/core/bee.h"
  38 +#include "openbr/core/eval.h"
38 #include "openbr/core/common.h" 39 #include "openbr/core/common.h"
39 #include "openbr/core/opencvutils.h" 40 #include "openbr/core/opencvutils.h"
40 -#include "openbr/core/plot.h"  
41 #include "openbr/core/qtutils.h" 41 #include "openbr/core/qtutils.h"
42 42
43 namespace br 43 namespace br
@@ -146,8 +146,8 @@ class meltOutput : public MatrixOutput @@ -146,8 +146,8 @@ class meltOutput : public MatrixOutput
146 QStringList lines; 146 QStringList lines;
147 if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys)); 147 if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys));
148 148
149 - QList<QString> queryLabels = queryFiles.get<QString>("Subject");  
150 - QList<QString> targetLabels = targetFiles.get<QString>("Subject"); 149 + QList<QString> queryLabels = File::get<QString>(queryFiles, "Subject");
  150 + QList<QString> targetLabels = File::get<QString>(targetFiles, "Subject");
151 151
152 for (int i=0; i<queryFiles.size(); i++) { 152 for (int i=0; i<queryFiles.size(); i++) {
153 for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) { 153 for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) {
openbr/plugins/svm.cpp
@@ -130,7 +130,7 @@ private: @@ -130,7 +130,7 @@ private:
130 Mat lab; 130 Mat lab;
131 // If we are doing regression, assume subject has float values 131 // If we are doing regression, assume subject has float values
132 if (type == EPS_SVR || type == NU_SVR) { 132 if (type == EPS_SVR || type == NU_SVR) {
133 - lab = OpenCVUtils::toMat(_data.get<float>("Subject")); 133 + lab = OpenCVUtils::toMat(File::get<float>(_data, "Subject"));
134 } 134 }
135 // If we are doing classification, assume subject has discrete values, map them 135 // If we are doing classification, assume subject has discrete values, map them
136 // and store the mapping data 136 // and store the mapping data
scripts/downloadDatasets.sh
@@ -48,3 +48,19 @@ if [ ! -d ../data/MEDS/img ]; then @@ -48,3 +48,19 @@ if [ ! -d ../data/MEDS/img ]; then
48 mv data/*/*.jpg ../data/MEDS/img 48 mv data/*/*.jpg ../data/MEDS/img
49 rm -r data NIST_SD32_MEDS-II_face.zip 49 rm -r data NIST_SD32_MEDS-II_face.zip
50 fi 50 fi
  51 +
  52 +# KTH
  53 +if [ ! -d ../data/KTH/vid ]; then
  54 + echo "Downloading KTH..."
  55 + mkdir ../data/KTH/vid
  56 + for vidclass in {'boxing','handclapping','handwaving','jogging','running','walking'}; do
  57 + if hash curl 2>/dev/null; then
  58 + curl -OL http://www.nada.kth.se/cvap/actions/${vidclass}.zip
  59 + else
  60 + wget http://www.nada.kth.se/cvap/actions/${vidclass}.zip
  61 + fi
  62 + mkdir ../data/KTH/vid/${vidclass}
  63 + unzip ${vidclass}.zip -d ../data/KTH/vid/${vidclass}
  64 + rm ${vidclass}.zip
  65 + done
  66 +fi