diff --git a/openbr/openbr_plugin.cpp b/openbr/openbr_plugin.cpp index 1740bde..f073913 100644 --- a/openbr/openbr_plugin.cpp +++ b/openbr/openbr_plugin.cpp @@ -408,7 +408,7 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) QStringList labels; for (int i=newTemplates.size()-1; i>=0; i--) { newTemplates[i].file.set("Index", i+templates.size()); - newTemplates[i].file.set("Gallery", gallery.name); + newTemplates[i].file.set("Gallery", file.name); QString label = newTemplates.at(i).file.get("Label"); // Have we seen this subject before? @@ -436,7 +436,7 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) } else { for (int i=newTemplates.size()-1; i>=0; i--) { newTemplates[i].file.set("Index", i+templates.size()); - newTemplates[i].file.set("Gallery", gallery.name); + newTemplates[i].file.set("Gallery", file.name); if (crossValidate > 0) { if (newTemplates[i].file.getBool("duplicatePartitions")) { diff --git a/openbr/plugins/independent.cpp b/openbr/plugins/independent.cpp index 156cc75..162566f 100644 --- a/openbr/plugins/independent.cpp +++ b/openbr/plugins/independent.cpp @@ -9,12 +9,13 @@ using namespace cv; namespace br { -static TemplateList Downsample(const TemplateList &templates, int classes, int instances, float fraction, const QString & inputVariable) +static TemplateList Downsample(const TemplateList &templates, int classes, int instances, float fraction, const QString & inputVariable, const QString &gallery) { // Return early when no downsampling is required if ((classes == std::numeric_limits::max()) && (instances == std::numeric_limits::max()) && - (fraction >= 1)) + (fraction >= 1) && + (gallery.isEmpty())) return templates; const bool atLeast = instances < 0; @@ -60,6 +61,11 @@ static TemplateList Downsample(const TemplateList &templates, int classes, int i downsample = downsample.mid(0, downsample.size()*fraction); } + if (!gallery.isEmpty()) + foreach(const Template &t, templates) + if (t.file.get("Gallery") == gallery) + downsample.append(t); + return downsample; } @@ -71,11 +77,13 @@ class DownsampleTrainingTransform : public Transform Q_PROPERTY(int instances READ get_instances WRITE set_instances RESET reset_instances STORED false) Q_PROPERTY(float fraction READ get_fraction WRITE set_fraction RESET reset_fraction STORED false) Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) + Q_PROPERTY(QString gallery READ get_gallery WRITE set_gallery RESET reset_gallery STORED false) BR_PROPERTY(br::Transform*, transform, NULL) BR_PROPERTY(int, classes, std::numeric_limits::max()) BR_PROPERTY(int, instances, std::numeric_limits::max()) BR_PROPERTY(float, fraction, 1) BR_PROPERTY(QString, inputVariable, "Label") + BR_PROPERTY(QString, gallery, QString()) void project(const Template & src, Template & dst) const { @@ -88,7 +96,7 @@ class DownsampleTrainingTransform : public Transform if (!transform || !transform->trainable) return; - TemplateList downsampled = Downsample(data, classes, instances, fraction, inputVariable); + TemplateList downsampled = Downsample(data, classes, instances, fraction, inputVariable, gallery); transform->train(downsampled); } };