Commit 6a7d402fe8c18637c8d1a6e5a46fee75108f980c

Authored by Josh Klontz
1 parent 19a3de11

generalized evalClassification

openbr/core/classify.cpp
@@ -14,8 +14,6 @@ @@ -14,8 +14,6 @@
14 * limitations under the License. * 14 * limitations under the License. *
15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
16 16
17 -#include <QDebug>  
18 -#include <QHash>  
19 #include <openbr/openbr_plugin.h> 17 #include <openbr/openbr_plugin.h>
20 18
21 #include "classify.h" 19 #include "classify.h"
@@ -24,7 +22,7 @@ @@ -24,7 +22,7 @@
24 // Helper struct for statistics accumulation 22 // Helper struct for statistics accumulation
25 struct Counter 23 struct Counter
26 { 24 {
27 - int truePositive, falsePositive, falseNegative; 25 + float truePositive, falsePositive, falseNegative;
28 Counter() 26 Counter()
29 { 27 {
30 truePositive = 0; 28 truePositive = 0;
@@ -41,35 +39,44 @@ void br::EvalClassification(const QString &amp;predictedInput, const QString &amp;truthI @@ -41,35 +39,44 @@ void br::EvalClassification(const QString &amp;predictedInput, const QString &amp;truthI
41 TemplateList truth(TemplateList::fromGallery(truthInput)); 39 TemplateList truth(TemplateList::fromGallery(truthInput));
42 if (predicted.size() != truth.size()) qFatal("Input size mismatch."); 40 if (predicted.size() != truth.size()) qFatal("Input size mismatch.");
43 41
44 - QHash<int, Counter> counters; 42 + QHash<QString, Counter> counters;
45 for (int i=0; i<predicted.size(); i++) { 43 for (int i=0; i<predicted.size(); i++) {
46 if (predicted[i].file.name != truth[i].file.name) 44 if (predicted[i].file.name != truth[i].file.name)
47 qFatal("Input order mismatch."); 45 qFatal("Input order mismatch.");
48 46
49 - const int trueLabel = truth[i].file.label();  
50 - const int predictedLabel = predicted[i].file.label();  
51 - if (trueLabel == predictedLabel) {  
52 - counters[trueLabel].truePositive++;  
53 - } else {  
54 - counters[trueLabel].falseNegative++;  
55 - counters[predictedLabel].falsePositive++; 47 + // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy.
  48 + QStringList predictedSubjects = predicted[i].file.get<QStringList>("Subject");
  49 + QStringList trueSubjects = truth[i].file.get<QStringList>("Subject");
  50 + foreach (const QString &subject, trueSubjects.toVector() /* Hack to copy the list. */) {
  51 + if (predictedSubjects.contains(subject)) {
  52 + counters[subject].truePositive++;
  53 + trueSubjects.removeOne(subject);
  54 + predictedSubjects.removeOne(subject);
  55 + } else {
  56 + counters[subject].falseNegative++;
  57 + }
56 } 58 }
  59 +
  60 + for (int i=0; i<trueSubjects.size(); i++)
  61 + foreach (const QString &subject, predictedSubjects)
  62 + counters[subject].falsePositive += 1.f / predictedSubjects.size();
57 } 63 }
58 64
59 QSharedPointer<Output> output(Output::make("", FileList() << "Subject" << "Count" << "Precision" << "Recall" << "F-score", FileList(counters.size()))); 65 QSharedPointer<Output> output(Output::make("", FileList() << "Subject" << "Count" << "Precision" << "Recall" << "F-score", FileList(counters.size())));
60 66
61 int tpc = 0; 67 int tpc = 0;
62 int fnc = 0; 68 int fnc = 0;
  69 + const QStringList keys = counters.keys();
63 for (int i=0; i<counters.size(); i++) { 70 for (int i=0; i<counters.size(); i++) {
64 - int trueLabel = counters.keys()[i];  
65 - const Counter &counter = counters[trueLabel]; 71 + const QString &subject = keys[i];
  72 + const Counter &counter = counters[subject];
66 tpc += counter.truePositive; 73 tpc += counter.truePositive;
67 fnc += counter.falseNegative; 74 fnc += counter.falseNegative;
68 const int count = counter.truePositive + counter.falseNegative; 75 const int count = counter.truePositive + counter.falseNegative;
69 const float precision = counter.truePositive / (float)(counter.truePositive + counter.falsePositive); 76 const float precision = counter.truePositive / (float)(counter.truePositive + counter.falsePositive);
70 const float recall = counter.truePositive / (float)(counter.truePositive + counter.falseNegative); 77 const float recall = counter.truePositive / (float)(counter.truePositive + counter.falseNegative);
71 const float fscore = 2 * precision * recall / (precision + recall); 78 const float fscore = 2 * precision * recall / (precision + recall);
72 - output->setRelative(trueLabel, i, 0); 79 + output->setRelative(File("", subject).label(), i, 0);
73 output->setRelative(count, i, 1); 80 output->setRelative(count, i, 1);
74 output->setRelative(precision, i, 2); 81 output->setRelative(precision, i, 2);
75 output->setRelative(recall, i, 3); 82 output->setRelative(recall, i, 3);
openbr/openbr_plugin.cpp
@@ -1025,7 +1025,7 @@ MatrixOutput *MatrixOutput::make(const FileList &amp;targetFiles, const FileList &amp;qu @@ -1025,7 +1025,7 @@ MatrixOutput *MatrixOutput::make(const FileList &amp;targetFiles, const FileList &amp;qu
1025 /* MatrixOutput - protected methods */ 1025 /* MatrixOutput - protected methods */
1026 QString MatrixOutput::toString(int row, int column) const 1026 QString MatrixOutput::toString(int row, int column) const
1027 { 1027 {
1028 - if (targetFiles[column] == "Label") { 1028 + if (targetFiles[column] == "Subject") {
1029 const int label = data.at<float>(row,column); 1029 const int label = data.at<float>(row,column);
1030 return Globals->subjects.key(label, QString::number(label)); 1030 return Globals->subjects.key(label, QString::number(label));
1031 } 1031 }
openbr/plugins/algorithms.cpp
@@ -50,7 +50,7 @@ class AlgorithmsInitializer : public Initializer @@ -50,7 +50,7 @@ class AlgorithmsInitializer : public Initializer
50 Globals->abbreviations.insert("SmallSIFT", "Open+LimitSize(512)+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); 50 Globals->abbreviations.insert("SmallSIFT", "Open+LimitSize(512)+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)");
51 Globals->abbreviations.insert("SmallSURF", "Open+LimitSize(512)+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)"); 51 Globals->abbreviations.insert("SmallSURF", "Open+LimitSize(512)+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)");
52 Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):L2"); 52 Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):L2");
53 - Globals->abbreviations.insert("ImageClassification", "Open+CropSquare+LimitSize(256)+Cvt(Gray)+Gradient+Bin(0,360,9,true)+Merge+Integral+RecursiveIntegralSampler(4,2,8,Singleton(KMeans(1024)[fraction=0.5]))+Cat+CvtFloat+Hist(1024)+KNS(5,Dist(L1))"); 53 + Globals->abbreviations.insert("ImageClassification", "Open+CropSquare+LimitSize(256)+Cvt(Gray)+Gradient+Bin(0,360,9,true)+Merge+Integral+RecursiveIntegralSampler(4,2,8,Singleton(KMeans(1024)[fraction=0.5]))+Cat+CvtFloat+Hist(1024)+KNS(5,Dist(L1))+Rename(KNS,Subject)");
54 54
55 // Hash 55 // Hash
56 Globals->abbreviations.insert("FileName", "Name+Identity:Identical"); 56 Globals->abbreviations.insert("FileName", "Name+Identity:Identical");
openbr/plugins/cluster.cpp
@@ -106,7 +106,8 @@ class KNSTransform : public Transform @@ -106,7 +106,8 @@ class KNSTransform : public Transform
106 subjects.insert(gallery[sortedScores[i].second].file.subject()); 106 subjects.insert(gallery[sortedScores[i].second].file.subject());
107 i++; 107 i++;
108 } 108 }
109 - dst.file.set("KNS", QStringList(subjects.toList()).join(", ")); 109 + const QStringList subjectList = subjects.toList();
  110 + dst.file.set("KNS", subjects.size() > 1 ? "[" + subjectList.join(",") + "]" : subjectList.first());
110 } 111 }
111 112
112 void store(QDataStream &stream) const 113 void store(QDataStream &stream) const