From 78adc5c5b87f428d61d2b7bc0376c5973b62d677 Mon Sep 17 00:00:00 2001 From: Scott Klum Date: Mon, 25 Mar 2013 20:59:00 -0400 Subject: [PATCH] Added partition parameter to makeMask, fixed CMC plot bug only visible using crossValidate --- openbr/core/bee.cpp | 15 +++++++++++---- openbr/core/bee.h | 2 +- openbr/core/plot.cpp | 7 +++++-- openbr/plugins/validate.cpp | 2 ++ 4 files changed, 19 insertions(+), 7 deletions(-) diff --git a/openbr/core/bee.cpp b/openbr/core/bee.cpp index d224886..93d61d6 100644 --- a/openbr/core/bee.cpp +++ b/openbr/core/bee.cpp @@ -235,12 +235,17 @@ void BEE::writeMask(const Mat &m, const QString &mask, const QString &targetSigs void BEE::makeMask(const QString &targetInput, const QString &queryInput, const QString &mask) { qDebug("Making mask from %s and %s to %s", qPrintable(targetInput), qPrintable(queryInput), qPrintable(mask)); - FileList targes = TemplateList::fromGallery(targetInput).files(); - FileList queries = (queryInput == ".") ? targes : TemplateList::fromGallery(queryInput).files(); - writeMask(makeMask(targes, queries), mask, targetInput, queryInput); + FileList targets = TemplateList::fromGallery(targetInput).files(); + FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files(); + int partitions = targets.first().get("crossValidate"); + if (partitions == 0) writeMask(makeMask(targets, queries), mask, targetInput, queryInput); + else for (int i=0; i targetLabels = targets.labels(); QList queryLabels = queries.labels(); @@ -263,6 +268,8 @@ cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries) else if (labelA == -1) val = DontCare; else if (labelB == -1) val = DontCare; else if (partitionA != partitionB) val = DontCare; + else if (partitionA != partition) val = DontCare; + else if (partitionB != partition) val = DontCare; else if (labelA == labelB) val = Match; else val = NonMatch; mask.at(i,j) = val; diff --git a/openbr/core/bee.h b/openbr/core/bee.h index 3cd9fcd..c015d35 100644 --- a/openbr/core/bee.h +++ b/openbr/core/bee.h @@ -46,7 +46,7 @@ namespace BEE // Write BEE files void makeMask(const QString &targetInput, const QString &queryInput, const QString &mask); - cv::Mat makeMask(const br::FileList &targets, const br::FileList &queries); + cv::Mat makeMask(const br::FileList &targets, const br::FileList &queries, int partition = 0); void combineMasks(const QStringList &inputMasks, const QString &outputMask, const QString &method); } diff --git a/openbr/core/plot.cpp b/openbr/core/plot.cpp index 4641bd8..7f24d23 100644 --- a/openbr/core/plot.cpp +++ b/openbr/core/plot.cpp @@ -34,6 +34,7 @@ #include "version.h" #include "openbr/core/bee.h" #include "openbr/core/common.h" +#include "openbr/core/opencvutils.h" #include "openbr/core/qtutils.h" #undef FAR // Windows preprecessor definition @@ -260,8 +261,10 @@ float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv) for (int i=1; i<=Max_Retrieval; i++) { int realizedReturns = 0, possibleReturns = 0; foreach (int firstGenuineReturn, firstGenuineReturns) { - if (firstGenuineReturn > 0) possibleReturns++; - if (firstGenuineReturn <= i) realizedReturns++; + if (firstGenuineReturn > 0) { + possibleReturns++; + if (firstGenuineReturn <= i) realizedReturns++; + } } const float retrievalRate = float(realizedReturns)/possibleReturns; lines.append(qPrintable(QString("CMC,%1,%2").arg(QString::number(i), QString::number(retrievalRate)))); diff --git a/openbr/plugins/validate.cpp b/openbr/plugins/validate.cpp index a7dc81a..a94afea 100644 --- a/openbr/plugins/validate.cpp +++ b/openbr/plugins/validate.cpp @@ -40,8 +40,10 @@ class CrossValidateTransform : public MetaTransform for (int i=0; i=0; j--) + // Remove all templates from partition i if (partitions[j] == i) partitionedData.removeAt(j); + // Train on the remaining templates if (Globals->parallelism) futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); else transforms[i]->train(partitionedData); } -- libgit2 0.21.4