Commit 6a7d402fe8c18637c8d1a6e5a46fee75108f980c
1 parent
19a3de11
generalized evalClassification
Showing
4 changed files
with
25 additions
and
17 deletions
openbr/core/classify.cpp
| ... | ... | @@ -14,8 +14,6 @@ |
| 14 | 14 | * limitations under the License. * |
| 15 | 15 | * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ |
| 16 | 16 | |
| 17 | -#include <QDebug> | |
| 18 | -#include <QHash> | |
| 19 | 17 | #include <openbr/openbr_plugin.h> |
| 20 | 18 | |
| 21 | 19 | #include "classify.h" |
| ... | ... | @@ -24,7 +22,7 @@ |
| 24 | 22 | // Helper struct for statistics accumulation |
| 25 | 23 | struct Counter |
| 26 | 24 | { |
| 27 | - int truePositive, falsePositive, falseNegative; | |
| 25 | + float truePositive, falsePositive, falseNegative; | |
| 28 | 26 | Counter() |
| 29 | 27 | { |
| 30 | 28 | truePositive = 0; |
| ... | ... | @@ -41,35 +39,44 @@ void br::EvalClassification(const QString &predictedInput, const QString &truthI |
| 41 | 39 | TemplateList truth(TemplateList::fromGallery(truthInput)); |
| 42 | 40 | if (predicted.size() != truth.size()) qFatal("Input size mismatch."); |
| 43 | 41 | |
| 44 | - QHash<int, Counter> counters; | |
| 42 | + QHash<QString, Counter> counters; | |
| 45 | 43 | for (int i=0; i<predicted.size(); i++) { |
| 46 | 44 | if (predicted[i].file.name != truth[i].file.name) |
| 47 | 45 | qFatal("Input order mismatch."); |
| 48 | 46 | |
| 49 | - const int trueLabel = truth[i].file.label(); | |
| 50 | - const int predictedLabel = predicted[i].file.label(); | |
| 51 | - if (trueLabel == predictedLabel) { | |
| 52 | - counters[trueLabel].truePositive++; | |
| 53 | - } else { | |
| 54 | - counters[trueLabel].falseNegative++; | |
| 55 | - counters[predictedLabel].falsePositive++; | |
| 47 | + // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy. | |
| 48 | + QStringList predictedSubjects = predicted[i].file.get<QStringList>("Subject"); | |
| 49 | + QStringList trueSubjects = truth[i].file.get<QStringList>("Subject"); | |
| 50 | + foreach (const QString &subject, trueSubjects.toVector() /* Hack to copy the list. */) { | |
| 51 | + if (predictedSubjects.contains(subject)) { | |
| 52 | + counters[subject].truePositive++; | |
| 53 | + trueSubjects.removeOne(subject); | |
| 54 | + predictedSubjects.removeOne(subject); | |
| 55 | + } else { | |
| 56 | + counters[subject].falseNegative++; | |
| 57 | + } | |
| 56 | 58 | } |
| 59 | + | |
| 60 | + for (int i=0; i<trueSubjects.size(); i++) | |
| 61 | + foreach (const QString &subject, predictedSubjects) | |
| 62 | + counters[subject].falsePositive += 1.f / predictedSubjects.size(); | |
| 57 | 63 | } |
| 58 | 64 | |
| 59 | 65 | QSharedPointer<Output> output(Output::make("", FileList() << "Subject" << "Count" << "Precision" << "Recall" << "F-score", FileList(counters.size()))); |
| 60 | 66 | |
| 61 | 67 | int tpc = 0; |
| 62 | 68 | int fnc = 0; |
| 69 | + const QStringList keys = counters.keys(); | |
| 63 | 70 | for (int i=0; i<counters.size(); i++) { |
| 64 | - int trueLabel = counters.keys()[i]; | |
| 65 | - const Counter &counter = counters[trueLabel]; | |
| 71 | + const QString &subject = keys[i]; | |
| 72 | + const Counter &counter = counters[subject]; | |
| 66 | 73 | tpc += counter.truePositive; |
| 67 | 74 | fnc += counter.falseNegative; |
| 68 | 75 | const int count = counter.truePositive + counter.falseNegative; |
| 69 | 76 | const float precision = counter.truePositive / (float)(counter.truePositive + counter.falsePositive); |
| 70 | 77 | const float recall = counter.truePositive / (float)(counter.truePositive + counter.falseNegative); |
| 71 | 78 | const float fscore = 2 * precision * recall / (precision + recall); |
| 72 | - output->setRelative(trueLabel, i, 0); | |
| 79 | + output->setRelative(File("", subject).label(), i, 0); | |
| 73 | 80 | output->setRelative(count, i, 1); |
| 74 | 81 | output->setRelative(precision, i, 2); |
| 75 | 82 | output->setRelative(recall, i, 3); | ... | ... |
openbr/openbr_plugin.cpp
| ... | ... | @@ -1025,7 +1025,7 @@ MatrixOutput *MatrixOutput::make(const FileList &targetFiles, const FileList &qu |
| 1025 | 1025 | /* MatrixOutput - protected methods */ |
| 1026 | 1026 | QString MatrixOutput::toString(int row, int column) const |
| 1027 | 1027 | { |
| 1028 | - if (targetFiles[column] == "Label") { | |
| 1028 | + if (targetFiles[column] == "Subject") { | |
| 1029 | 1029 | const int label = data.at<float>(row,column); |
| 1030 | 1030 | return Globals->subjects.key(label, QString::number(label)); |
| 1031 | 1031 | } | ... | ... |
openbr/plugins/algorithms.cpp
| ... | ... | @@ -50,7 +50,7 @@ class AlgorithmsInitializer : public Initializer |
| 50 | 50 | Globals->abbreviations.insert("SmallSIFT", "Open+LimitSize(512)+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); |
| 51 | 51 | Globals->abbreviations.insert("SmallSURF", "Open+LimitSize(512)+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)"); |
| 52 | 52 | Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):L2"); |
| 53 | - Globals->abbreviations.insert("ImageClassification", "Open+CropSquare+LimitSize(256)+Cvt(Gray)+Gradient+Bin(0,360,9,true)+Merge+Integral+RecursiveIntegralSampler(4,2,8,Singleton(KMeans(1024)[fraction=0.5]))+Cat+CvtFloat+Hist(1024)+KNS(5,Dist(L1))"); | |
| 53 | + Globals->abbreviations.insert("ImageClassification", "Open+CropSquare+LimitSize(256)+Cvt(Gray)+Gradient+Bin(0,360,9,true)+Merge+Integral+RecursiveIntegralSampler(4,2,8,Singleton(KMeans(1024)[fraction=0.5]))+Cat+CvtFloat+Hist(1024)+KNS(5,Dist(L1))+Rename(KNS,Subject)"); | |
| 54 | 54 | |
| 55 | 55 | // Hash |
| 56 | 56 | Globals->abbreviations.insert("FileName", "Name+Identity:Identical"); | ... | ... |
openbr/plugins/cluster.cpp
| ... | ... | @@ -106,7 +106,8 @@ class KNSTransform : public Transform |
| 106 | 106 | subjects.insert(gallery[sortedScores[i].second].file.subject()); |
| 107 | 107 | i++; |
| 108 | 108 | } |
| 109 | - dst.file.set("KNS", QStringList(subjects.toList()).join(", ")); | |
| 109 | + const QStringList subjectList = subjects.toList(); | |
| 110 | + dst.file.set("KNS", subjects.size() > 1 ? "[" + subjectList.join(",") + "]" : subjectList.first()); | |
| 110 | 111 | } |
| 111 | 112 | |
| 112 | 113 | void store(QDataStream &stream) const | ... | ... |