Commit 6a7d402fe8c18637c8d1a6e5a46fee75108f980c

Authored by Josh Klontz
1 parent 19a3de11

generalized evalClassification

openbr/core/classify.cpp
... ... @@ -14,8 +14,6 @@
14 14 * limitations under the License. *
15 15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
16 16  
17   -#include <QDebug>
18   -#include <QHash>
19 17 #include <openbr/openbr_plugin.h>
20 18  
21 19 #include "classify.h"
... ... @@ -24,7 +22,7 @@
24 22 // Helper struct for statistics accumulation
25 23 struct Counter
26 24 {
27   - int truePositive, falsePositive, falseNegative;
  25 + float truePositive, falsePositive, falseNegative;
28 26 Counter()
29 27 {
30 28 truePositive = 0;
... ... @@ -41,35 +39,44 @@ void br::EvalClassification(const QString &amp;predictedInput, const QString &amp;truthI
41 39 TemplateList truth(TemplateList::fromGallery(truthInput));
42 40 if (predicted.size() != truth.size()) qFatal("Input size mismatch.");
43 41  
44   - QHash<int, Counter> counters;
  42 + QHash<QString, Counter> counters;
45 43 for (int i=0; i<predicted.size(); i++) {
46 44 if (predicted[i].file.name != truth[i].file.name)
47 45 qFatal("Input order mismatch.");
48 46  
49   - const int trueLabel = truth[i].file.label();
50   - const int predictedLabel = predicted[i].file.label();
51   - if (trueLabel == predictedLabel) {
52   - counters[trueLabel].truePositive++;
53   - } else {
54   - counters[trueLabel].falseNegative++;
55   - counters[predictedLabel].falsePositive++;
  47 + // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy.
  48 + QStringList predictedSubjects = predicted[i].file.get<QStringList>("Subject");
  49 + QStringList trueSubjects = truth[i].file.get<QStringList>("Subject");
  50 + foreach (const QString &subject, trueSubjects.toVector() /* Hack to copy the list. */) {
  51 + if (predictedSubjects.contains(subject)) {
  52 + counters[subject].truePositive++;
  53 + trueSubjects.removeOne(subject);
  54 + predictedSubjects.removeOne(subject);
  55 + } else {
  56 + counters[subject].falseNegative++;
  57 + }
56 58 }
  59 +
  60 + for (int i=0; i<trueSubjects.size(); i++)
  61 + foreach (const QString &subject, predictedSubjects)
  62 + counters[subject].falsePositive += 1.f / predictedSubjects.size();
57 63 }
58 64  
59 65 QSharedPointer<Output> output(Output::make("", FileList() << "Subject" << "Count" << "Precision" << "Recall" << "F-score", FileList(counters.size())));
60 66  
61 67 int tpc = 0;
62 68 int fnc = 0;
  69 + const QStringList keys = counters.keys();
63 70 for (int i=0; i<counters.size(); i++) {
64   - int trueLabel = counters.keys()[i];
65   - const Counter &counter = counters[trueLabel];
  71 + const QString &subject = keys[i];
  72 + const Counter &counter = counters[subject];
66 73 tpc += counter.truePositive;
67 74 fnc += counter.falseNegative;
68 75 const int count = counter.truePositive + counter.falseNegative;
69 76 const float precision = counter.truePositive / (float)(counter.truePositive + counter.falsePositive);
70 77 const float recall = counter.truePositive / (float)(counter.truePositive + counter.falseNegative);
71 78 const float fscore = 2 * precision * recall / (precision + recall);
72   - output->setRelative(trueLabel, i, 0);
  79 + output->setRelative(File("", subject).label(), i, 0);
73 80 output->setRelative(count, i, 1);
74 81 output->setRelative(precision, i, 2);
75 82 output->setRelative(recall, i, 3);
... ...
openbr/openbr_plugin.cpp
... ... @@ -1025,7 +1025,7 @@ MatrixOutput *MatrixOutput::make(const FileList &amp;targetFiles, const FileList &amp;qu
1025 1025 /* MatrixOutput - protected methods */
1026 1026 QString MatrixOutput::toString(int row, int column) const
1027 1027 {
1028   - if (targetFiles[column] == "Label") {
  1028 + if (targetFiles[column] == "Subject") {
1029 1029 const int label = data.at<float>(row,column);
1030 1030 return Globals->subjects.key(label, QString::number(label));
1031 1031 }
... ...
openbr/plugins/algorithms.cpp
... ... @@ -50,7 +50,7 @@ class AlgorithmsInitializer : public Initializer
50 50 Globals->abbreviations.insert("SmallSIFT", "Open+LimitSize(512)+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)");
51 51 Globals->abbreviations.insert("SmallSURF", "Open+LimitSize(512)+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)");
52 52 Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):L2");
53   - Globals->abbreviations.insert("ImageClassification", "Open+CropSquare+LimitSize(256)+Cvt(Gray)+Gradient+Bin(0,360,9,true)+Merge+Integral+RecursiveIntegralSampler(4,2,8,Singleton(KMeans(1024)[fraction=0.5]))+Cat+CvtFloat+Hist(1024)+KNS(5,Dist(L1))");
  53 + Globals->abbreviations.insert("ImageClassification", "Open+CropSquare+LimitSize(256)+Cvt(Gray)+Gradient+Bin(0,360,9,true)+Merge+Integral+RecursiveIntegralSampler(4,2,8,Singleton(KMeans(1024)[fraction=0.5]))+Cat+CvtFloat+Hist(1024)+KNS(5,Dist(L1))+Rename(KNS,Subject)");
54 54  
55 55 // Hash
56 56 Globals->abbreviations.insert("FileName", "Name+Identity:Identical");
... ...
openbr/plugins/cluster.cpp
... ... @@ -106,7 +106,8 @@ class KNSTransform : public Transform
106 106 subjects.insert(gallery[sortedScores[i].second].file.subject());
107 107 i++;
108 108 }
109   - dst.file.set("KNS", QStringList(subjects.toList()).join(", "));
  109 + const QStringList subjectList = subjects.toList();
  110 + dst.file.set("KNS", subjects.size() > 1 ? "[" + subjectList.join(",") + "]" : subjectList.first());
110 111 }
111 112  
112 113 void store(QDataStream &stream) const
... ...