Commit 78adc5c5b87f428d61d2b7bc0376c5973b62d677
1 parent
7d0f5a9b
Added partition parameter to makeMask, fixed CMC plot bug only visible using crossValidate
Showing
4 changed files
with
19 additions
and
7 deletions
openbr/core/bee.cpp
| ... | ... | @@ -235,12 +235,17 @@ void BEE::writeMask(const Mat &m, const QString &mask, const QString &targetSigs |
| 235 | 235 | void BEE::makeMask(const QString &targetInput, const QString &queryInput, const QString &mask) |
| 236 | 236 | { |
| 237 | 237 | qDebug("Making mask from %s and %s to %s", qPrintable(targetInput), qPrintable(queryInput), qPrintable(mask)); |
| 238 | - FileList targes = TemplateList::fromGallery(targetInput).files(); | |
| 239 | - FileList queries = (queryInput == ".") ? targes : TemplateList::fromGallery(queryInput).files(); | |
| 240 | - writeMask(makeMask(targes, queries), mask, targetInput, queryInput); | |
| 238 | + FileList targets = TemplateList::fromGallery(targetInput).files(); | |
| 239 | + FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files(); | |
| 240 | + int partitions = targets.first().get<int>("crossValidate"); | |
| 241 | + if (partitions == 0) writeMask(makeMask(targets, queries), mask, targetInput, queryInput); | |
| 242 | + else for (int i=0; i<partitions; i++) { | |
| 243 | + QString maskPartition = mask; | |
| 244 | + writeMask(makeMask(targets, queries, i), maskPartition.insert(maskPartition.indexOf('.'),"Partition_" + QString::number(i)), targetInput, queryInput); | |
| 245 | + } | |
| 241 | 246 | } |
| 242 | 247 | |
| 243 | -cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries) | |
| 248 | +cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition) | |
| 244 | 249 | { |
| 245 | 250 | QList<float> targetLabels = targets.labels(); |
| 246 | 251 | QList<float> queryLabels = queries.labels(); |
| ... | ... | @@ -263,6 +268,8 @@ cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries) |
| 263 | 268 | else if (labelA == -1) val = DontCare; |
| 264 | 269 | else if (labelB == -1) val = DontCare; |
| 265 | 270 | else if (partitionA != partitionB) val = DontCare; |
| 271 | + else if (partitionA != partition) val = DontCare; | |
| 272 | + else if (partitionB != partition) val = DontCare; | |
| 266 | 273 | else if (labelA == labelB) val = Match; |
| 267 | 274 | else val = NonMatch; |
| 268 | 275 | mask.at<Mask_t>(i,j) = val; | ... | ... |
openbr/core/bee.h
| ... | ... | @@ -46,7 +46,7 @@ namespace BEE |
| 46 | 46 | |
| 47 | 47 | // Write BEE files |
| 48 | 48 | void makeMask(const QString &targetInput, const QString &queryInput, const QString &mask); |
| 49 | - cv::Mat makeMask(const br::FileList &targets, const br::FileList &queries); | |
| 49 | + cv::Mat makeMask(const br::FileList &targets, const br::FileList &queries, int partition = 0); | |
| 50 | 50 | void combineMasks(const QStringList &inputMasks, const QString &outputMask, const QString &method); |
| 51 | 51 | } |
| 52 | 52 | ... | ... |
openbr/core/plot.cpp
| ... | ... | @@ -34,6 +34,7 @@ |
| 34 | 34 | #include "version.h" |
| 35 | 35 | #include "openbr/core/bee.h" |
| 36 | 36 | #include "openbr/core/common.h" |
| 37 | +#include "openbr/core/opencvutils.h" | |
| 37 | 38 | #include "openbr/core/qtutils.h" |
| 38 | 39 | |
| 39 | 40 | #undef FAR // Windows preprecessor definition |
| ... | ... | @@ -260,8 +261,10 @@ float Evaluate(const Mat &simmat, const Mat &mask, const QString &csv) |
| 260 | 261 | for (int i=1; i<=Max_Retrieval; i++) { |
| 261 | 262 | int realizedReturns = 0, possibleReturns = 0; |
| 262 | 263 | foreach (int firstGenuineReturn, firstGenuineReturns) { |
| 263 | - if (firstGenuineReturn > 0) possibleReturns++; | |
| 264 | - if (firstGenuineReturn <= i) realizedReturns++; | |
| 264 | + if (firstGenuineReturn > 0) { | |
| 265 | + possibleReturns++; | |
| 266 | + if (firstGenuineReturn <= i) realizedReturns++; | |
| 267 | + } | |
| 265 | 268 | } |
| 266 | 269 | const float retrievalRate = float(realizedReturns)/possibleReturns; |
| 267 | 270 | lines.append(qPrintable(QString("CMC,%1,%2").arg(QString::number(i), QString::number(retrievalRate)))); | ... | ... |
openbr/plugins/validate.cpp
| ... | ... | @@ -40,8 +40,10 @@ class CrossValidateTransform : public MetaTransform |
| 40 | 40 | for (int i=0; i<numPartitions; i++) { |
| 41 | 41 | TemplateList partitionedData = data; |
| 42 | 42 | for (int j=partitionedData.size()-1; j>=0; j--) |
| 43 | + // Remove all templates from partition i | |
| 43 | 44 | if (partitions[j] == i) |
| 44 | 45 | partitionedData.removeAt(j); |
| 46 | + // Train on the remaining templates | |
| 45 | 47 | if (Globals->parallelism) futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); |
| 46 | 48 | else transforms[i]->train(partitionedData); |
| 47 | 49 | } | ... | ... |