classify.cpp
5.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
* Copyright 2012 The MITRE Corporation *
* *
* Licensed under the Apache License, Version 2.0 (the "License"); *
* you may not use this file except in compliance with the License. *
* You may obtain a copy of the License at *
* *
* http://www.apache.org/licenses/LICENSE-2.0 *
* *
* Unless required by applicable law or agreed to in writing, software *
* distributed under the License is distributed on an "AS IS" BASIS, *
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
* See the License for the specific language governing permissions and *
* limitations under the License. *
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
#include <QDebug>
#include <QHash>
#include <openbr_plugin.h>
#include "classify.h"
#include "core/qtutils.h"
// Helper struct for statistics accumulation
struct Counter
{
int truePositive, falsePositive, falseNegative;
Counter()
{
truePositive = 0;
falsePositive = 0;
falseNegative = 0;
}
};
void br::EvalClassification(const QString &predictedInput, const QString &truthInput)
{
qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
TemplateList predicted(TemplateList::fromGallery(predictedInput));
TemplateList truth(TemplateList::fromGallery(truthInput));
if (predicted.size() != truth.size()) qFatal("Input size mismatch.");
QHash<int, Counter> counters;
for (int i=0; i<predicted.size(); i++) {
if (predicted[i].file.name != truth[i].file.name)
qFatal("Input order mismatch.");
const int trueLabel = truth[i].file.label();
const int predictedLabel = predicted[i].file.label();
if (trueLabel == predictedLabel) {
counters[trueLabel].truePositive++;
} else {
counters[trueLabel].falseNegative++;
counters[predictedLabel].falsePositive++;
}
}
QSharedPointer<Output> output(Output::make("", FileList() << "Label" << "Count" << "Precision" << "Recall" << "F-score", FileList(counters.size())));
int tpc = 0;
int fnc = 0;
for (int i=0; i<counters.size(); i++) {
int trueLabel = counters.keys()[i];
const Counter &counter = counters[trueLabel];
tpc += counter.truePositive;
fnc += counter.falseNegative;
const int count = counter.truePositive + counter.falseNegative;
const float precision = counter.truePositive / (float)(counter.truePositive + counter.falsePositive);
const float recall = counter.truePositive / (float)(counter.truePositive + counter.falseNegative);
const float fscore = 2 * precision * recall / (precision + recall);
output->setRelative(trueLabel, i, 0);
output->setRelative(count, i, 1);
output->setRelative(precision, i, 2);
output->setRelative(recall, i, 3);
output->setRelative(fscore, i, 4);
}
qDebug("Overall Accuracy = %f", (float)tpc / (float)(tpc + fnc));
}
void br::EvalRegression(const QString &predictedInput, const QString &truthInput)
{
qDebug("Evaluating regression of %s against %s", qPrintable(predictedInput), qPrintable(truthInput));
const TemplateList predicted(TemplateList::fromGallery(predictedInput));
const TemplateList truth(TemplateList::fromGallery(truthInput));
if (predicted.size() != truth.size()) qFatal("Input size mismatch.");
float rmsError = 0;
QStringList truthValues, predictedValues;
for (int i=0; i<predicted.size(); i++) {
if (predicted[i].file.name != truth[i].file.name)
qFatal("Input order mismatch.");
rmsError += pow(predicted[i].file.label()-truth[i].file.label(), 2.f);
truthValues.append(QString::number(truth[i].file.label()));
predictedValues.append(QString::number(predicted[i].file.label()));
}
QStringList rSource;
rSource << "# Load libraries" << "library(ggplot2)" << "" << "# Set Data"
<< "Actual <- c(" + truthValues.join(",") + ")"
<< "Predicted <- c(" + predictedValues.join(",") + ")"
<< "data <- data.frame(Actual, Predicted)"
<< "" << "# Construct Plot" << "pdf(\"EvalRegression.pdf\")"
<< "print(qplot(Actual, Predicted, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=1, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())"
<< "print(qplot(Actual, Predicted-Actual, data=data, geom=\"jitter\", alpha=I(2/3)) + geom_abline(intercept=0, slope=0, color=\"forestgreen\", size=I(1)) + geom_smooth(size=I(1), color=\"mediumblue\") + theme_bw())"
<< "dev.off()";
QString rFile = "EvalRegression.R";
QtUtils::writeFile(rFile, rSource);
bool success = QtUtils::runRScript(rFile);
if (success) QtUtils::showFile("EvalRegression.pdf");
qDebug("RMS Error = %f", sqrt(rmsError/predicted.size()));
}