Commit da345e9d0139a8693792404f2dfecb6df892b774
Merge pull request #53 from biometrics/local_remap
Changes to subject/label
Showing
25 changed files
with
195 additions
and
421 deletions
app/examples/age_estimation.cpp
| @@ -29,7 +29,7 @@ | @@ -29,7 +29,7 @@ | ||
| 29 | 29 | ||
| 30 | static void printTemplate(const br::Template &t) | 30 | static void printTemplate(const br::Template &t) |
| 31 | { | 31 | { |
| 32 | - printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.label())); | 32 | + printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Subject"))); |
| 33 | } | 33 | } |
| 34 | 34 | ||
| 35 | int main(int argc, char *argv[]) | 35 | int main(int argc, char *argv[]) |
app/examples/gender_estimation.cpp
| @@ -29,7 +29,7 @@ | @@ -29,7 +29,7 @@ | ||
| 29 | 29 | ||
| 30 | static void printTemplate(const br::Template &t) | 30 | static void printTemplate(const br::Template &t) |
| 31 | { | 31 | { |
| 32 | - printf("%s gender: %s\n", qPrintable(t.file.fileName()), t.file.label() == 1 ? "Female" : "Male"); | 32 | + printf("%s gender: %s\n", qPrintable(t.file.fileName()), qPrintable(t.file.get<QString>("Subject"))); |
| 33 | } | 33 | } |
| 34 | 34 | ||
| 35 | int main(int argc, char *argv[]) | 35 | int main(int argc, char *argv[]) |
openbr/core/bee.cpp
| @@ -96,7 +96,7 @@ void BEE::writeSigset(const QString &sigset, const br::FileList &files, bool ign | @@ -96,7 +96,7 @@ void BEE::writeSigset(const QString &sigset, const br::FileList &files, bool ign | ||
| 96 | if ((key == "Index") || (key == "Subject")) continue; | 96 | if ((key == "Index") || (key == "Subject")) continue; |
| 97 | metadata.append(key+"=\""+QtUtils::toString(file.value(key))+"\""); | 97 | metadata.append(key+"=\""+QtUtils::toString(file.value(key))+"\""); |
| 98 | } | 98 | } |
| 99 | - lines.append("\t<biometric-signature name=\"" + file.subject() +"\">"); | 99 | + lines.append("\t<biometric-signature name=\"" + file.get<QString>("Subject") +"\">"); |
| 100 | lines.append("\t\t<presentation file-name=\"" + file.name + "\" " + metadata.join(" ") + "/>"); | 100 | lines.append("\t\t<presentation file-name=\"" + file.name + "\" " + metadata.join(" ") + "/>"); |
| 101 | lines.append("\t</biometric-signature>"); | 101 | lines.append("\t</biometric-signature>"); |
| 102 | } | 102 | } |
| @@ -260,26 +260,28 @@ void BEE::makeMask(const QString &targetInput, const QString &queryInput, const | @@ -260,26 +260,28 @@ void BEE::makeMask(const QString &targetInput, const QString &queryInput, const | ||
| 260 | 260 | ||
| 261 | cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition) | 261 | cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition) |
| 262 | { | 262 | { |
| 263 | - QList<float> targetLabels = targets.labels(); | ||
| 264 | - QList<float> queryLabels = queries.labels(); | 263 | + // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet |
| 264 | + // -cao | ||
| 265 | + QList<QString> targetLabels = targets.get<QString>("Subject", "-1"); | ||
| 266 | + QList<QString> queryLabels = queries.get<QString>("Subject", "-1"); | ||
| 265 | QList<int> targetPartitions = targets.crossValidationPartitions(); | 267 | QList<int> targetPartitions = targets.crossValidationPartitions(); |
| 266 | QList<int> queryPartitions = queries.crossValidationPartitions(); | 268 | QList<int> queryPartitions = queries.crossValidationPartitions(); |
| 267 | 269 | ||
| 268 | Mat mask(queries.size(), targets.size(), CV_8UC1); | 270 | Mat mask(queries.size(), targets.size(), CV_8UC1); |
| 269 | for (int i=0; i<queries.size(); i++) { | 271 | for (int i=0; i<queries.size(); i++) { |
| 270 | const QString &fileA = queries[i]; | 272 | const QString &fileA = queries[i]; |
| 271 | - const int labelA = queryLabels[i]; | 273 | + const QString labelA = queryLabels[i]; |
| 272 | const int partitionA = queryPartitions[i]; | 274 | const int partitionA = queryPartitions[i]; |
| 273 | 275 | ||
| 274 | for (int j=0; j<targets.size(); j++) { | 276 | for (int j=0; j<targets.size(); j++) { |
| 275 | const QString &fileB = targets[j]; | 277 | const QString &fileB = targets[j]; |
| 276 | - const int labelB = targetLabels[j]; | 278 | + const QString labelB = targetLabels[j]; |
| 277 | const int partitionB = targetPartitions[j]; | 279 | const int partitionB = targetPartitions[j]; |
| 278 | 280 | ||
| 279 | Mask_t val; | 281 | Mask_t val; |
| 280 | if (fileA == fileB) val = DontCare; | 282 | if (fileA == fileB) val = DontCare; |
| 281 | - else if (labelA == -1) val = DontCare; | ||
| 282 | - else if (labelB == -1) val = DontCare; | 283 | + else if (labelA == "-1") val = DontCare; |
| 284 | + else if (labelB == "-1") val = DontCare; | ||
| 283 | else if (partitionA != partition) val = DontCare; | 285 | else if (partitionA != partition) val = DontCare; |
| 284 | else if (partitionB == -1) val = NonMatch; | 286 | else if (partitionB == -1) val = NonMatch; |
| 285 | else if (partitionB != partition) val = DontCare; | 287 | else if (partitionB != partition) val = DontCare; |
openbr/core/classify.cpp
| @@ -45,8 +45,8 @@ void br::EvalClassification(const QString &predictedInput, const QString &truthI | @@ -45,8 +45,8 @@ void br::EvalClassification(const QString &predictedInput, const QString &truthI | ||
| 45 | qFatal("Input order mismatch."); | 45 | qFatal("Input order mismatch."); |
| 46 | 46 | ||
| 47 | // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy. | 47 | // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy. |
| 48 | - QString predictedSubject = predicted[i].file.subject(); | ||
| 49 | - QString trueSubject = truth[i].file.subject(); | 48 | + QString predictedSubject = predicted[i].file.get<QString>("Subject"); |
| 49 | + QString trueSubject = truth[i].file.get<QString>("Subject"); | ||
| 50 | 50 | ||
| 51 | QStringList predictedSubjects(predictedSubject); | 51 | QStringList predictedSubjects(predictedSubject); |
| 52 | QStringList trueSubjects(trueSubject); | 52 | QStringList trueSubjects(trueSubject); |
| @@ -66,11 +66,12 @@ void br::EvalClassification(const QString &predictedInput, const QString &truthI | @@ -66,11 +66,12 @@ void br::EvalClassification(const QString &predictedInput, const QString &truthI | ||
| 66 | counters[subject].falsePositive += 1.f / predictedSubjects.size(); | 66 | counters[subject].falsePositive += 1.f / predictedSubjects.size(); |
| 67 | } | 67 | } |
| 68 | 68 | ||
| 69 | - QSharedPointer<Output> output(Output::make("", FileList() << "Subject" << "Count" << "Precision" << "Recall" << "F-score", FileList(counters.size()))); | 69 | + const QStringList keys = counters.keys(); |
| 70 | + QSharedPointer<Output> output(Output::make("", FileList() << "Count" << "Precision" << "Recall" << "F-score", FileList(keys))); | ||
| 70 | 71 | ||
| 71 | int tpc = 0; | 72 | int tpc = 0; |
| 72 | int fnc = 0; | 73 | int fnc = 0; |
| 73 | - const QStringList keys = counters.keys(); | 74 | + |
| 74 | for (int i=0; i<counters.size(); i++) { | 75 | for (int i=0; i<counters.size(); i++) { |
| 75 | const QString &subject = keys[i]; | 76 | const QString &subject = keys[i]; |
| 76 | const Counter &counter = counters[subject]; | 77 | const Counter &counter = counters[subject]; |
| @@ -80,11 +81,10 @@ void br::EvalClassification(const QString &predictedInput, const QString &truthI | @@ -80,11 +81,10 @@ void br::EvalClassification(const QString &predictedInput, const QString &truthI | ||
| 80 | const float precision = counter.truePositive / (float)(counter.truePositive + counter.falsePositive); | 81 | const float precision = counter.truePositive / (float)(counter.truePositive + counter.falsePositive); |
| 81 | const float recall = counter.truePositive / (float)(counter.truePositive + counter.falseNegative); | 82 | const float recall = counter.truePositive / (float)(counter.truePositive + counter.falseNegative); |
| 82 | const float fscore = 2 * precision * recall / (precision + recall); | 83 | const float fscore = 2 * precision * recall / (precision + recall); |
| 83 | - output->setRelative(File("", subject).label(), i, 0); | ||
| 84 | - output->setRelative(count, i, 1); | ||
| 85 | - output->setRelative(precision, i, 2); | ||
| 86 | - output->setRelative(recall, i, 3); | ||
| 87 | - output->setRelative(fscore, i, 4); | 84 | + output->setRelative(count, i, 0); |
| 85 | + output->setRelative(precision, i, 1); | ||
| 86 | + output->setRelative(recall, i, 2); | ||
| 87 | + output->setRelative(fscore, i, 3); | ||
| 88 | } | 88 | } |
| 89 | 89 | ||
| 90 | qDebug("Overall Accuracy = %f", (float)tpc / (float)(tpc + fnc)); | 90 | qDebug("Overall Accuracy = %f", (float)tpc / (float)(tpc + fnc)); |
| @@ -103,9 +103,9 @@ void br::EvalRegression(const QString &predictedInput, const QString &truthInput | @@ -103,9 +103,9 @@ void br::EvalRegression(const QString &predictedInput, const QString &truthInput | ||
| 103 | for (int i=0; i<predicted.size(); i++) { | 103 | for (int i=0; i<predicted.size(); i++) { |
| 104 | if (predicted[i].file.name != truth[i].file.name) | 104 | if (predicted[i].file.name != truth[i].file.name) |
| 105 | qFatal("Input order mismatch."); | 105 | qFatal("Input order mismatch."); |
| 106 | - rmsError += pow(predicted[i].file.label()-truth[i].file.label(), 2.f); | ||
| 107 | - truthValues.append(QString::number(truth[i].file.label())); | ||
| 108 | - predictedValues.append(QString::number(predicted[i].file.label())); | 106 | + rmsError += pow(predicted[i].file.get<float>("Subject")-truth[i].file.get<float>("Subject"), 2.f); |
| 107 | + truthValues.append(QString::number(truth[i].file.get<float>("Subject"))); | ||
| 108 | + predictedValues.append(QString::number(predicted[i].file.get<float>("Subject"))); | ||
| 109 | } | 109 | } |
| 110 | 110 | ||
| 111 | QStringList rSource; | 111 | QStringList rSource; |
openbr/core/cluster.cpp
| @@ -278,7 +278,9 @@ void br::EvalClustering(const QString &csv, const QString &input) | @@ -278,7 +278,9 @@ void br::EvalClustering(const QString &csv, const QString &input) | ||
| 278 | { | 278 | { |
| 279 | qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input)); | 279 | qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input)); |
| 280 | 280 | ||
| 281 | - QList<float> labels = TemplateList::fromGallery(input).files().labels(); | 281 | + // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are |
| 282 | + // not named). | ||
| 283 | + QList<int> labels = TemplateList::fromGallery(input).files().get<int>("Subject"); | ||
| 282 | 284 | ||
| 283 | QHash<int, int> labelToIndex; | 285 | QHash<int, int> labelToIndex; |
| 284 | int nClusters = 0; | 286 | int nClusters = 0; |
openbr/core/core.cpp
| @@ -78,7 +78,6 @@ struct AlgorithmCore | @@ -78,7 +78,6 @@ struct AlgorithmCore | ||
| 78 | const bool hasComparer = !distance.isNull(); | 78 | const bool hasComparer = !distance.isNull(); |
| 79 | out << hasComparer; | 79 | out << hasComparer; |
| 80 | if (hasComparer) distance->store(out); | 80 | if (hasComparer) distance->store(out); |
| 81 | - out << Globals->subjects; | ||
| 82 | 81 | ||
| 83 | // Compress and save to file | 82 | // Compress and save to file |
| 84 | QtUtils::writeFile(model, data, -1); | 83 | QtUtils::writeFile(model, data, -1); |
| @@ -98,7 +97,6 @@ struct AlgorithmCore | @@ -98,7 +97,6 @@ struct AlgorithmCore | ||
| 98 | transform->load(in); | 97 | transform->load(in); |
| 99 | bool hasDistance; in >> hasDistance; | 98 | bool hasDistance; in >> hasDistance; |
| 100 | if (hasDistance) distance->load(in); | 99 | if (hasDistance) distance->load(in); |
| 101 | - in >> Globals->subjects; | ||
| 102 | } | 100 | } |
| 103 | 101 | ||
| 104 | File getMemoryGallery(const File &file) const | 102 | File getMemoryGallery(const File &file) const |
openbr/core/opencvutils.cpp
| @@ -119,6 +119,17 @@ Mat OpenCVUtils::toMat(const QList<float> &src, int rows) | @@ -119,6 +119,17 @@ Mat OpenCVUtils::toMat(const QList<float> &src, int rows) | ||
| 119 | return dst; | 119 | return dst; |
| 120 | } | 120 | } |
| 121 | 121 | ||
| 122 | +Mat OpenCVUtils::toMat(const QList<int> &src, int rows) | ||
| 123 | +{ | ||
| 124 | + if (rows == -1) rows = src.size(); | ||
| 125 | + int columns = src.isEmpty() ? 0 : src.size() / rows; | ||
| 126 | + if (rows*columns != src.size()) qFatal("Invalid matrix size."); | ||
| 127 | + Mat dst(rows, columns, CV_32FC1); | ||
| 128 | + for (int i=0; i<src.size(); i++) | ||
| 129 | + dst.at<float>(i/columns,i%columns) = src[i]; | ||
| 130 | + return dst; | ||
| 131 | +} | ||
| 132 | + | ||
| 122 | Mat OpenCVUtils::toMat(const QList<Mat> &src) | 133 | Mat OpenCVUtils::toMat(const QList<Mat> &src) |
| 123 | { | 134 | { |
| 124 | if (src.isEmpty()) return Mat(); | 135 | if (src.isEmpty()) return Mat(); |
openbr/core/opencvutils.h
| @@ -35,6 +35,8 @@ namespace OpenCVUtils | @@ -35,6 +35,8 @@ namespace OpenCVUtils | ||
| 35 | 35 | ||
| 36 | // To image | 36 | // To image |
| 37 | cv::Mat toMat(const QList<float> &src, int rows = -1); | 37 | cv::Mat toMat(const QList<float> &src, int rows = -1); |
| 38 | + cv::Mat toMat(const QList<int> &src, int rows = -1); | ||
| 39 | + | ||
| 38 | cv::Mat toMat(const QList<cv::Mat> &src); // Data organized one matrix per row | 40 | cv::Mat toMat(const QList<cv::Mat> &src); // Data organized one matrix per row |
| 39 | cv::Mat toMatByRow(const QList<cv::Mat> &src); // Data organized one row per row | 41 | cv::Mat toMatByRow(const QList<cv::Mat> &src); // Data organized one row per row |
| 40 | 42 |
openbr/frvt2012.cpp
| @@ -132,7 +132,7 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &input_face, int32_t &age) | @@ -132,7 +132,7 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &input_face, int32_t &age) | ||
| 132 | TemplateList templates; | 132 | TemplateList templates; |
| 133 | templates.append(templateFromONEFACE(input_face)); | 133 | templates.append(templateFromONEFACE(input_face)); |
| 134 | templates >> *frvt2012_age_transform.data(); | 134 | templates >> *frvt2012_age_transform.data(); |
| 135 | - age = templates.first().file.label(); | 135 | + age = templates.first().file.get<float>("Subject"); |
| 136 | return templates.first().file.failed() ? 4 : 0; | 136 | return templates.first().file.failed() ? 4 : 0; |
| 137 | } | 137 | } |
| 138 | 138 | ||
| @@ -141,6 +141,6 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &input_face, int8_t &gender, | @@ -141,6 +141,6 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &input_face, int8_t &gender, | ||
| 141 | TemplateList templates; | 141 | TemplateList templates; |
| 142 | templates.append(templateFromONEFACE(input_face)); | 142 | templates.append(templateFromONEFACE(input_face)); |
| 143 | templates >> *frvt2012_gender_transform.data(); | 143 | templates >> *frvt2012_gender_transform.data(); |
| 144 | - mf = gender = templates.first().file.label(); | 144 | + mf = gender = templates.first().file.get<QString>("Subject") == "Male" ? 0 : 1; |
| 145 | return templates.first().file.failed() ? 4 : 0; | 145 | return templates.first().file.failed() ? 4 : 0; |
| 146 | } | 146 | } |
openbr/gui/classifier.cpp
| @@ -44,13 +44,13 @@ void Classifier::_classify(File file) | @@ -44,13 +44,13 @@ void Classifier::_classify(File file) | ||
| 44 | 44 | ||
| 45 | if (algorithm == "GenderClassification") { | 45 | if (algorithm == "GenderClassification") { |
| 46 | key = "Gender"; | 46 | key = "Gender"; |
| 47 | - value = (f.label() == 0 ? "Male" : "Female"); | 47 | + value = (f.get<QString>("Subject")); |
| 48 | } else if (algorithm == "AgeRegression") { | 48 | } else if (algorithm == "AgeRegression") { |
| 49 | key = "Age"; | 49 | key = "Age"; |
| 50 | - value = QString::number(int(f.label()+0.5)) + " Years"; | 50 | + value = QString::number(int(f.get<float>("Subject")+0.5)) + " Years"; |
| 51 | } else { | 51 | } else { |
| 52 | key = algorithm; | 52 | key = algorithm; |
| 53 | - value = QString::number(f.label()); | 53 | + value = f.get<QString>("Subject"); |
| 54 | } | 54 | } |
| 55 | break; | 55 | break; |
| 56 | } | 56 | } |
openbr/openbr_plugin.cpp
| @@ -167,34 +167,6 @@ bool File::getBool(const QString &key, bool defaultValue) const | @@ -167,34 +167,6 @@ bool File::getBool(const QString &key, bool defaultValue) const | ||
| 167 | return variant.value<bool>(); | 167 | return variant.value<bool>(); |
| 168 | } | 168 | } |
| 169 | 169 | ||
| 170 | -QString File::subject() const | ||
| 171 | -{ | ||
| 172 | - const QVariant l = m_metadata.value("Label"); | ||
| 173 | - if (!l.isNull()) return Globals->subjects.key(l.toFloat(), l.toString()); | ||
| 174 | - return m_metadata.value("Subject").toString(); | ||
| 175 | -} | ||
| 176 | - | ||
| 177 | -float File::label() const | ||
| 178 | -{ | ||
| 179 | - const QVariant l = m_metadata.value("Label"); | ||
| 180 | - if (!l.isNull()) return l.toFloat(); | ||
| 181 | - | ||
| 182 | - const QVariant s = m_metadata.value("Subject"); | ||
| 183 | - if (s.isNull()) return -1; | ||
| 184 | - | ||
| 185 | - const QString subject = s.toString(); | ||
| 186 | - | ||
| 187 | - bool is_num = false; | ||
| 188 | - float num = subject.toFloat(&is_num); | ||
| 189 | - if (is_num) return num; | ||
| 190 | - | ||
| 191 | - static QMutex mutex; | ||
| 192 | - QMutexLocker mutexLocker(&mutex); | ||
| 193 | - if (!Globals->subjects.contains(subject)) | ||
| 194 | - Globals->subjects.insert(subject, Globals->subjects.size()); | ||
| 195 | - return Globals->subjects.value(subject); | ||
| 196 | -} | ||
| 197 | - | ||
| 198 | QList<QPointF> File::namedPoints() const | 170 | QList<QPointF> File::namedPoints() const |
| 199 | { | 171 | { |
| 200 | QList<QPointF> landmarks; | 172 | QList<QPointF> landmarks; |
| @@ -360,14 +332,6 @@ void FileList::sort(const QString& key) | @@ -360,14 +332,6 @@ void FileList::sort(const QString& key) | ||
| 360 | *this = sortedList; | 332 | *this = sortedList; |
| 361 | } | 333 | } |
| 362 | 334 | ||
| 363 | -QList<float> FileList::labels() const | ||
| 364 | -{ | ||
| 365 | - QList<float> labels; labels.reserve(size()); | ||
| 366 | - foreach (const File &f, *this) | ||
| 367 | - labels.append(f.label()); | ||
| 368 | - return labels; | ||
| 369 | -} | ||
| 370 | - | ||
| 371 | QList<int> FileList::crossValidationPartitions() const | 335 | QList<int> FileList::crossValidationPartitions() const |
| 372 | { | 336 | { |
| 373 | QList<int> crossValidationPartitions; crossValidationPartitions.reserve(size()); | 337 | QList<int> crossValidationPartitions; crossValidationPartitions.reserve(size()); |
| @@ -449,7 +413,7 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | @@ -449,7 +413,7 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | ||
| 449 | 413 | ||
| 450 | newTemplates[i].file.set("Partition", -1); | 414 | newTemplates[i].file.set("Partition", -1); |
| 451 | } else { | 415 | } else { |
| 452 | - const QByteArray md5 = QCryptographicHash::hash(newTemplates[i].file.subject().toLatin1(), QCryptographicHash::Md5); | 416 | + const QByteArray md5 = QCryptographicHash::hash(newTemplates[i].file.get<QString>("Subject").toLatin1(), QCryptographicHash::Md5); |
| 453 | // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow | 417 | // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow |
| 454 | newTemplates[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); | 418 | newTemplates[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); |
| 455 | } | 419 | } |
| @@ -469,11 +433,13 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | @@ -469,11 +433,13 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | ||
| 469 | return templates; | 433 | return templates; |
| 470 | } | 434 | } |
| 471 | 435 | ||
| 472 | -TemplateList TemplateList::relabel(const TemplateList &tl) | 436 | +// indexes some property, assigns an integer id to each unique value of propName |
| 437 | +// stores the index values in "Label" of the output template list | ||
| 438 | +TemplateList TemplateList::relabel(const TemplateList &tl, const QString & propName) | ||
| 473 | { | 439 | { |
| 474 | - const QList<int> originalLabels = tl.labels<int>(); | ||
| 475 | - QHash<int,int> labelTable; | ||
| 476 | - foreach (int label, originalLabels) | 440 | + const QList<QString> originalLabels = tl.get<QString>(propName); |
| 441 | + QHash<QString,int> labelTable; | ||
| 442 | + foreach (const QString & label, originalLabels) | ||
| 477 | if (!labelTable.contains(label)) | 443 | if (!labelTable.contains(label)) |
| 478 | labelTable.insert(label, labelTable.size()); | 444 | labelTable.insert(label, labelTable.size()); |
| 479 | 445 | ||
| @@ -483,6 +449,52 @@ TemplateList TemplateList::relabel(const TemplateList &tl) | @@ -483,6 +449,52 @@ TemplateList TemplateList::relabel(const TemplateList &tl) | ||
| 483 | return result; | 449 | return result; |
| 484 | } | 450 | } |
| 485 | 451 | ||
| 452 | +QList<int> TemplateList::indexProperty(const QString & propName, QHash<QString, int> * valueMap,QHash<int, QVariant> * reverseLookup) const | ||
| 453 | +{ | ||
| 454 | + QHash<QString, int> dummyForwards; | ||
| 455 | + QHash<int, QVariant> dummyBackwards; | ||
| 456 | + | ||
| 457 | + if (!valueMap) valueMap = &dummyForwards; | ||
| 458 | + if (!reverseLookup) reverseLookup = &dummyBackwards; | ||
| 459 | + | ||
| 460 | + return indexProperty(propName, *valueMap, *reverseLookup); | ||
| 461 | +} | ||
| 462 | + | ||
| 463 | +QList<int> TemplateList::indexProperty(const QString & propName, QHash<QString, int> & valueMap, QHash<int, QVariant> & reverseLookup) const | ||
| 464 | +{ | ||
| 465 | + valueMap.clear(); | ||
| 466 | + reverseLookup.clear(); | ||
| 467 | + | ||
| 468 | + const QList<QVariant> originalLabels = values(propName); | ||
| 469 | + foreach (const QVariant & label, originalLabels) { | ||
| 470 | + QString labelString = label.toString(); | ||
| 471 | + if (!valueMap.contains(labelString)) { | ||
| 472 | + reverseLookup.insert(valueMap.size(), label); | ||
| 473 | + valueMap.insert(labelString, valueMap.size()); | ||
| 474 | + } | ||
| 475 | + } | ||
| 476 | + | ||
| 477 | + QList<int> result; | ||
| 478 | + for (int i=0; i<originalLabels.size(); i++) | ||
| 479 | + result.append(valueMap[originalLabels[i].toString()]); | ||
| 480 | + | ||
| 481 | + return result; | ||
| 482 | +} | ||
| 483 | + | ||
| 484 | +// uses -1 for missing values | ||
| 485 | +QList<int> TemplateList::applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const | ||
| 486 | +{ | ||
| 487 | + const QList<QString> originalLabels = get<QString>(propName); | ||
| 488 | + | ||
| 489 | + QList<int> result; | ||
| 490 | + for (int i=0; i<originalLabels.size(); i++) { | ||
| 491 | + if (!valueMap.contains(originalLabels[i])) result.append(-1); | ||
| 492 | + else result.append(valueMap[originalLabels[i]]); | ||
| 493 | + } | ||
| 494 | + | ||
| 495 | + return result; | ||
| 496 | +} | ||
| 497 | + | ||
| 486 | /* Object - public methods */ | 498 | /* Object - public methods */ |
| 487 | QStringList Object::parameters() const | 499 | QStringList Object::parameters() const |
| 488 | { | 500 | { |
| @@ -1021,10 +1033,6 @@ MatrixOutput *MatrixOutput::make(const FileList &targetFiles, const FileList &qu | @@ -1021,10 +1033,6 @@ MatrixOutput *MatrixOutput::make(const FileList &targetFiles, const FileList &qu | ||
| 1021 | /* MatrixOutput - protected methods */ | 1033 | /* MatrixOutput - protected methods */ |
| 1022 | QString MatrixOutput::toString(int row, int column) const | 1034 | QString MatrixOutput::toString(int row, int column) const |
| 1023 | { | 1035 | { |
| 1024 | - if (targetFiles[column] == "Subject") { | ||
| 1025 | - const int label = data.at<float>(row,column); | ||
| 1026 | - return Globals->subjects.key(label, QString::number(label)); | ||
| 1027 | - } | ||
| 1028 | return QString::number(data.at<float>(row,column)); | 1036 | return QString::number(data.at<float>(row,column)); |
| 1029 | } | 1037 | } |
| 1030 | 1038 |
openbr/openbr_plugin.h
| @@ -254,8 +254,6 @@ struct BR_EXPORT File | @@ -254,8 +254,6 @@ struct BR_EXPORT File | ||
| 254 | return variant.value<T>(); | 254 | return variant.value<T>(); |
| 255 | } | 255 | } |
| 256 | 256 | ||
| 257 | - QString subject() const; /*!< \brief Looks up the subject from the file's label. */ | ||
| 258 | - float label() const; /*!< \brief Convenience function for retrieving the file's \c Label. */ | ||
| 259 | inline bool failed() const { return getBool("FTE") || getBool("FTO"); } /*!< \brief Returns \c true if the file failed to open or enroll, \c false otherwise. */ | 257 | inline bool failed() const { return getBool("FTE") || getBool("FTO"); } /*!< \brief Returns \c true if the file failed to open or enroll, \c false otherwise. */ |
| 260 | 258 | ||
| 261 | QList<QPointF> namedPoints() const; /*!< \brief Returns points convertible from metadata keys. */ | 259 | QList<QPointF> namedPoints() const; /*!< \brief Returns points convertible from metadata keys. */ |
| @@ -299,7 +297,24 @@ struct BR_EXPORT FileList : public QList<File> | @@ -299,7 +297,24 @@ struct BR_EXPORT FileList : public QList<File> | ||
| 299 | QStringList flat() const; /*!< \brief Returns br::File::flat() for each file in the list. */ | 297 | QStringList flat() const; /*!< \brief Returns br::File::flat() for each file in the list. */ |
| 300 | QStringList names() const; /*!< \brief Returns #br::File::name for each file in the list. */ | 298 | QStringList names() const; /*!< \brief Returns #br::File::name for each file in the list. */ |
| 301 | void sort(const QString& key); /*!< \brief Sort the list based on metadata. */ | 299 | void sort(const QString& key); /*!< \brief Sort the list based on metadata. */ |
| 302 | - QList<float> labels() const; /*!< \brief Returns br::File::label() for each file in the list. */ | 300 | + /*!< \brief Returns values associated with the input propName for each file in the list. */ |
| 301 | + template<typename T> | ||
| 302 | + QList<T> get(const QString & propName) const | ||
| 303 | + { | ||
| 304 | + QList<T> values; values.reserve(size()); | ||
| 305 | + foreach (const File &f, *this) | ||
| 306 | + values.append(f.get<T>(propName)); | ||
| 307 | + return values; | ||
| 308 | + } | ||
| 309 | + template<typename T> | ||
| 310 | + QList<T> get(const QString & propName, T defaultValue) const | ||
| 311 | + { | ||
| 312 | + QList<T> values; values.reserve(size()); | ||
| 313 | + foreach (const File &f, *this) | ||
| 314 | + values.append(f.contains(propName) ? f.get<T>(propName) : defaultValue); | ||
| 315 | + return values; | ||
| 316 | + } | ||
| 317 | + | ||
| 303 | QList<int> crossValidationPartitions() const; /*!< \brief Returns the cross-validation partition (default=0) for each file in the list. */ | 318 | QList<int> crossValidationPartitions() const; /*!< \brief Returns the cross-validation partition (default=0) for each file in the list. */ |
| 304 | int failures() const; /*!< \brief Returns the number of files with br::File::failed(). */ | 319 | int failures() const; /*!< \brief Returns the number of files with br::File::failed(). */ |
| 305 | }; | 320 | }; |
| @@ -383,7 +398,14 @@ struct TemplateList : public QList<Template> | @@ -383,7 +398,14 @@ struct TemplateList : public QList<Template> | ||
| 383 | TemplateList(const QList<Template> &templates) : uniform(false) { append(templates); } /*!< \brief Initialize the template list from another template list. */ | 398 | TemplateList(const QList<Template> &templates) : uniform(false) { append(templates); } /*!< \brief Initialize the template list from another template list. */ |
| 384 | TemplateList(const QList<File> &files) : uniform(false) { foreach (const File &file, files) append(file); } /*!< \brief Initialize the template list from a file list. */ | 399 | TemplateList(const QList<File> &files) : uniform(false) { foreach (const File &file, files) append(file); } /*!< \brief Initialize the template list from a file list. */ |
| 385 | BR_EXPORT static TemplateList fromGallery(const File &gallery); /*!< \brief Create a template list from a br::Gallery. */ | 400 | BR_EXPORT static TemplateList fromGallery(const File &gallery); /*!< \brief Create a template list from a br::Gallery. */ |
| 386 | - BR_EXPORT static TemplateList relabel(const TemplateList &tl); /*!< \brief Ensure labels are in the range [0,numClasses-1]. */ | 401 | + |
| 402 | + /*!< \brief Ensure labels are in the range [0,numClasses-1]. */ | ||
| 403 | + BR_EXPORT static TemplateList relabel(const TemplateList & tl, const QString & propName); | ||
| 404 | + | ||
| 405 | + QList<int> indexProperty(const QString & propName, QHash<QString, int> * valueMap=NULL,QHash<int, QVariant> * reverseLookup = NULL) const; | ||
| 406 | + QList<int> indexProperty(const QString & propName, QHash<QString, int> & valueMap, QHash<int, QVariant> & reverseLookup) const; | ||
| 407 | + QList<int> applyIndex(const QString & propName, const QHash<QString, int> & valueMap) const; | ||
| 408 | + | ||
| 387 | 409 | ||
| 388 | /*! | 410 | /*! |
| 389 | * \brief Returns the total number of bytes in all the templates. | 411 | * \brief Returns the total number of bytes in all the templates. |
| @@ -457,23 +479,30 @@ struct TemplateList : public QList<Template> | @@ -457,23 +479,30 @@ struct TemplateList : public QList<Template> | ||
| 457 | /*! | 479 | /*! |
| 458 | * \brief Returns br::Template::label() for each template in the list. | 480 | * \brief Returns br::Template::label() for each template in the list. |
| 459 | */ | 481 | */ |
| 460 | - template <typename T> | ||
| 461 | - QList<T> labels() const | 482 | + template<typename T> |
| 483 | + QList<T> get(const QString & propName) const | ||
| 484 | + { | ||
| 485 | + QList<T> values; values.reserve(size()); | ||
| 486 | + foreach (const Template &t, *this) values.append(t.file.get<T>(propName)); | ||
| 487 | + return values; | ||
| 488 | + } | ||
| 489 | + QList<QVariant> values(const QString & propName) const | ||
| 462 | { | 490 | { |
| 463 | - QList<T> labels; labels.reserve(size()); | ||
| 464 | - foreach (const Template &t, *this) labels.append(t.file.label()); | ||
| 465 | - return labels; | 491 | + QList<QVariant> values; values.reserve(size()); |
| 492 | + foreach (const Template &t, *this) values.append(t.file.value(propName)); | ||
| 493 | + return values; | ||
| 466 | } | 494 | } |
| 467 | 495 | ||
| 468 | /*! | 496 | /*! |
| 469 | * \brief Returns the number of occurences for each label in the list. | 497 | * \brief Returns the number of occurences for each label in the list. |
| 470 | */ | 498 | */ |
| 471 | - QMap<int,int> labelCounts(bool excludeFailures = false) const | 499 | + template<typename T> |
| 500 | + QMap<T,int> countValues(const QString & propName, bool excludeFailures = false) const | ||
| 472 | { | 501 | { |
| 473 | - QMap<int, int> labelCounts; | 502 | + QMap<T, int> labelCounts; |
| 474 | foreach (const File &file, files()) | 503 | foreach (const File &file, files()) |
| 475 | if (!excludeFailures || !file.failed()) | 504 | if (!excludeFailures || !file.failed()) |
| 476 | - labelCounts[file.label()]++; | 505 | + labelCounts[file.get<T>(propName)]++; |
| 477 | return labelCounts; | 506 | return labelCounts; |
| 478 | } | 507 | } |
| 479 | 508 | ||
| @@ -649,7 +678,6 @@ public: | @@ -649,7 +678,6 @@ public: | ||
| 649 | BR_PROPERTY(int, crossValidate, 0) | 678 | BR_PROPERTY(int, crossValidate, 0) |
| 650 | 679 | ||
| 651 | QHash<QString,QString> abbreviations; /*!< \brief Used by br::Transform::make() to expand abbreviated algorithms into their complete definitions. */ | 680 | QHash<QString,QString> abbreviations; /*!< \brief Used by br::Transform::make() to expand abbreviated algorithms into their complete definitions. */ |
| 652 | - QHash<QString,int> subjects; /*!< \brief Used by classifiers to associate text class labels with unique integers IDs. */ | ||
| 653 | QTime startTime; /*!< \brief Used to estimate timeRemaining(). */ | 681 | QTime startTime; /*!< \brief Used to estimate timeRemaining(). */ |
| 654 | 682 | ||
| 655 | /*! | 683 | /*! |
openbr/plugins/cluster.cpp
| @@ -111,13 +111,13 @@ class KNNTransform : public Transform | @@ -111,13 +111,13 @@ class KNNTransform : public Transform | ||
| 111 | QHash<QString, float> votes; | 111 | QHash<QString, float> votes; |
| 112 | const int max = (k < 1) ? sortedScores.size() : std::min(k, sortedScores.size()); | 112 | const int max = (k < 1) ? sortedScores.size() : std::min(k, sortedScores.size()); |
| 113 | for (int j=0; j<max; j++) | 113 | for (int j=0; j<max; j++) |
| 114 | - votes[gallery[sortedScores[j].second].file.subject()] += (weighted ? sortedScores[j].first : 1); | 114 | + votes[gallery[sortedScores[j].second].file.get<QString>("Subject")] += (weighted ? sortedScores[j].first : 1); |
| 115 | subjects.append(votes.keys()[votes.values().indexOf(Common::Max(votes.values()))]); | 115 | subjects.append(votes.keys()[votes.values().indexOf(Common::Max(votes.values()))]); |
| 116 | 116 | ||
| 117 | // Remove subject from consideration | 117 | // Remove subject from consideration |
| 118 | if (subjects.size() < numSubjects) | 118 | if (subjects.size() < numSubjects) |
| 119 | for (int j=sortedScores.size()-1; j>=0; j--) | 119 | for (int j=sortedScores.size()-1; j>=0; j--) |
| 120 | - if (gallery[sortedScores[j].second].file.subject() == subjects.last()) | 120 | + if (gallery[sortedScores[j].second].file.get<QString>("Subject") == subjects.last()) |
| 121 | sortedScores.removeAt(j); | 121 | sortedScores.removeAt(j); |
| 122 | } | 122 | } |
| 123 | 123 |
openbr/plugins/eigen3.cpp
| @@ -329,7 +329,8 @@ class LDATransform : public Transform | @@ -329,7 +329,8 @@ class LDATransform : public Transform | ||
| 329 | 329 | ||
| 330 | void train(const TemplateList &_trainingSet) | 330 | void train(const TemplateList &_trainingSet) |
| 331 | { | 331 | { |
| 332 | - TemplateList trainingSet = TemplateList::relabel(_trainingSet); | 332 | + // creates "Label" |
| 333 | + TemplateList trainingSet = TemplateList::relabel(_trainingSet, "Subject"); | ||
| 333 | 334 | ||
| 334 | int instances = trainingSet.size(); | 335 | int instances = trainingSet.size(); |
| 335 | 336 | ||
| @@ -342,13 +343,13 @@ class LDATransform : public Transform | @@ -342,13 +343,13 @@ class LDATransform : public Transform | ||
| 342 | 343 | ||
| 343 | TemplateList ldaTrainingSet; | 344 | TemplateList ldaTrainingSet; |
| 344 | static_cast<Transform*>(&pca)->project(trainingSet, ldaTrainingSet); | 345 | static_cast<Transform*>(&pca)->project(trainingSet, ldaTrainingSet); |
| 345 | - ldaTrainingSet = TemplateList::relabel(ldaTrainingSet); | ||
| 346 | 346 | ||
| 347 | int dimsIn = ldaTrainingSet.first().m().rows * ldaTrainingSet.first().m().cols; | 347 | int dimsIn = ldaTrainingSet.first().m().rows * ldaTrainingSet.first().m().cols; |
| 348 | 348 | ||
| 349 | // OpenBR ensures that class values range from 0 to numClasses-1. | 349 | // OpenBR ensures that class values range from 0 to numClasses-1. |
| 350 | - QList<int> classes = trainingSet.labels<int>(); | ||
| 351 | - QMap<int, int> classCounts = trainingSet.labelCounts(); | 350 | + // Label exists because we created it earlier with relabel |
| 351 | + QList<int> classes = trainingSet.get<int>("Label"); | ||
| 352 | + QMap<int, int> classCounts = trainingSet.countValues<int>("Label"); | ||
| 352 | const int numClasses = classCounts.size(); | 353 | const int numClasses = classCounts.size(); |
| 353 | 354 | ||
| 354 | // Map Eigen into OpenCV | 355 | // Map Eigen into OpenCV |
openbr/plugins/gallery.cpp
| @@ -65,13 +65,13 @@ class arffGallery : public Gallery | @@ -65,13 +65,13 @@ class arffGallery : public Gallery | ||
| 65 | const int dimensions = t.m().rows * t.m().cols; | 65 | const int dimensions = t.m().rows * t.m().cols; |
| 66 | for (int i=0; i<dimensions; i++) | 66 | for (int i=0; i<dimensions; i++) |
| 67 | arffFile.write(qPrintable("@ATTRIBUTE v" + QString::number(i) + " REAL\n")); | 67 | arffFile.write(qPrintable("@ATTRIBUTE v" + QString::number(i) + " REAL\n")); |
| 68 | - arffFile.write(qPrintable("@ATTRIBUTE class {" + QStringList(Globals->subjects.keys()).join(',') + "}\n")); | 68 | + arffFile.write(qPrintable("@ATTRIBUTE class string\n")); |
| 69 | 69 | ||
| 70 | arffFile.write("\n@DATA\n"); | 70 | arffFile.write("\n@DATA\n"); |
| 71 | } | 71 | } |
| 72 | 72 | ||
| 73 | arffFile.write(qPrintable(OpenCVUtils::matrixToStringList(t).join(','))); | 73 | arffFile.write(qPrintable(OpenCVUtils::matrixToStringList(t).join(','))); |
| 74 | - arffFile.write(qPrintable(",'" + t.file.subject() + "'\n")); | 74 | + arffFile.write(qPrintable(",'" + t.file.get<QString>("Subject") + "'\n")); |
| 75 | } | 75 | } |
| 76 | }; | 76 | }; |
| 77 | 77 | ||
| @@ -517,11 +517,7 @@ class csvGallery : public Gallery | @@ -517,11 +517,7 @@ class csvGallery : public Gallery | ||
| 517 | 517 | ||
| 518 | static QString getCSVElement(const QString &key, const QVariant &value, bool header) | 518 | static QString getCSVElement(const QString &key, const QVariant &value, bool header) |
| 519 | { | 519 | { |
| 520 | - if ((key == "Label") && !header) { | ||
| 521 | - QString stringLabel = Globals->subjects.key(value.value<int>()); | ||
| 522 | - if (stringLabel.isEmpty()) return value.value<QString>(); | ||
| 523 | - else return stringLabel; | ||
| 524 | - } else if (value.canConvert<QString>()) { | 520 | + if (value.canConvert<QString>()) { |
| 525 | if (header) return key; | 521 | if (header) return key; |
| 526 | else return value.value<QString>(); | 522 | else return value.value<QString>(); |
| 527 | } else if (value.canConvert<QPointF>()) { | 523 | } else if (value.canConvert<QPointF>()) { |
| @@ -878,7 +874,7 @@ class statGallery : public Gallery | @@ -878,7 +874,7 @@ class statGallery : public Gallery | ||
| 878 | 874 | ||
| 879 | void write(const Template &t) | 875 | void write(const Template &t) |
| 880 | { | 876 | { |
| 881 | - subjects.insert(t.file.subject()); | 877 | + subjects.insert(t.file.get<QString>("Subject")); |
| 882 | bytes.append(t.bytes()); | 878 | bytes.append(t.bytes()); |
| 883 | } | 879 | } |
| 884 | }; | 880 | }; |
openbr/plugins/independent.cpp
| @@ -20,21 +20,23 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t | @@ -20,21 +20,23 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t | ||
| 20 | const bool atLeast = transform->instances < 0; | 20 | const bool atLeast = transform->instances < 0; |
| 21 | const int instances = abs(transform->instances); | 21 | const int instances = abs(transform->instances); |
| 22 | 22 | ||
| 23 | - QList<int> allLabels = templates.labels<int>(); | ||
| 24 | - QList<int> uniqueLabels = allLabels.toSet().toList(); | 23 | + QList<QString> allLabels = templates.get<QString>("Subject"); |
| 24 | + QList<QString> uniqueLabels = allLabels.toSet().toList(); | ||
| 25 | qSort(uniqueLabels); | 25 | qSort(uniqueLabels); |
| 26 | 26 | ||
| 27 | - QMap<int,int> counts = templates.labelCounts(instances != std::numeric_limits<int>::max()); | 27 | + QMap<QString,int> counts = templates.countValues<QString>("Subject", instances != std::numeric_limits<int>::max()); |
| 28 | + | ||
| 28 | if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max())) | 29 | if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max())) |
| 29 | - foreach (int label, counts.keys()) | 30 | + foreach (const QString & label, counts.keys()) |
| 30 | if (counts[label] < instances) | 31 | if (counts[label] < instances) |
| 31 | counts.remove(label); | 32 | counts.remove(label); |
| 33 | + | ||
| 32 | uniqueLabels = counts.keys(); | 34 | uniqueLabels = counts.keys(); |
| 33 | if ((transform->classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < transform->classes)) | 35 | if ((transform->classes != std::numeric_limits<int>::max()) && (uniqueLabels.size() < transform->classes)) |
| 34 | qWarning("Downsample requested %d classes but only %d are available.", transform->classes, uniqueLabels.size()); | 36 | qWarning("Downsample requested %d classes but only %d are available.", transform->classes, uniqueLabels.size()); |
| 35 | 37 | ||
| 36 | Common::seedRNG(); | 38 | Common::seedRNG(); |
| 37 | - QList<int> selectedLabels = uniqueLabels; | 39 | + QList<QString> selectedLabels = uniqueLabels; |
| 38 | if (transform->classes < uniqueLabels.size()) { | 40 | if (transform->classes < uniqueLabels.size()) { |
| 39 | std::random_shuffle(selectedLabels.begin(), selectedLabels.end()); | 41 | std::random_shuffle(selectedLabels.begin(), selectedLabels.end()); |
| 40 | selectedLabels = selectedLabels.mid(0, transform->classes); | 42 | selectedLabels = selectedLabels.mid(0, transform->classes); |
| @@ -42,7 +44,7 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t | @@ -42,7 +44,7 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t | ||
| 42 | 44 | ||
| 43 | TemplateList downsample; | 45 | TemplateList downsample; |
| 44 | for (int i=0; i<selectedLabels.size(); i++) { | 46 | for (int i=0; i<selectedLabels.size(); i++) { |
| 45 | - const int selectedLabel = selectedLabels[i]; | 47 | + const QString selectedLabel = selectedLabels[i]; |
| 46 | QList<int> indices; | 48 | QList<int> indices; |
| 47 | for (int j=0; j<allLabels.size(); j++) | 49 | for (int j=0; j<allLabels.size(); j++) |
| 48 | if ((allLabels[j] == selectedLabel) && (!templates.value(j).file.get<bool>("FTE", false))) | 50 | if ((allLabels[j] == selectedLabel) && (!templates.value(j).file.get<bool>("FTE", false))) |
openbr/plugins/meta.cpp
| @@ -453,7 +453,6 @@ private: | @@ -453,7 +453,6 @@ private: | ||
| 453 | const QString &file = src.file; | 453 | const QString &file = src.file; |
| 454 | if (cache.contains(file)) { | 454 | if (cache.contains(file)) { |
| 455 | dst = cache[file]; | 455 | dst = cache[file]; |
| 456 | - dst.file.set("Label", src.file.label()); | ||
| 457 | } else { | 456 | } else { |
| 458 | transform->project(src, dst); | 457 | transform->project(src, dst); |
| 459 | cacheLock.lock(); | 458 | cacheLock.lock(); |
openbr/plugins/normalize.cpp
| @@ -126,7 +126,8 @@ private: | @@ -126,7 +126,8 @@ private: | ||
| 126 | { | 126 | { |
| 127 | Mat m; | 127 | Mat m; |
| 128 | OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F); | 128 | OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F); |
| 129 | - const QList<int> labels = data.labels<int>(); | 129 | + |
| 130 | + const QList<int> labels = data.indexProperty("Subject"); | ||
| 130 | const int dims = m.cols; | 131 | const int dims = m.cols; |
| 131 | 132 | ||
| 132 | vector<Mat> mv, av, bv; | 133 | vector<Mat> mv, av, bv; |
openbr/plugins/output.cpp
| @@ -145,8 +145,10 @@ class meltOutput : public MatrixOutput | @@ -145,8 +145,10 @@ class meltOutput : public MatrixOutput | ||
| 145 | 145 | ||
| 146 | QStringList lines; | 146 | QStringList lines; |
| 147 | if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys)); | 147 | if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys)); |
| 148 | - QList<float> queryLabels = queryFiles.labels(); | ||
| 149 | - QList<float> targetLabels = targetFiles.labels(); | 148 | + |
| 149 | + QList<QString> queryLabels = queryFiles.get<QString>("Subject"); | ||
| 150 | + QList<QString> targetLabels = targetFiles.get<QString>("Subject"); | ||
| 151 | + | ||
| 150 | for (int i=0; i<queryFiles.size(); i++) { | 152 | for (int i=0; i<queryFiles.size(); i++) { |
| 151 | for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) { | 153 | for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) { |
| 152 | const bool genuine = queryLabels[i] == targetLabels[j]; | 154 | const bool genuine = queryLabels[i] == targetLabels[j]; |
| @@ -296,7 +298,7 @@ class txtOutput : public MatrixOutput | @@ -296,7 +298,7 @@ class txtOutput : public MatrixOutput | ||
| 296 | if (file.isNull() || targetFiles.isEmpty() || queryFiles.isEmpty()) return; | 298 | if (file.isNull() || targetFiles.isEmpty() || queryFiles.isEmpty()) return; |
| 297 | QStringList lines; | 299 | QStringList lines; |
| 298 | foreach (const File &file, queryFiles) | 300 | foreach (const File &file, queryFiles) |
| 299 | - lines.append(file.name + " " + file.subject()); | 301 | + lines.append(file.name + " " + file.get<QString>("Subject")); |
| 300 | QtUtils::writeFile(file, lines); | 302 | QtUtils::writeFile(file, lines); |
| 301 | } | 303 | } |
| 302 | }; | 304 | }; |
openbr/plugins/pixel.cpp deleted
| 1 | -/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * | ||
| 2 | - * Copyright 2012 The MITRE Corporation * | ||
| 3 | - * * | ||
| 4 | - * Licensed under the Apache License, Version 2.0 (the "License"); * | ||
| 5 | - * you may not use this file except in compliance with the License. * | ||
| 6 | - * You may obtain a copy of the License at * | ||
| 7 | - * * | ||
| 8 | - * http://www.apache.org/licenses/LICENSE-2.0 * | ||
| 9 | - * * | ||
| 10 | - * Unless required by applicable law or agreed to in writing, software * | ||
| 11 | - * distributed under the License is distributed on an "AS IS" BASIS, * | ||
| 12 | - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * | ||
| 13 | - * See the License for the specific language governing permissions and * | ||
| 14 | - * limitations under the License. * | ||
| 15 | - * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ | ||
| 16 | - | ||
| 17 | -#include "openbr_internal.h" | ||
| 18 | - | ||
| 19 | -using namespace cv; | ||
| 20 | - | ||
| 21 | -namespace br | ||
| 22 | -{ | ||
| 23 | - | ||
| 24 | -/*! | ||
| 25 | - * \ingroup transforms | ||
| 26 | - * \brief Treat each pixel as a classification task | ||
| 27 | - * \author E. Taborsky \cite mmtaborsky | ||
| 28 | - */ | ||
| 29 | -class PerPixelClassifierTransform : public MetaTransform | ||
| 30 | -{ | ||
| 31 | - Q_OBJECT | ||
| 32 | - Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform) | ||
| 33 | - Q_PROPERTY(int pixels READ get_pixels WRITE set_pixels RESET reset_pixels STORED false) | ||
| 34 | - Q_PROPERTY(int orient READ get_orient WRITE set_orient RESET reset_orient STORED false) | ||
| 35 | - BR_PROPERTY(br::Transform*, transform, NULL) | ||
| 36 | - BR_PROPERTY(int, pixels, 10000) | ||
| 37 | - BR_PROPERTY(bool, orient, false) | ||
| 38 | - | ||
| 39 | - /* | ||
| 40 | - Bins: | ||
| 41 | - |4|3|2| | ||
| 42 | - |5| |1| | ||
| 43 | - |6|7|8| | ||
| 44 | - */ | ||
| 45 | - | ||
| 46 | - QList<float> shift(int n, QList<float> &src) const | ||
| 47 | - { | ||
| 48 | - for (int i = 0; i < n; i++){ // Equivalent to src.append(src.takeFirst()) ? | ||
| 49 | - src.append(src.at(i)); | ||
| 50 | - src.removeFirst(); | ||
| 51 | - } | ||
| 52 | - return src; | ||
| 53 | - } | ||
| 54 | - | ||
| 55 | - void rotate(Template &src, Template &dst) const | ||
| 56 | - { | ||
| 57 | - int images = (src.m().cols)/9; | ||
| 58 | - dst = src; | ||
| 59 | - for (int i = 0; i < images; i++){ | ||
| 60 | - double a = src.m().at<float>(7+(i*9)); //top | ||
| 61 | - double b = src.m().at<float>(1+(i*9)); //bottom | ||
| 62 | - double c = src.m().at<float>(5+(i*9)); //right | ||
| 63 | - double d = src.m().at<float>(3+(i*9)); //left | ||
| 64 | - double orientation = atan2((a-b),(c-d)); | ||
| 65 | - int bin; | ||
| 66 | - if (orientation > 0){ | ||
| 67 | - bin = ((orientation/CV_PI)*4.0 +.5); | ||
| 68 | - } else { | ||
| 69 | - bin = 8.0 + ((orientation/CV_PI)*4.0 + .5); | ||
| 70 | - } | ||
| 71 | - | ||
| 72 | - // put things in an order that makes sense to rotate | ||
| 73 | - // blugh | ||
| 74 | - QList<float> orderedList; | ||
| 75 | - QList<float> rotatedList; | ||
| 76 | - orderedList.insert(0, src.m().at<float>(3+(i*9))); | ||
| 77 | - orderedList.insert(1, src.m().at<float>(6+(i*9))); | ||
| 78 | - orderedList.insert(2, src.m().at<float>(7+(i*9))); | ||
| 79 | - orderedList.insert(3, src.m().at<float>(8+(i*9))); | ||
| 80 | - orderedList.insert(4, src.m().at<float>(5+(i*9))); | ||
| 81 | - orderedList.insert(5, src.m().at<float>(2+(i*9))); | ||
| 82 | - orderedList.insert(6, src.m().at<float>(1+(i*9))); | ||
| 83 | - orderedList.insert(7, src.m().at<float>(0+(i*9))); | ||
| 84 | - | ||
| 85 | - rotatedList = shift(bin, orderedList); | ||
| 86 | - | ||
| 87 | - dst.m().at<float>(0+(i*9)) = rotatedList.at(7); | ||
| 88 | - dst.m().at<float>(1+(i*9)) = rotatedList.at(6); | ||
| 89 | - dst.m().at<float>(2+(i*9)) = rotatedList.at(5); | ||
| 90 | - dst.m().at<float>(3+(i*9)) = rotatedList.at(0); | ||
| 91 | - dst.m().at<float>(4+(i*9)) = src.m().at<float>(4+(i*9)); // middle pixel not in orderedList | ||
| 92 | - dst.m().at<float>(5+(i*9)) = rotatedList.at(4); | ||
| 93 | - dst.m().at<float>(6+(i*9)) = rotatedList.at(1); | ||
| 94 | - dst.m().at<float>(7+(i*9)) = rotatedList.at(2); | ||
| 95 | - dst.m().at<float>(8+(i*9)) = rotatedList.at(3); | ||
| 96 | - } | ||
| 97 | - } | ||
| 98 | - | ||
| 99 | - void train(const TemplateList &trainingSet) | ||
| 100 | - { | ||
| 101 | - TemplateList pixelTemplates = TemplateList(); | ||
| 102 | - const int length = trainingSet.length(); | ||
| 103 | - int pixelsPerImage = pixels/length; | ||
| 104 | - | ||
| 105 | - for (int i=0; i < length; i++){ | ||
| 106 | - Template src = trainingSet.at(i); | ||
| 107 | - | ||
| 108 | - const int mats = src.length(); | ||
| 109 | - const int rows = src.m().rows; | ||
| 110 | - const int cols = src.m().cols; | ||
| 111 | - | ||
| 112 | - RNG &rng = theRNG(); | ||
| 113 | - TemplateList srcPixelTemplates; | ||
| 114 | - | ||
| 115 | - for (int m=0; m < pixelsPerImage; m++){ | ||
| 116 | - int index = rng.uniform(0, rows*cols); | ||
| 117 | - Template temp = Template(src.file, cv::Mat(1, mats, CV_32F)); | ||
| 118 | - float *ptemp = (float*)temp.m().ptr(); | ||
| 119 | - for (int n=0; n < mats; n++){ | ||
| 120 | - uchar *psrc = src[n].ptr(); | ||
| 121 | - ptemp[n] = psrc[index]; | ||
| 122 | - } | ||
| 123 | - cv::Mat labelMat = src.file.value("labels").value<cv::Mat>(); | ||
| 124 | - uchar* plabel = labelMat.ptr(); | ||
| 125 | - temp.file.set("Label", plabel[index]); | ||
| 126 | - | ||
| 127 | - if (orient){ | ||
| 128 | - Template rotated; | ||
| 129 | - rotate(temp, rotated); | ||
| 130 | - srcPixelTemplates.append(rotated); | ||
| 131 | - } else { | ||
| 132 | - srcPixelTemplates.append(temp); | ||
| 133 | - } | ||
| 134 | - } | ||
| 135 | - pixelTemplates.append(srcPixelTemplates); | ||
| 136 | - } | ||
| 137 | - transform->train(pixelTemplates); | ||
| 138 | - } | ||
| 139 | - | ||
| 140 | - void project(const Template &src, Template &dst) const | ||
| 141 | - { | ||
| 142 | - const int mats = src.length(); | ||
| 143 | - const int rows = src.m().rows; | ||
| 144 | - const int cols = src.m().cols; | ||
| 145 | - | ||
| 146 | - dst = src; // Do we really want to copy all the src matrices into dst? | ||
| 147 | - dst.merge(Template(src.file, cv::Mat(src.m().rows, src.m().cols, CV_32F))); | ||
| 148 | - float *pdst = (float*) dst.m().ptr(); | ||
| 149 | - | ||
| 150 | - for (int r = 0; r < rows; r++){ | ||
| 151 | - for (int c = 0; c < cols; c++){ | ||
| 152 | - Template temp = Template(src.file, cv::Mat(1, mats, CV_32F)); | ||
| 153 | - Template dstTemp = Template(src.file, cv::Mat(1, mats, CV_32F)); | ||
| 154 | - | ||
| 155 | - for (int n=0; n < mats; n++){ | ||
| 156 | - const uchar *psrc = src[n].ptr(); | ||
| 157 | - float *ptemp = (float*)temp[0].ptr(); | ||
| 158 | - int index = r*cols + c; | ||
| 159 | - ptemp[n] = psrc[index]; | ||
| 160 | - } | ||
| 161 | - | ||
| 162 | - if (orient){ | ||
| 163 | - Template rotated = Template(src.file, cv::Mat(1, mats, CV_32F)); | ||
| 164 | - rotate(temp, rotated); | ||
| 165 | - temp = rotated; | ||
| 166 | - } | ||
| 167 | - | ||
| 168 | - transform->project(temp,dstTemp); | ||
| 169 | - pdst[r*cols+c] = dstTemp.file.label(); | ||
| 170 | - } | ||
| 171 | - } | ||
| 172 | - } | ||
| 173 | -}; | ||
| 174 | - | ||
| 175 | -BR_REGISTER(Transform, PerPixelClassifierTransform) | ||
| 176 | - | ||
| 177 | -/*! | ||
| 178 | - * \ingroup transforms | ||
| 179 | - * \brief Construct feature vectors of neighboring pixels | ||
| 180 | - * \author E. Taborsky \cite mmtaborsky | ||
| 181 | - */ | ||
| 182 | -class NeighborsTransform : public UntrainableMetaTransform | ||
| 183 | -{ | ||
| 184 | - Q_OBJECT | ||
| 185 | - | ||
| 186 | - void project(const Template &src, Template &dst) const | ||
| 187 | - { | ||
| 188 | - int rows = src.m().rows; | ||
| 189 | - int cols = src.m().cols; | ||
| 190 | - int mats = src.length(); | ||
| 191 | - dst.file = src.file; | ||
| 192 | - | ||
| 193 | - for (int n = 0; n < mats; n++){ //each matrix, except the last one, will be turned into 9 matrices | ||
| 194 | - const uchar *psrc = src[n].ptr(); | ||
| 195 | - for (int i = -1; i < 2; i++){ | ||
| 196 | - for (int j = -1; j < 2; j++){ // these nine matrices are shifted versions of the original | ||
| 197 | - cv::Mat mat = cv::Mat(rows, cols, CV_8UC1); | ||
| 198 | - uchar *pmat = (uchar*)mat.ptr(); | ||
| 199 | - for (int r = 0; r < rows; r++){ | ||
| 200 | - for (int c = 0; c < cols; c++){ | ||
| 201 | - int index = r*cols+c; | ||
| 202 | - int newIndex = index + i*cols + j; | ||
| 203 | - if ((newIndex < 0) || (newIndex >= rows*cols)){ | ||
| 204 | - pmat[index] = psrc[index]; | ||
| 205 | - } else { | ||
| 206 | - pmat[index] = psrc[newIndex]; | ||
| 207 | - } | ||
| 208 | - } | ||
| 209 | - } | ||
| 210 | - dst.push_back(mat); //add mat to dst | ||
| 211 | - } | ||
| 212 | - } | ||
| 213 | - } | ||
| 214 | - dst.push_back(src.m()); // add the last matrix | ||
| 215 | - } | ||
| 216 | -}; | ||
| 217 | - | ||
| 218 | -BR_REGISTER(Transform, NeighborsTransform) | ||
| 219 | - | ||
| 220 | -/*! | ||
| 221 | - * \ingroup transforms | ||
| 222 | - * \brief To binary vector | ||
| 223 | - * \author E. Taborsky \cite mmtaborsky | ||
| 224 | - */ | ||
| 225 | -class ToBinaryVectorTransform : public UntrainableMetaTransform | ||
| 226 | -{ | ||
| 227 | - Q_OBJECT | ||
| 228 | - Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform STORED false) | ||
| 229 | - Q_PROPERTY(int length READ get_length WRITE set_length RESET reset_length STORED false) | ||
| 230 | - BR_PROPERTY(br::Transform*, transform, NULL) | ||
| 231 | - BR_PROPERTY(int, length, -1) | ||
| 232 | - | ||
| 233 | - //needs to be updated.. | ||
| 234 | - void project(const Template &src, Template &dst) const | ||
| 235 | - { | ||
| 236 | - | ||
| 237 | - dst = src; | ||
| 238 | - int mats = src.length(); | ||
| 239 | - for (int i = 0; i < mats; i++){ | ||
| 240 | - // Does this actually modify the data? | ||
| 241 | - dst[i]*(1.0/255.0); //scaling the input matrices to make the svm happier | ||
| 242 | - } | ||
| 243 | - for (int i = 0; i < length*(mats); i++){ | ||
| 244 | - dst.prepend(Template(src.file, Mat::zeros(src.m().rows, src.m().cols, CV_8U))); | ||
| 245 | - } | ||
| 246 | - | ||
| 247 | - // original pixel values at the end | ||
| 248 | - | ||
| 249 | - Template transformed; | ||
| 250 | - transformed.file = src.file; | ||
| 251 | - transform->project(src, transformed); | ||
| 252 | - | ||
| 253 | - int rows = transformed.m().rows; | ||
| 254 | - int cols = transformed.m().cols; | ||
| 255 | - | ||
| 256 | - for (int i = 0; i < mats; i++){ | ||
| 257 | - uchar *ptransformed = transformed[i].ptr(); | ||
| 258 | - for (int r = 0; r < rows; r++){ | ||
| 259 | - for (int c = 0; c < cols; c++){ | ||
| 260 | - uchar index = ptransformed[r*cols+c]; | ||
| 261 | - dst[index+(length*i)].at<uchar>(r,c) = 1; | ||
| 262 | - } | ||
| 263 | - } | ||
| 264 | - } | ||
| 265 | - } | ||
| 266 | -}; | ||
| 267 | - | ||
| 268 | -BR_REGISTER(Transform, ToBinaryVectorTransform) | ||
| 269 | - | ||
| 270 | -/*! | ||
| 271 | - * \ingroup transforms | ||
| 272 | - * \brief If "labels" is specified, makes the last matrix into metadata | ||
| 273 | - * \author E. Taborsky \cite mmtaborsky | ||
| 274 | - */ | ||
| 275 | - | ||
| 276 | -class ToMetadataTransform : public UntrainableMetaTransform | ||
| 277 | -{ | ||
| 278 | - Q_OBJECT | ||
| 279 | - | ||
| 280 | - void project(const Template &src, Template &dst) const | ||
| 281 | - { | ||
| 282 | - dst = src; | ||
| 283 | - if (dst.file.contains("labels")){ | ||
| 284 | - QVariant last = qVariantFromValue(dst.m()); | ||
| 285 | - dst.file.set("labels", last); | ||
| 286 | - dst.pop_back(); | ||
| 287 | - } | ||
| 288 | - } | ||
| 289 | - | ||
| 290 | -}; | ||
| 291 | - | ||
| 292 | -BR_REGISTER(Transform, ToMetadataTransform) | ||
| 293 | - | ||
| 294 | -} // namespace br | ||
| 295 | - | ||
| 296 | -#include "pixel.moc" |
openbr/plugins/quality.cpp
| @@ -26,10 +26,10 @@ class ImpostorUniquenessMeasureTransform : public Transform | @@ -26,10 +26,10 @@ class ImpostorUniquenessMeasureTransform : public Transform | ||
| 26 | 26 | ||
| 27 | float calculateIUM(const Template &probe, const TemplateList &gallery) const | 27 | float calculateIUM(const Template &probe, const TemplateList &gallery) const |
| 28 | { | 28 | { |
| 29 | - const int probeLabel = probe.file.label(); | 29 | + const QString probeLabel = probe.file.get<QString>("Subject"); |
| 30 | TemplateList subset = gallery; | 30 | TemplateList subset = gallery; |
| 31 | for (int j=subset.size()-1; j>=0; j--) | 31 | for (int j=subset.size()-1; j>=0; j--) |
| 32 | - if (subset[j].file.label() == probeLabel) | 32 | + if (subset[j].file.get<QString>("Subject") == probeLabel) |
| 33 | subset.removeAt(j); | 33 | subset.removeAt(j); |
| 34 | 34 | ||
| 35 | QList<float> scores = distance->compare(subset, probe); | 35 | QList<float> scores = distance->compare(subset, probe); |
| @@ -158,8 +158,7 @@ class MatchProbabilityDistance : public Distance | @@ -158,8 +158,7 @@ class MatchProbabilityDistance : public Distance | ||
| 158 | { | 158 | { |
| 159 | distance->train(src); | 159 | distance->train(src); |
| 160 | 160 | ||
| 161 | - const QList<int> labels = src.labels<int>(); | ||
| 162 | - | 161 | + const QList<int> labels = src.indexProperty("Subject"); |
| 163 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); | 162 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); |
| 164 | distance->compare(src, src, matrixOutput.data()); | 163 | distance->compare(src, src, matrixOutput.data()); |
| 165 | 164 | ||
| @@ -229,7 +228,7 @@ class HeatMapDistance : public Distance | @@ -229,7 +228,7 @@ class HeatMapDistance : public Distance | ||
| 229 | { | 228 | { |
| 230 | distance->train(src); | 229 | distance->train(src); |
| 231 | 230 | ||
| 232 | - const QList<int> labels = src.labels<int>(); | 231 | + const QList<int> labels = src.indexProperty("Subject"); |
| 233 | 232 | ||
| 234 | QList<TemplateList> patches; | 233 | QList<TemplateList> patches; |
| 235 | 234 | ||
| @@ -317,7 +316,7 @@ class UnitDistance : public Distance | @@ -317,7 +316,7 @@ class UnitDistance : public Distance | ||
| 317 | void train(const TemplateList &templates) | 316 | void train(const TemplateList &templates) |
| 318 | { | 317 | { |
| 319 | const TemplateList samples = templates.mid(0, 2000); | 318 | const TemplateList samples = templates.mid(0, 2000); |
| 320 | - const QList<float> sampleLabels = samples.labels<float>(); | 319 | + const QList<int> sampleLabels = samples.indexProperty("Subject"); |
| 321 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size()))); | 320 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size()))); |
| 322 | Distance::compare(samples, samples, matrixOutput.data()); | 321 | Distance::compare(samples, samples, matrixOutput.data()); |
| 323 | 322 |
openbr/plugins/quantize.cpp
| @@ -150,7 +150,7 @@ class BayesianQuantizationDistance : public Distance | @@ -150,7 +150,7 @@ class BayesianQuantizationDistance : public Distance | ||
| 150 | qFatal("Expected sigle matrix templates of type CV_8UC1!"); | 150 | qFatal("Expected sigle matrix templates of type CV_8UC1!"); |
| 151 | 151 | ||
| 152 | const Mat data = OpenCVUtils::toMat(src.data()); | 152 | const Mat data = OpenCVUtils::toMat(src.data()); |
| 153 | - const QList<int> templateLabels = src.labels<int>(); | 153 | + const QList<int> templateLabels = src.indexProperty("Subject"); |
| 154 | loglikelihoods = QVector<float>(data.cols*256, 0); | 154 | loglikelihoods = QVector<float>(data.cols*256, 0); |
| 155 | 155 | ||
| 156 | QFutureSynchronizer<void> futures; | 156 | QFutureSynchronizer<void> futures; |
| @@ -473,7 +473,8 @@ private: | @@ -473,7 +473,8 @@ private: | ||
| 473 | { | 473 | { |
| 474 | Mat data = OpenCVUtils::toMat(src.data()); | 474 | Mat data = OpenCVUtils::toMat(src.data()); |
| 475 | const int step = getStep(data.cols); | 475 | const int step = getStep(data.cols); |
| 476 | - const QList<int> labels = src.labels<int>(); | 476 | + |
| 477 | + const QList<int> labels = src.indexProperty("Subject"); | ||
| 477 | 478 | ||
| 478 | Mat &lut = ProductQuantizationLUTs[index]; | 479 | Mat &lut = ProductQuantizationLUTs[index]; |
| 479 | lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1); | 480 | lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1); |
openbr/plugins/quantize2.cpp
| @@ -77,7 +77,7 @@ class BayesianQuantizationTransform : public Transform | @@ -77,7 +77,7 @@ class BayesianQuantizationTransform : public Transform | ||
| 77 | void train(const TemplateList &src) | 77 | void train(const TemplateList &src) |
| 78 | { | 78 | { |
| 79 | const Mat data = OpenCVUtils::toMat(src.data()); | 79 | const Mat data = OpenCVUtils::toMat(src.data()); |
| 80 | - const QList<int> labels = src.labels<int>(); | 80 | + const QList<int> labels = src.indexProperty("Subject"); |
| 81 | 81 | ||
| 82 | thresholds = QVector<float>(256*data.cols); | 82 | thresholds = QVector<float>(256*data.cols); |
| 83 | 83 |
openbr/plugins/svm.cpp
| @@ -121,28 +121,46 @@ private: | @@ -121,28 +121,46 @@ private: | ||
| 121 | BR_PROPERTY(float, gamma, -1) | 121 | BR_PROPERTY(float, gamma, -1) |
| 122 | 122 | ||
| 123 | SVM svm; | 123 | SVM svm; |
| 124 | + QHash<QString, int> labelMap; | ||
| 125 | + QHash<int, QVariant> reverseLookup; | ||
| 124 | 126 | ||
| 125 | void train(const TemplateList &_data) | 127 | void train(const TemplateList &_data) |
| 126 | { | 128 | { |
| 127 | Mat data = OpenCVUtils::toMat(_data.data()); | 129 | Mat data = OpenCVUtils::toMat(_data.data()); |
| 128 | - Mat lab = OpenCVUtils::toMat(_data.labels<float>()); | 130 | + Mat lab; |
| 131 | + // If we are doing regression, assume subject has float values | ||
| 132 | + if (type == EPS_SVR || type == NU_SVR) { | ||
| 133 | + lab = OpenCVUtils::toMat(_data.get<float>("Subject")); | ||
| 134 | + } | ||
| 135 | + // If we are doing classification, assume subject has discrete values, map them | ||
| 136 | + // and store the mapping data | ||
| 137 | + else { | ||
| 138 | + QList<int> dataLabels = _data.indexProperty("Subject", labelMap, reverseLookup); | ||
| 139 | + lab = OpenCVUtils::toMat(dataLabels); | ||
| 140 | + } | ||
| 129 | trainSVM(svm, data, lab, kernel, type, C, gamma); | 141 | trainSVM(svm, data, lab, kernel, type, C, gamma); |
| 130 | } | 142 | } |
| 131 | 143 | ||
| 132 | void project(const Template &src, Template &dst) const | 144 | void project(const Template &src, Template &dst) const |
| 133 | { | 145 | { |
| 134 | dst = src; | 146 | dst = src; |
| 135 | - dst.file.set("Label", svm.predict(src.m().reshape(1, 1))); | 147 | + float prediction = svm.predict(src.m().reshape(1, 1)); |
| 148 | + if (type == EPS_SVR || type == NU_SVR) | ||
| 149 | + dst.file.set("Subject", prediction); | ||
| 150 | + else | ||
| 151 | + dst.file.set("Subject", reverseLookup[prediction]); | ||
| 136 | } | 152 | } |
| 137 | 153 | ||
| 138 | void store(QDataStream &stream) const | 154 | void store(QDataStream &stream) const |
| 139 | { | 155 | { |
| 140 | storeSVM(svm, stream); | 156 | storeSVM(svm, stream); |
| 157 | + stream << labelMap << reverseLookup; | ||
| 141 | } | 158 | } |
| 142 | 159 | ||
| 143 | void load(QDataStream &stream) | 160 | void load(QDataStream &stream) |
| 144 | { | 161 | { |
| 145 | loadSVM(svm, stream); | 162 | loadSVM(svm, stream); |
| 163 | + stream >> labelMap >> reverseLookup; | ||
| 146 | } | 164 | } |
| 147 | }; | 165 | }; |
| 148 | 166 | ||
| @@ -182,7 +200,7 @@ private: | @@ -182,7 +200,7 @@ private: | ||
| 182 | void train(const TemplateList &src) | 200 | void train(const TemplateList &src) |
| 183 | { | 201 | { |
| 184 | const Mat data = OpenCVUtils::toMat(src.data()); | 202 | const Mat data = OpenCVUtils::toMat(src.data()); |
| 185 | - const QList<int> lab = src.labels<int>(); | 203 | + const QList<int> lab = src.indexProperty("Subject"); |
| 186 | 204 | ||
| 187 | const int instances = data.rows * (data.rows+1) / 2; | 205 | const int instances = data.rows * (data.rows+1) / 2; |
| 188 | Mat deltaData(instances, data.cols, data.type()); | 206 | Mat deltaData(instances, data.cols, data.type()); |