From 66001a7191876851e6420a60b84146aea62249a7 Mon Sep 17 00:00:00 2001 From: Josh Klontz Date: Mon, 11 Feb 2013 11:23:13 -0500 Subject: [PATCH] implemented PipeDistance --- sdk/openbr_plugin.cpp | 35 ++++++++++++++++++++++++++++------- sdk/openbr_plugin.h | 1 + sdk/plugins/algorithms.cpp | 4 ++-- sdk/plugins/distance.cpp | 40 +++++++++++++++++++++++++++++++++++++++- sdk/plugins/quality.cpp | 31 ------------------------------- sdk/plugins/validate.cpp | 35 ++++++++++++++++++++++++++--------- 6 files changed, 96 insertions(+), 50 deletions(-) diff --git a/sdk/openbr_plugin.cpp b/sdk/openbr_plugin.cpp index f6f0345..45f5dd9 100644 --- a/sdk/openbr_plugin.cpp +++ b/sdk/openbr_plugin.cpp @@ -522,6 +522,9 @@ QString Object::argument(int index) const } else if (type == "QList") { foreach (Transform *transform, variant.value< QList >()) strings.append(transform->description()); + } else if (type == "QList") { + foreach (Distance *distance, variant.value< QList >()) + strings.append(distance->description()); } else { qFatal("Unrecognized type: %s", qPrintable(type)); } @@ -556,6 +559,9 @@ void Object::store(QDataStream &stream) const if (type == "QList") { foreach (Transform *transform, property.read(this).value< QList >()) transform->store(stream); + } else if (type == "QList") { + foreach (Distance *distance, property.read(this).value< QList >()) + distance->store(stream); } else if (type == "br::Transform*") { property.read(this).value()->store(stream); } else if (type == "br::Distance*") { @@ -590,6 +596,9 @@ void Object::load(QDataStream &stream) if (type == "QList") { foreach (Transform *transform, property.read(this).value< QList >()) transform->load(stream); + } else if (type == "QList") { + foreach (Distance *distance, property.read(this).value< QList >()) + distance->load(stream); } else if (type == "br::Transform*") { property.read(this).value()->load(stream); } else if (type == "br::Distance*") { @@ -653,6 +662,11 @@ void Object::setProperty(const QString &name, const QString &value) foreach (const QString &string, strings) values.append(Transform::make(string, this)); variant.setValue(values); + } else if (type == "QList") { + QList values; + foreach (const QString &string, strings) + values.append(Distance::make(string, this)); + variant.setValue(values); } else { qFatal("Unrecognized type: %s", qPrintable(type)); } @@ -828,6 +842,7 @@ void br::Context::initializeQt(QString sdkPath) qRegisterMetaType< br::Transform* >(); qRegisterMetaType< QList >(); qRegisterMetaType< br::Distance* >(); + qRegisterMetaType< QList >(); qRegisterMetaType< cv::Mat >(); qInstallMsgHandler(messageHandler); @@ -1311,15 +1326,21 @@ void Transform::backProject(const TemplateList &dst, TemplateList &src) const /* Distance - public methods */ Distance *Distance::make(QString str, QObject *parent) { - // Check for custom transforms - if (Globals->abbreviations.contains(str)) - return make(Globals->abbreviations[str], parent); + // Check for custom transforms + if (Globals->abbreviations.contains(str)) + return make(Globals->abbreviations[str], parent); + + { // Check for use of '+' as shorthand for Pipe(...) + QStringList words = parse(str, '+'); + if (words.size() > 1) + return make("Pipe([" + words.join(",") + "])", parent); + } - File f = "." + str; - Distance *distance = Factory::make(f); + File f = "." + str; + Distance *distance = Factory::make(f); - distance->setParent(parent); - return distance; + distance->setParent(parent); + return distance; } void Distance::compare(const TemplateList &target, const TemplateList &query, Output *output) const diff --git a/sdk/openbr_plugin.h b/sdk/openbr_plugin.h index 8fd4ba0..976cda3 100644 --- a/sdk/openbr_plugin.h +++ b/sdk/openbr_plugin.h @@ -1136,6 +1136,7 @@ Q_DECLARE_METATYPE(QList) Q_DECLARE_METATYPE(br::Transform*) Q_DECLARE_METATYPE(QList) Q_DECLARE_METATYPE(br::Distance*) +Q_DECLARE_METATYPE(QList) Q_DECLARE_METATYPE(cv::Mat) #endif // __OPENBR_PLUGIN_H diff --git a/sdk/plugins/algorithms.cpp b/sdk/plugins/algorithms.cpp index a4a140b..b4d53b5 100644 --- a/sdk/plugins/algorithms.cpp +++ b/sdk/plugins/algorithms.cpp @@ -32,7 +32,7 @@ class AlgorithmsInitializer : public Initializer { // Face Globals->abbreviations.insert("FaceRecognition", "FaceDetection!!++:UCharL1"); - Globals->abbreviations.insert("FaceRecognitionNoTraining", "FaceDetection!ASEFEyes+Affine(86,86,0.25,0.35)!Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+Mask+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59)+Cat:Dist(ChiSquared)"); + Globals->abbreviations.insert("FaceRecognitionNoTraining", "FaceDetection!ASEFEyes+Affine(86,86,0.25,0.35)!Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+Mask+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59)+Cat:ChiSquared"); Globals->abbreviations.insert("GenderClassification", "FaceDetection!!++Discard"); Globals->abbreviations.insert("AgeRegression", "FaceDetection!!++Discard"); Globals->abbreviations.insert("FaceQuality", "Open!Cascade(FrontalFace)+ASEFEyes+Affine(64,64,0.25,0.35)+ImageQuality+Cvt(Gray)+DFFS+Discard"); @@ -45,7 +45,7 @@ class AlgorithmsInitializer : public Initializer Globals->abbreviations.insert("SURF", "Open+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)"); Globals->abbreviations.insert("SmallSIFT", "Open+LimitSize(512)+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); Globals->abbreviations.insert("SmallSURF", "Open+LimitSize(512)+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)"); - Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):Dist(L2)"); + Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):L2"); // Hash Globals->abbreviations.insert("FileName", "Name+Identity:Identical"); diff --git a/sdk/plugins/distance.cpp b/sdk/plugins/distance.cpp index e7030b4..e9a6849 100644 --- a/sdk/plugins/distance.cpp +++ b/sdk/plugins/distance.cpp @@ -14,6 +14,7 @@ * limitations under the License. * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ +#include #include #include @@ -101,7 +102,6 @@ private: for (int col=0; col(row,col); const float query = b.at(row,col); - dot += target * query; magA += target * target; magB += query * query; @@ -139,6 +139,44 @@ BR_REGISTER(Distance, DefaultDistance) /*! * \ingroup distances + * \brief Distances in series. + * \author Josh Klontz \cite jklontz + * + * The templates are compared using each br::Distance in order. + * If the result of the comparison with any given distance is -std::numeric_limits::max() then this result is returned early. + * Otherwise the returned result is the value of comparing the templates using the last br::Distance. + */ +class PipeDistance : public Distance +{ + Q_OBJECT + Q_PROPERTY(QList distances READ get_distances WRITE set_distances RESET reset_distances) + BR_PROPERTY(QList, distances, QList()) + + void train(const TemplateList &data) + { + QList< QFuture > futures; + foreach (br::Distance *distance, distances) + if (Globals->parallelism) futures.append(QtConcurrent::run(distance, &Distance::train, data)); + else distance->train(data); + Globals->trackFutures(futures); + } + + float compare(const Template &a, const Template &b) const + { + float result = -std::numeric_limits::max(); + foreach (br::Distance *distance, distances) { + result = distance->compare(a, b); + if (result == -std::numeric_limits::max()) + return result; + } + return result; + } +}; + +BR_REGISTER(Distance, PipeDistance) + +/*! + * \ingroup distances * \brief Fast 8-bit L1 distance * \author Josh Klontz \cite jklontz */ diff --git a/sdk/plugins/quality.cpp b/sdk/plugins/quality.cpp index 3a55e6a..a9a42ee 100644 --- a/sdk/plugins/quality.cpp +++ b/sdk/plugins/quality.cpp @@ -260,37 +260,6 @@ class UnitDistance : public Distance BR_REGISTER(Distance, UnitDistance) -/*! - * \ingroup distances - * \brief Check target metadata before matching templates. - * \author Josh Klontz \cite jklontz - */ -class MetadataDistance : public Distance -{ - Q_OBJECT - Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance) - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) - - void train(const TemplateList &src) - { - distance->train(src); - } - - float compare(const Template &a, const Template &b) const - { - foreach (const QString &filter, Globals->demographicFilters.keys()) { - const QString metadata = a.file.getString(filter, ""); - if (metadata.isEmpty()) continue; - const QRegExp re(Globals->demographicFilters[filter]); - if (re.indexIn(metadata) == -1) - return -std::numeric_limits::max(); - } - return distance->compare(a, b); - } -}; - -BR_REGISTER(Distance, MetadataDistance) - } // namespace br #include "quality.moc" diff --git a/sdk/plugins/validate.cpp b/sdk/plugins/validate.cpp index d11c50c..ccea539 100644 --- a/sdk/plugins/validate.cpp +++ b/sdk/plugins/validate.cpp @@ -79,25 +79,42 @@ BR_REGISTER(Transform, CrossValidateTransform) class CrossValidateDistance : public Distance { Q_OBJECT - Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance) - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) - - void train(const TemplateList &src) - { - distance->train(src); - } float compare(const Template &a, const Template &b) const { const int partitionA = a.file.getInt("Cross_Validation_Partition", 0); const int partitionB = b.file.getInt("Cross_Validation_Partition", 0); - if (partitionA != partitionB) return -std::numeric_limits::max(); - return distance->compare(a, b); + return (partitionA != partitionB) ? -std::numeric_limits::max() : 0; } }; BR_REGISTER(Distance, CrossValidateDistance) +/*! + * \ingroup distances + * \brief Checks target metadata. + * \author Josh Klontz \cite jklontz + */ +class MetadataDistance : public Distance +{ + Q_OBJECT + + float compare(const Template &a, const Template &b) const + { + (void) b; + foreach (const QString &filter, Globals->demographicFilters.keys()) { + const QString metadata = a.file.getString(filter, ""); + if (metadata.isEmpty()) continue; + const QRegExp re(Globals->demographicFilters[filter]); + if (re.indexIn(metadata) == -1) + return -std::numeric_limits::max(); + } + return 0; + } +}; + +BR_REGISTER(Distance, MetadataDistance) + } // namespace br #include "validate.moc" -- libgit2 0.21.4