/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * Copyright 2012 The MITRE Corporation * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ #include #include #include #include "classify.h" #include "core/qtutils.h" // Helper struct for statistics accumulation struct Counter { int truePositive, falsePositive, falseNegative; Counter() { truePositive = 0; falsePositive = 0; falseNegative = 0; } }; void br::EvalClassification(const QString &predictedInput, const QString &truthInput) { qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); TemplateList predicted(TemplateList::fromGallery(predictedInput)); TemplateList truth(TemplateList::fromGallery(truthInput)); if (predicted.size() != truth.size()) qFatal("Input size mismatch."); QHash counters; for (int i=0; i output(Output::make("", FileList() << "Label" << "Count" << "Precision" << "Recall" << "F-score", FileList(counters.size()))); int tpc = 0; int fnc = 0; for (int i=0; isetRelative(trueLabel, i, 0); output->setRelative(count, i, 1); output->setRelative(precision, i, 2); output->setRelative(recall, i, 3); output->setRelative(fscore, i, 4); } qDebug("Overall Accuracy = %f", (float)tpc / (float)(tpc + fnc)); } void br::EvalRegression(const QString &predictedInput, const QString &truthInput) { qDebug("Evaluating regression of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); const TemplateList predicted(TemplateList::fromGallery(predictedInput)); const TemplateList truth(TemplateList::fromGallery(truthInput)); if (predicted.size() != truth.size()) qFatal("Input size mismatch."); float rmsError = 0; QStringList truthValues, predictedValues; for (int i=0; i