Commit 6a7d402fe8c18637c8d1a6e5a46fee75108f980c
1 parent
19a3de11
generalized evalClassification
Showing
4 changed files
with
25 additions
and
17 deletions
openbr/core/classify.cpp
| @@ -14,8 +14,6 @@ | @@ -14,8 +14,6 @@ | ||
| 14 | * limitations under the License. * | 14 | * limitations under the License. * |
| 15 | * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ | 15 | * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ |
| 16 | 16 | ||
| 17 | -#include <QDebug> | ||
| 18 | -#include <QHash> | ||
| 19 | #include <openbr/openbr_plugin.h> | 17 | #include <openbr/openbr_plugin.h> |
| 20 | 18 | ||
| 21 | #include "classify.h" | 19 | #include "classify.h" |
| @@ -24,7 +22,7 @@ | @@ -24,7 +22,7 @@ | ||
| 24 | // Helper struct for statistics accumulation | 22 | // Helper struct for statistics accumulation |
| 25 | struct Counter | 23 | struct Counter |
| 26 | { | 24 | { |
| 27 | - int truePositive, falsePositive, falseNegative; | 25 | + float truePositive, falsePositive, falseNegative; |
| 28 | Counter() | 26 | Counter() |
| 29 | { | 27 | { |
| 30 | truePositive = 0; | 28 | truePositive = 0; |
| @@ -41,35 +39,44 @@ void br::EvalClassification(const QString &predictedInput, const QString &truthI | @@ -41,35 +39,44 @@ void br::EvalClassification(const QString &predictedInput, const QString &truthI | ||
| 41 | TemplateList truth(TemplateList::fromGallery(truthInput)); | 39 | TemplateList truth(TemplateList::fromGallery(truthInput)); |
| 42 | if (predicted.size() != truth.size()) qFatal("Input size mismatch."); | 40 | if (predicted.size() != truth.size()) qFatal("Input size mismatch."); |
| 43 | 41 | ||
| 44 | - QHash<int, Counter> counters; | 42 | + QHash<QString, Counter> counters; |
| 45 | for (int i=0; i<predicted.size(); i++) { | 43 | for (int i=0; i<predicted.size(); i++) { |
| 46 | if (predicted[i].file.name != truth[i].file.name) | 44 | if (predicted[i].file.name != truth[i].file.name) |
| 47 | qFatal("Input order mismatch."); | 45 | qFatal("Input order mismatch."); |
| 48 | 46 | ||
| 49 | - const int trueLabel = truth[i].file.label(); | ||
| 50 | - const int predictedLabel = predicted[i].file.label(); | ||
| 51 | - if (trueLabel == predictedLabel) { | ||
| 52 | - counters[trueLabel].truePositive++; | ||
| 53 | - } else { | ||
| 54 | - counters[trueLabel].falseNegative++; | ||
| 55 | - counters[predictedLabel].falsePositive++; | 47 | + // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy. |
| 48 | + QStringList predictedSubjects = predicted[i].file.get<QStringList>("Subject"); | ||
| 49 | + QStringList trueSubjects = truth[i].file.get<QStringList>("Subject"); | ||
| 50 | + foreach (const QString &subject, trueSubjects.toVector() /* Hack to copy the list. */) { | ||
| 51 | + if (predictedSubjects.contains(subject)) { | ||
| 52 | + counters[subject].truePositive++; | ||
| 53 | + trueSubjects.removeOne(subject); | ||
| 54 | + predictedSubjects.removeOne(subject); | ||
| 55 | + } else { | ||
| 56 | + counters[subject].falseNegative++; | ||
| 57 | + } | ||
| 56 | } | 58 | } |
| 59 | + | ||
| 60 | + for (int i=0; i<trueSubjects.size(); i++) | ||
| 61 | + foreach (const QString &subject, predictedSubjects) | ||
| 62 | + counters[subject].falsePositive += 1.f / predictedSubjects.size(); | ||
| 57 | } | 63 | } |
| 58 | 64 | ||
| 59 | QSharedPointer<Output> output(Output::make("", FileList() << "Subject" << "Count" << "Precision" << "Recall" << "F-score", FileList(counters.size()))); | 65 | QSharedPointer<Output> output(Output::make("", FileList() << "Subject" << "Count" << "Precision" << "Recall" << "F-score", FileList(counters.size()))); |
| 60 | 66 | ||
| 61 | int tpc = 0; | 67 | int tpc = 0; |
| 62 | int fnc = 0; | 68 | int fnc = 0; |
| 69 | + const QStringList keys = counters.keys(); | ||
| 63 | for (int i=0; i<counters.size(); i++) { | 70 | for (int i=0; i<counters.size(); i++) { |
| 64 | - int trueLabel = counters.keys()[i]; | ||
| 65 | - const Counter &counter = counters[trueLabel]; | 71 | + const QString &subject = keys[i]; |
| 72 | + const Counter &counter = counters[subject]; | ||
| 66 | tpc += counter.truePositive; | 73 | tpc += counter.truePositive; |
| 67 | fnc += counter.falseNegative; | 74 | fnc += counter.falseNegative; |
| 68 | const int count = counter.truePositive + counter.falseNegative; | 75 | const int count = counter.truePositive + counter.falseNegative; |
| 69 | const float precision = counter.truePositive / (float)(counter.truePositive + counter.falsePositive); | 76 | const float precision = counter.truePositive / (float)(counter.truePositive + counter.falsePositive); |
| 70 | const float recall = counter.truePositive / (float)(counter.truePositive + counter.falseNegative); | 77 | const float recall = counter.truePositive / (float)(counter.truePositive + counter.falseNegative); |
| 71 | const float fscore = 2 * precision * recall / (precision + recall); | 78 | const float fscore = 2 * precision * recall / (precision + recall); |
| 72 | - output->setRelative(trueLabel, i, 0); | 79 | + output->setRelative(File("", subject).label(), i, 0); |
| 73 | output->setRelative(count, i, 1); | 80 | output->setRelative(count, i, 1); |
| 74 | output->setRelative(precision, i, 2); | 81 | output->setRelative(precision, i, 2); |
| 75 | output->setRelative(recall, i, 3); | 82 | output->setRelative(recall, i, 3); |
openbr/openbr_plugin.cpp
| @@ -1025,7 +1025,7 @@ MatrixOutput *MatrixOutput::make(const FileList &targetFiles, const FileList &qu | @@ -1025,7 +1025,7 @@ MatrixOutput *MatrixOutput::make(const FileList &targetFiles, const FileList &qu | ||
| 1025 | /* MatrixOutput - protected methods */ | 1025 | /* MatrixOutput - protected methods */ |
| 1026 | QString MatrixOutput::toString(int row, int column) const | 1026 | QString MatrixOutput::toString(int row, int column) const |
| 1027 | { | 1027 | { |
| 1028 | - if (targetFiles[column] == "Label") { | 1028 | + if (targetFiles[column] == "Subject") { |
| 1029 | const int label = data.at<float>(row,column); | 1029 | const int label = data.at<float>(row,column); |
| 1030 | return Globals->subjects.key(label, QString::number(label)); | 1030 | return Globals->subjects.key(label, QString::number(label)); |
| 1031 | } | 1031 | } |
openbr/plugins/algorithms.cpp
| @@ -50,7 +50,7 @@ class AlgorithmsInitializer : public Initializer | @@ -50,7 +50,7 @@ class AlgorithmsInitializer : public Initializer | ||
| 50 | Globals->abbreviations.insert("SmallSIFT", "Open+LimitSize(512)+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); | 50 | Globals->abbreviations.insert("SmallSIFT", "Open+LimitSize(512)+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); |
| 51 | Globals->abbreviations.insert("SmallSURF", "Open+LimitSize(512)+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)"); | 51 | Globals->abbreviations.insert("SmallSURF", "Open+LimitSize(512)+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)"); |
| 52 | Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):L2"); | 52 | Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):L2"); |
| 53 | - Globals->abbreviations.insert("ImageClassification", "Open+CropSquare+LimitSize(256)+Cvt(Gray)+Gradient+Bin(0,360,9,true)+Merge+Integral+RecursiveIntegralSampler(4,2,8,Singleton(KMeans(1024)[fraction=0.5]))+Cat+CvtFloat+Hist(1024)+KNS(5,Dist(L1))"); | 53 | + Globals->abbreviations.insert("ImageClassification", "Open+CropSquare+LimitSize(256)+Cvt(Gray)+Gradient+Bin(0,360,9,true)+Merge+Integral+RecursiveIntegralSampler(4,2,8,Singleton(KMeans(1024)[fraction=0.5]))+Cat+CvtFloat+Hist(1024)+KNS(5,Dist(L1))+Rename(KNS,Subject)"); |
| 54 | 54 | ||
| 55 | // Hash | 55 | // Hash |
| 56 | Globals->abbreviations.insert("FileName", "Name+Identity:Identical"); | 56 | Globals->abbreviations.insert("FileName", "Name+Identity:Identical"); |
openbr/plugins/cluster.cpp
| @@ -106,7 +106,8 @@ class KNSTransform : public Transform | @@ -106,7 +106,8 @@ class KNSTransform : public Transform | ||
| 106 | subjects.insert(gallery[sortedScores[i].second].file.subject()); | 106 | subjects.insert(gallery[sortedScores[i].second].file.subject()); |
| 107 | i++; | 107 | i++; |
| 108 | } | 108 | } |
| 109 | - dst.file.set("KNS", QStringList(subjects.toList()).join(", ")); | 109 | + const QStringList subjectList = subjects.toList(); |
| 110 | + dst.file.set("KNS", subjects.size() > 1 ? "[" + subjectList.join(",") + "]" : subjectList.first()); | ||
| 110 | } | 111 | } |
| 111 | 112 | ||
| 112 | void store(QDataStream &stream) const | 113 | void store(QDataStream &stream) const |