Commit fae02bb07cd91a120687b0fac4ad452739d20ad5

Authored by Scott Klum
1 parent 0be609a0

Added duplicatePartitions/allPartitions distinction, heatmap distance

openbr/core/bee.cpp
... ... @@ -265,9 +265,6 @@ cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries,
265 265 QList<int> targetPartitions = targets.crossValidationPartitions();
266 266 QList<int> queryPartitions = queries.crossValidationPartitions();
267 267  
268   - for (int i = 0; i < 5; i++) qDebug() << "QueryPartition " << queries[i].fileName() << ": " << queryPartitions[i];
269   - for (int i = 0; i < 5; i++) qDebug() << "TargetPartition " << targets[i].fileName() << ": " << targetPartitions[i];
270   -
271 268 Mat mask(queries.size(), targets.size(), CV_8UC1);
272 269 for (int i=0; i<queries.size(); i++) {
273 270 const QString &fileA = queries[i];
... ... @@ -283,9 +280,10 @@ cv::Mat BEE::makeMask(const br::FileList &amp;targets, const br::FileList &amp;queries,
283 280 if (fileA == fileB) val = DontCare;
284 281 else if (labelA == -1) val = DontCare;
285 282 else if (labelB == -1) val = DontCare;
286   - else if (partitionA != partitionB) val = DontCare;
287 283 else if (partitionA != partition) val = DontCare;
  284 + else if (partitionB == -1) val = NonMatch;
288 285 else if (partitionB != partition) val = DontCare;
  286 + else if (partitionA != partitionB) val = DontCare;
289 287 else if (labelA == labelB) val = Match;
290 288 else val = NonMatch;
291 289 mask.at<Mask_t>(i,j) = val;
... ...
openbr/core/plot.cpp
... ... @@ -269,7 +269,7 @@ float Evaluate(const Mat &amp;simmat, const Mat &amp;mask, const QString &amp;csv)
269 269 }
270 270  
271 271 // Write Cumulative Match Characteristic (CMC) curve
272   - const int Max_Retrieval = 200;
  272 + const int Max_Retrieval = 100;
273 273 const int Report_Retrieval = 5;
274 274  
275 275 float reportRetrievalRate = -1;
... ... @@ -507,11 +507,11 @@ bool Plot(const QStringList &amp;files, const br::File &amp;destination, bool show)
507 507 QString(" + theme(aspect.ratio=1)\n\n")));
508 508  
509 509 p.file.write(qPrintable(QString("ggplot(CMC, aes(x=X, y=Y%1%2)) + xlab(\"Rank\") + ylab(\"Retrieval Rate\")").arg(p.major.size > 1 ? QString(" ,colour=factor(%1)").arg(p.major.header) : QString(), p.minor.size > 1 ? QString(", linetype=factor(%1)").arg(p.minor.header) : QString()) +
510   - ((p.major.smooth || p.minor.smooth) ? (minimalist ? " + stat_summary(geom=\"line\", fun.y=mean)" : " + stat_summary(geom=\"line\", fun.y=min, aes(linetype=\"Min/Max\")) + stat_summary(geom=\"line\", fun.y=max, aes(linetype=\"Min/Max\")) + stat_summary(geom=\"line\", fun.y=mean, aes(linetype=\"Mean\")) + scale_linetype_manual(\"Legend\", values=c(\"Mean\"=1, \"Min/Max\"=2))") : " + geom_line()") +
  510 + ((p.major.smooth || p.minor.smooth) ? (minimalist ? " + stat_summary(geom=\"line\", lwd=3, fun.y=mean)" : " + stat_summary(geom=\"line\", fun.y=min, aes(linetype=\"Min/Max\")) + stat_summary(geom=\"line\", fun.y=max, aes(linetype=\"Min/Max\")) + stat_summary(geom=\"line\", fun.y=mean, aes(linetype=\"Mean\")) + scale_linetype_manual(\"Legend\", values=c(\"Mean\"=1, \"Min/Max\"=2))") : " + geom_line()") +
511 511 (minimalist ? "" : " + scale_x_log10(labels=c(1,5,10,50,100), breaks=c(1,5,10,50,100)) + annotation_logticks(sides=\"b\")") +
512 512 (p.major.size > 1 ? getScale("colour", p.major.header, p.major.size) : QString()) +
513 513 (p.minor.size > 1 ? QString(" + scale_linetype_discrete(\"%1\")").arg(p.minor.header) : QString()) +
514   - QString(" + theme_minimal() + scale_y_continuous(labels=percent)\n\n")));
  514 + QString(" + theme_blank() + scale_y_continuous(labels=percent)\n\n")));
515 515  
516 516 p.file.write(qPrintable(QString("qplot(factor(%1)%2, data=BC, %3").arg(p.major.smooth ? (p.minor.header.isEmpty() ? "Algorithm" : p.minor.header) : p.major.header, (p.major.smooth || p.minor.smooth) ? ", Y" : "", (p.major.smooth || p.minor.smooth) ? "geom=\"boxplot\"" : "geom=\"bar\", position=\"dodge\", weight=Y") +
517 517 (p.major.size > 1 ? QString(", fill=factor(%1)").arg(p.major.header) : QString()) +
... ...
openbr/openbr_plugin.cpp
... ... @@ -401,7 +401,6 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery)
401 401 {
402 402 TemplateList templates;
403 403 foreach (const br::File &file, gallery.split()) {
404   - qDebug() << file.name;
405 404 QScopedPointer<Gallery> i(Gallery::make(file));
406 405 TemplateList newTemplates = i->read();
407 406  
... ... @@ -432,7 +431,11 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery)
432 431 newTemplates[i].file.set("Gallery", gallery.name);
433 432  
434 433 if (crossValidate > 0) {
435   - if (newTemplates[i].file.getBool("allPartitions")) {
  434 + if (newTemplates[i].file.getBool("duplicatePartitions")) {
  435 + // The duplicatePartitions flag is used to add target images
  436 + // crossValidate times to the simmat/mask
  437 + // when multiple training sets are being used
  438 +
436 439 // Set template to the first parition
437 440 newTemplates[i].file.set("Partition", QVariant(0));
438 441  
... ... @@ -442,6 +445,11 @@ TemplateList TemplateList::fromGallery(const br::File &amp;gallery)
442 445 allPartitionTemplate.file.set("Partition", j);
443 446 newTemplates.insert(i+1, allPartitionTemplate);
444 447 }
  448 + } else if (newTemplates[i].file.getBool("allPartitions")) {
  449 + // The allPartitions flag is used to add an extended set
  450 + // of target images to every partition
  451 +
  452 + newTemplates[i].file.set("Partition", -1);
445 453 } else {
446 454 const QByteArray md5 = QCryptographicHash::hash(newTemplates[i].file.subject().toLatin1(), QCryptographicHash::Md5);
447 455 // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow
... ...
openbr/plugins/algorithms.cpp
... ... @@ -52,6 +52,7 @@ class AlgorithmsInitializer : public Initializer
52 52 Globals->abbreviations.insert("SmallSURF", "Open+LimitSize(512)+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)");
53 53 Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):L2");
54 54 Globals->abbreviations.insert("ImageClassification", "Open+CropSquare+LimitSize(256)+Cvt(Gray)+Gradient+Bin(0,360,9,true)+Merge+Integral+RecursiveIntegralSampler(4,2,8,Singleton(KMeans(256)))+Cat+CvtFloat+Hist(256)+KNN(5,Dist(L1),false,5)+Rename(KNN,Subject)");
  55 + Globals->abbreviations.insert("TanTriggs", "Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)");
55 56  
56 57 // Hash
57 58 Globals->abbreviations.insert("FileName", "Name+Identity:Identical");
... ...
openbr/plugins/distance.cpp
... ... @@ -198,7 +198,6 @@ class AverageDistance : public Distance
198 198  
199 199 float score = 0;
200 200 for (int i = 0; i < a.size(); i++) {
201   - qDebug() << "Computing score for: " << a.file.name << " vs. " << b.file.name;
202 201 score += distance->compare(a[i],b[i]);
203 202 }
204 203  
... ... @@ -295,10 +294,52 @@ class IdenticalDistance : public Distance
295 294 if (am.data[i] != bm.data[i]) return 0;
296 295 return 1;
297 296 }
298   -};
  297 +};
299 298  
300 299 BR_REGISTER(Distance, IdenticalDistance)
301 300  
  301 +class HeatMapDistance : public Distance
  302 +{
  303 + Q_OBJECT
  304 + Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
  305 + BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))
  306 + Q_PROPERTY(int rowSize READ get_rowSize WRITE set_rowSize RESET reset_rowSize STORED false)
  307 + BR_PROPERTY(int, rowSize, 1)
  308 +
  309 + void train(const TemplateList &src)
  310 + {
  311 + distance->train(src);
  312 + }
  313 +
  314 +
  315 + float compare(const Template &a, const Template &b) const
  316 + {
  317 + qFatal("HeatMap expects a TemplateList");
  318 +
  319 + (void) a; (void) b;
  320 + }
  321 +
  322 + void compare(const TemplateList &target, const TemplateList &query, Output *output) const
  323 + {
  324 + int i = 0;
  325 + int j = 0;
  326 + for (int index = 0; index < target.size(); index++) {
  327 + float score = distance->compare(target[index],query[index]);
  328 +
  329 + if (j >= rowSize) {
  330 + i++;
  331 + j = 0;
  332 + }
  333 +
  334 + output->setRelative(score, i, j);
  335 +
  336 + j++;
  337 + }
  338 + }
  339 +};
  340 +
  341 +BR_REGISTER(Distance, HeatMapDistance)
  342 +
302 343 } // namespace br
303 344  
304 345 #include "distance.moc"
... ...
openbr/plugins/regions.cpp
... ... @@ -228,7 +228,6 @@ class RectFromPointsTransform : public UntrainableTransform
228 228 }
229 229 }
230 230  
231   - // Padding is .05
232 231 double width = maxX-minX;
233 232 double deltaWidth = width*padding;
234 233 width += deltaWidth;
... ...
openbr/plugins/stasm.cpp
... ... @@ -21,7 +21,7 @@ class StasmInitializer : public Initializer
21 21 Globals->abbreviations.insert("RectFromStasmEyes","RectFromPoints([27, 28, 29, 30, 31, 32, 33, 34, 35, 36],0.125,6.0)+Resize(44,164)"); //
22 22 Globals->abbreviations.insert("RectFromStasmJaw","RectFromPoints([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],10)");
23 23 Globals->abbreviations.insert("RectFromStasmBrow","RectFromPoints([15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26],0.25,6.5)+Resize(44,230)");
24   - Globals->abbreviations.insert("RectFromStasmNose","RectFromPoints([38, 39, 40, 41, 42, 43, 44, 67],0.1,1.5)+Resize(44,44)");
  24 + Globals->abbreviations.insert("RectFromStasmNose","RectFromPoints([38, 39, 40, 41, 42, 43, 44, 67],0.15,1.25)+Resize(44,44)");
25 25 Globals->abbreviations.insert("RectFromStasmMouth","RectFromPoints([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66],0.3,3.0)+Resize(26,68)");
26 26 }
27 27 };
... ... @@ -60,7 +60,6 @@ class StasmTransform : public UntrainableTransform
60 60 }
61 61  
62 62 for (int i = 0; i < numLandmarks; i++) {
63   - qDebug() << QPointF(landmarks[2 * i], landmarks[2 * i + 1]);
64 63 dst.file.appendPoint(QPointF(landmarks[2 * i], landmarks[2 * i + 1]));
65 64 }
66 65  
... ...