From 6a7d402fe8c18637c8d1a6e5a46fee75108f980c Mon Sep 17 00:00:00 2001 From: Josh Klontz Date: Thu, 11 Apr 2013 17:48:35 -0400 Subject: [PATCH] generalized evalClassification --- openbr/core/classify.cpp | 35 +++++++++++++++++++++-------------- openbr/openbr_plugin.cpp | 2 +- openbr/plugins/algorithms.cpp | 2 +- openbr/plugins/cluster.cpp | 3 ++- 4 files changed, 25 insertions(+), 17 deletions(-) diff --git a/openbr/core/classify.cpp b/openbr/core/classify.cpp index fd4cfae..ad2937e 100644 --- a/openbr/core/classify.cpp +++ b/openbr/core/classify.cpp @@ -14,8 +14,6 @@ * limitations under the License. * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ -#include -#include #include #include "classify.h" @@ -24,7 +22,7 @@ // Helper struct for statistics accumulation struct Counter { - int truePositive, falsePositive, falseNegative; + float truePositive, falsePositive, falseNegative; Counter() { truePositive = 0; @@ -41,35 +39,44 @@ void br::EvalClassification(const QString &predictedInput, const QString &truthI TemplateList truth(TemplateList::fromGallery(truthInput)); if (predicted.size() != truth.size()) qFatal("Input size mismatch."); - QHash counters; + QHash counters; for (int i=0; i("Subject"); + QStringList trueSubjects = truth[i].file.get("Subject"); + foreach (const QString &subject, trueSubjects.toVector() /* Hack to copy the list. */) { + if (predictedSubjects.contains(subject)) { + counters[subject].truePositive++; + trueSubjects.removeOne(subject); + predictedSubjects.removeOne(subject); + } else { + counters[subject].falseNegative++; + } } + + for (int i=0; i output(Output::make("", FileList() << "Subject" << "Count" << "Precision" << "Recall" << "F-score", FileList(counters.size()))); int tpc = 0; int fnc = 0; + const QStringList keys = counters.keys(); for (int i=0; isetRelative(trueLabel, i, 0); + output->setRelative(File("", subject).label(), i, 0); output->setRelative(count, i, 1); output->setRelative(precision, i, 2); output->setRelative(recall, i, 3); diff --git a/openbr/openbr_plugin.cpp b/openbr/openbr_plugin.cpp index 602f001..99a2845 100644 --- a/openbr/openbr_plugin.cpp +++ b/openbr/openbr_plugin.cpp @@ -1025,7 +1025,7 @@ MatrixOutput *MatrixOutput::make(const FileList &targetFiles, const FileList &qu /* MatrixOutput - protected methods */ QString MatrixOutput::toString(int row, int column) const { - if (targetFiles[column] == "Label") { + if (targetFiles[column] == "Subject") { const int label = data.at(row,column); return Globals->subjects.key(label, QString::number(label)); } diff --git a/openbr/plugins/algorithms.cpp b/openbr/plugins/algorithms.cpp index 113e59d..98d14d9 100644 --- a/openbr/plugins/algorithms.cpp +++ b/openbr/plugins/algorithms.cpp @@ -50,7 +50,7 @@ class AlgorithmsInitializer : public Initializer Globals->abbreviations.insert("SmallSIFT", "Open+LimitSize(512)+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); Globals->abbreviations.insert("SmallSURF", "Open+LimitSize(512)+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)"); Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):L2"); - Globals->abbreviations.insert("ImageClassification", "Open+CropSquare+LimitSize(256)+Cvt(Gray)+Gradient+Bin(0,360,9,true)+Merge+Integral+RecursiveIntegralSampler(4,2,8,Singleton(KMeans(1024)[fraction=0.5]))+Cat+CvtFloat+Hist(1024)+KNS(5,Dist(L1))"); + Globals->abbreviations.insert("ImageClassification", "Open+CropSquare+LimitSize(256)+Cvt(Gray)+Gradient+Bin(0,360,9,true)+Merge+Integral+RecursiveIntegralSampler(4,2,8,Singleton(KMeans(1024)[fraction=0.5]))+Cat+CvtFloat+Hist(1024)+KNS(5,Dist(L1))+Rename(KNS,Subject)"); // Hash Globals->abbreviations.insert("FileName", "Name+Identity:Identical"); diff --git a/openbr/plugins/cluster.cpp b/openbr/plugins/cluster.cpp index a21c34a..b5499d4 100644 --- a/openbr/plugins/cluster.cpp +++ b/openbr/plugins/cluster.cpp @@ -106,7 +106,8 @@ class KNSTransform : public Transform subjects.insert(gallery[sortedScores[i].second].file.subject()); i++; } - dst.file.set("KNS", QStringList(subjects.toList()).join(", ")); + const QStringList subjectList = subjects.toList(); + dst.file.set("KNS", subjects.size() > 1 ? "[" + subjectList.join(",") + "]" : subjectList.first()); } void store(QDataStream &stream) const -- libgit2 0.21.4