Commit 78adc5c5b87f428d61d2b7bc0376c5973b62d677

Authored by Scott Klum
1 parent 7d0f5a9b

Added partition parameter to makeMask, fixed CMC plot bug only visible using crossValidate

openbr/core/bee.cpp
@@ -235,12 +235,17 @@ void BEE::writeMask(const Mat &m, const QString &mask, const QString &targetSigs @@ -235,12 +235,17 @@ void BEE::writeMask(const Mat &m, const QString &mask, const QString &targetSigs
235 void BEE::makeMask(const QString &targetInput, const QString &queryInput, const QString &mask) 235 void BEE::makeMask(const QString &targetInput, const QString &queryInput, const QString &mask)
236 { 236 {
237 qDebug("Making mask from %s and %s to %s", qPrintable(targetInput), qPrintable(queryInput), qPrintable(mask)); 237 qDebug("Making mask from %s and %s to %s", qPrintable(targetInput), qPrintable(queryInput), qPrintable(mask));
238 - FileList targes = TemplateList::fromGallery(targetInput).files();  
239 - FileList queries = (queryInput == ".") ? targes : TemplateList::fromGallery(queryInput).files();  
240 - writeMask(makeMask(targes, queries), mask, targetInput, queryInput); 238 + FileList targets = TemplateList::fromGallery(targetInput).files();
  239 + FileList queries = (queryInput == ".") ? targets : TemplateList::fromGallery(queryInput).files();
  240 + int partitions = targets.first().get<int>("crossValidate");
  241 + if (partitions == 0) writeMask(makeMask(targets, queries), mask, targetInput, queryInput);
  242 + else for (int i=0; i<partitions; i++) {
  243 + QString maskPartition = mask;
  244 + writeMask(makeMask(targets, queries, i), maskPartition.insert(maskPartition.indexOf('.'),"Partition_" + QString::number(i)), targetInput, queryInput);
  245 + }
241 } 246 }
242 247
243 -cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries) 248 +cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition)
244 { 249 {
245 QList<float> targetLabels = targets.labels(); 250 QList<float> targetLabels = targets.labels();
246 QList<float> queryLabels = queries.labels(); 251 QList<float> queryLabels = queries.labels();
@@ -263,6 +268,8 @@ cv::Mat BEE::makeMask(const br::FileList &amp;targets, const br::FileList &amp;queries) @@ -263,6 +268,8 @@ cv::Mat BEE::makeMask(const br::FileList &amp;targets, const br::FileList &amp;queries)
263 else if (labelA == -1) val = DontCare; 268 else if (labelA == -1) val = DontCare;
264 else if (labelB == -1) val = DontCare; 269 else if (labelB == -1) val = DontCare;
265 else if (partitionA != partitionB) val = DontCare; 270 else if (partitionA != partitionB) val = DontCare;
  271 + else if (partitionA != partition) val = DontCare;
  272 + else if (partitionB != partition) val = DontCare;
266 else if (labelA == labelB) val = Match; 273 else if (labelA == labelB) val = Match;
267 else val = NonMatch; 274 else val = NonMatch;
268 mask.at<Mask_t>(i,j) = val; 275 mask.at<Mask_t>(i,j) = val;
openbr/core/bee.h
@@ -46,7 +46,7 @@ namespace BEE @@ -46,7 +46,7 @@ namespace BEE
46 46
47 // Write BEE files 47 // Write BEE files
48 void makeMask(const QString &targetInput, const QString &queryInput, const QString &mask); 48 void makeMask(const QString &targetInput, const QString &queryInput, const QString &mask);
49 - cv::Mat makeMask(const br::FileList &targets, const br::FileList &queries); 49 + cv::Mat makeMask(const br::FileList &targets, const br::FileList &queries, int partition = 0);
50 void combineMasks(const QStringList &inputMasks, const QString &outputMask, const QString &method); 50 void combineMasks(const QStringList &inputMasks, const QString &outputMask, const QString &method);
51 } 51 }
52 52
openbr/core/plot.cpp
@@ -34,6 +34,7 @@ @@ -34,6 +34,7 @@
34 #include "version.h" 34 #include "version.h"
35 #include "openbr/core/bee.h" 35 #include "openbr/core/bee.h"
36 #include "openbr/core/common.h" 36 #include "openbr/core/common.h"
  37 +#include "openbr/core/opencvutils.h"
37 #include "openbr/core/qtutils.h" 38 #include "openbr/core/qtutils.h"
38 39
39 #undef FAR // Windows preprecessor definition 40 #undef FAR // Windows preprecessor definition
@@ -260,8 +261,10 @@ float Evaluate(const Mat &amp;simmat, const Mat &amp;mask, const QString &amp;csv) @@ -260,8 +261,10 @@ float Evaluate(const Mat &amp;simmat, const Mat &amp;mask, const QString &amp;csv)
260 for (int i=1; i<=Max_Retrieval; i++) { 261 for (int i=1; i<=Max_Retrieval; i++) {
261 int realizedReturns = 0, possibleReturns = 0; 262 int realizedReturns = 0, possibleReturns = 0;
262 foreach (int firstGenuineReturn, firstGenuineReturns) { 263 foreach (int firstGenuineReturn, firstGenuineReturns) {
263 - if (firstGenuineReturn > 0) possibleReturns++;  
264 - if (firstGenuineReturn <= i) realizedReturns++; 264 + if (firstGenuineReturn > 0) {
  265 + possibleReturns++;
  266 + if (firstGenuineReturn <= i) realizedReturns++;
  267 + }
265 } 268 }
266 const float retrievalRate = float(realizedReturns)/possibleReturns; 269 const float retrievalRate = float(realizedReturns)/possibleReturns;
267 lines.append(qPrintable(QString("CMC,%1,%2").arg(QString::number(i), QString::number(retrievalRate)))); 270 lines.append(qPrintable(QString("CMC,%1,%2").arg(QString::number(i), QString::number(retrievalRate))));
openbr/plugins/validate.cpp
@@ -40,8 +40,10 @@ class CrossValidateTransform : public MetaTransform @@ -40,8 +40,10 @@ class CrossValidateTransform : public MetaTransform
40 for (int i=0; i<numPartitions; i++) { 40 for (int i=0; i<numPartitions; i++) {
41 TemplateList partitionedData = data; 41 TemplateList partitionedData = data;
42 for (int j=partitionedData.size()-1; j>=0; j--) 42 for (int j=partitionedData.size()-1; j>=0; j--)
  43 + // Remove all templates from partition i
43 if (partitions[j] == i) 44 if (partitions[j] == i)
44 partitionedData.removeAt(j); 45 partitionedData.removeAt(j);
  46 + // Train on the remaining templates
45 if (Globals->parallelism) futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); 47 if (Globals->parallelism) futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData));
46 else transforms[i]->train(partitionedData); 48 else transforms[i]->train(partitionedData);
47 } 49 }