diff --git a/sdk/core/core.cpp b/sdk/core/core.cpp index 3db24ed..b3a1940 100644 --- a/sdk/core/core.cpp +++ b/sdk/core/core.cpp @@ -252,7 +252,7 @@ private: if (words.size() > 2) qFatal("AlgorithmCore::init invalid algorithm format."); transform = QSharedPointer(Transform::make(words[0], NULL)); - if (words.size() > 1) distance = QSharedPointer(Factory::make("." + words[1])); + if (words.size() > 1) distance = QSharedPointer(Distance::make(words[1], NULL)); } }; diff --git a/sdk/openbr_plugin.cpp b/sdk/openbr_plugin.cpp index ec78665..13b3f84 100644 --- a/sdk/openbr_plugin.cpp +++ b/sdk/openbr_plugin.cpp @@ -497,6 +497,8 @@ QString Object::argument(int index) const return "[" + strings.join(",") + "]"; } else if (type == "br::Transform*") { return variant.value()->description(); + } else if (type == "br::Distance*") { + return variant.value()->description(); } else if (type == "QStringList") { return "[" + variant.toStringList().join(",") + "]"; } @@ -524,6 +526,8 @@ void Object::store(QDataStream &stream) const transform->store(stream); } else if (type == "br::Transform*") { property.read(this).value()->store(stream); + } else if (type == "br::Distance*") { + property.read(this).value()->store(stream); } else if (type == "bool") { stream << property.read(this).toBool(); } else if (type == "int") { @@ -556,6 +560,8 @@ void Object::load(QDataStream &stream) transform->load(stream); } else if (type == "br::Transform*") { property.read(this).value()->load(stream); + } else if (type == "br::Distance*") { + property.read(this).value()->load(stream); } else if (type == "bool") { bool value; stream >> value; @@ -620,6 +626,8 @@ void Object::setProperty(const QString &name, const QString &value) } } else if (type == "br::Transform*") { variant.setValue(Transform::make(value, this)); + } else if (type == "br::Distance*") { + variant.setValue(Distance::make(value, this)); } else if (type == "QStringList") { variant.setValue(parse(value.mid(1, value.size()-2))); } else if (type == "bool") { @@ -630,6 +638,7 @@ void Object::setProperty(const QString &name, const QString &value) } else { variant = value; } + if (!QObject::setProperty(qPrintable(name), variant) && !type.isEmpty()) qFatal("Failed to set %s::%s to: %s %s", metaObject()->className(), qPrintable(name), qPrintable(value), qPrintable(type)); @@ -786,6 +795,7 @@ void br::Context::initializeQt(QString sdkPath) qRegisterMetaType< QList >(); qRegisterMetaType< br::Transform* >(); qRegisterMetaType< QList >(); + qRegisterMetaType< br::Distance* >(); qRegisterMetaType< cv::Mat >(); qInstallMsgHandler(messageHandler); @@ -1270,42 +1280,17 @@ void Transform::backProject(const TemplateList &dst, TemplateList &src) const } /* Distance - public methods */ -void Distance::train(const TemplateList &templates) -{ - const TemplateList samples = templates.mid(0, 2000); - const QList sampleLabels = samples.labels(); - QScopedPointer memoryOutput(dynamic_cast(Output::make(".Matrix", FileList(samples.size()), FileList(samples.size())))); - compare(samples, samples, memoryOutput.data()); - - double genuineAccumulator, impostorAccumulator; - int genuineCount, impostorCount; - genuineAccumulator = impostorAccumulator = genuineCount = impostorCount = 0; - - for (int i=0; idata.at(i, j); - if (sampleLabels[i] == sampleLabels[j]) { - genuineAccumulator += val; - genuineCount++; - } else { - impostorAccumulator += val; - impostorCount++; - } - } - } - - if (genuineCount == 0) { qWarning("No genuine matches."); return; } - if (impostorCount == 0) { qWarning("No impostor matches."); return; } - - double genuineMean = genuineAccumulator / genuineCount; - double impostorMean = impostorAccumulator / impostorCount; - - if (genuineMean == impostorMean) { qWarning("Genuines and impostors are indistinguishable."); return; } +Distance *Distance::make(QString str, QObject *parent) +{ + // Check for custom transforms + if (Globals->abbreviations.contains(str)) + return make(Globals->abbreviations[str], parent); - a = 1.0/(genuineMean-impostorMean); - b = impostorMean; + File f = "." + str; + Distance *distance = Factory::make(f); - qDebug("a = %f, b = %f", a, b); + distance->setParent(parent); + return distance; } void Distance::compare(const TemplateList &target, const TemplateList &query, Output *output) const @@ -1336,7 +1321,7 @@ float Distance::compare(const Template &target, const Template &query) const return -std::numeric_limits::max(); } - return a * (_compare(target, query) - b); + return _compare(target, query); } QList Distance::compare(const TemplateList &targets, const Template &query) const diff --git a/sdk/openbr_plugin.h b/sdk/openbr_plugin.h index a045a63..84baad1 100644 --- a/sdk/openbr_plugin.h +++ b/sdk/openbr_plugin.h @@ -971,7 +971,6 @@ public: protected: Transform(bool independent = true); /*!< \brief Construct a transform. */ inline Transform *make(const QString &description) { return make(description, this); } /*!< \brief Make a subtransform. */ - }; /*! @@ -1065,19 +1064,19 @@ class BR_EXPORT Distance : public Object { Q_OBJECT - // Score normalization - Q_PROPERTY(float a READ get_a WRITE set_a RESET reset_a) - Q_PROPERTY(float b READ get_b WRITE set_b RESET reset_b) - BR_PROPERTY(float, a, 1) - BR_PROPERTY(float, b, 0) - public: + virtual ~Distance() {} + static Distance *make(QString str, QObject *parent); /*!< \brief Make a distance from a string. */ + static QSharedPointer fromAlgorithm(const QString &algorithm); /*!< \brief Retrieve an algorithm's distance. */ - virtual void train(const TemplateList &src); /*!< \brief Train the distance. */ + virtual void train(const TemplateList &src) { (void) src; } /*!< \brief Train the distance. */ virtual void compare(const TemplateList &target, const TemplateList &query, Output *output) const; /*!< \brief Compare two template lists. */ float compare(const Template &target, const Template &query) const; /*!< \brief Compute the normalized distance between two templates. */ QList compare(const TemplateList &targets, const Template &query) const; /*!< \brief Compute the normalized distance between a template and a template list. */ +protected: + inline Distance *make(const QString &description) { return make(description, this); } /*!< \brief Make a subdistance. */ + private: virtual void compareBlock(const TemplateList &target, const TemplateList &query, Output *output, int targetOffset, int queryOffset) const; virtual float _compare(const Template &a, const Template &b) const = 0; /*!< \brief Compute the distance between two templates. */ @@ -1132,6 +1131,7 @@ Q_DECLARE_METATYPE(QList) Q_DECLARE_METATYPE(QList) Q_DECLARE_METATYPE(br::Transform*) Q_DECLARE_METATYPE(QList) +Q_DECLARE_METATYPE(br::Distance*) Q_DECLARE_METATYPE(cv::Mat) #endif // __OPENBR_PLUGIN_H diff --git a/sdk/plugins/algorithms.cpp b/sdk/plugins/algorithms.cpp index d25208d..c841fdf 100644 --- a/sdk/plugins/algorithms.cpp +++ b/sdk/plugins/algorithms.cpp @@ -70,6 +70,7 @@ class AlgorithmsInitializer : public Initializer Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+PCA(0.95,instances=-1)+Cat)"); Globals->abbreviations.insert("AgeRegressor", "Center(Range,instances=-1)+SVM(RBF,EPS_SVR,instances=100)"); Globals->abbreviations.insert("GenderClassifier", "Center(Range,instances=-1)+SVM(RBF,C_SVC,instances=4000)"); + Globals->abbreviations.insert("UCharL1", "Unit(ByteL1)"); } }; diff --git a/sdk/plugins/compare.cpp b/sdk/plugins/compare.cpp index 113e57f..3a6baf2 100644 --- a/sdk/plugins/compare.cpp +++ b/sdk/plugins/compare.cpp @@ -119,7 +119,7 @@ BR_REGISTER(Distance, DistDistance) * \brief Fast 8-bit L1 distance * \author Josh Klontz \cite jklontz */ -class UCharL1Distance : public Distance +class ByteL1Distance : public Distance { Q_OBJECT @@ -129,7 +129,7 @@ class UCharL1Distance : public Distance } }; -BR_REGISTER(Distance, UCharL1Distance) +BR_REGISTER(Distance, ByteL1Distance) /*! @@ -137,7 +137,7 @@ BR_REGISTER(Distance, UCharL1Distance) * \brief Fast 4-bit L1 distance * \author Josh Klontz \cite jklontz */ -class PackedUCharL1Distance : public Distance +class HalfByteL1Distance : public Distance { Q_OBJECT @@ -147,7 +147,7 @@ class PackedUCharL1Distance : public Distance } }; -BR_REGISTER(Distance, PackedUCharL1Distance) +BR_REGISTER(Distance, HalfByteL1Distance) /*! * \ingroup distances diff --git a/sdk/plugins/quality.cpp b/sdk/plugins/quality.cpp index 3c3f332..796bb86 100644 --- a/sdk/plugins/quality.cpp +++ b/sdk/plugins/quality.cpp @@ -18,14 +18,20 @@ class IUMTransform : public Transform Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) Q_PROPERTY(double mean READ get_mean WRITE set_mean RESET reset_mean) Q_PROPERTY(double stddev READ get_stddev WRITE set_stddev RESET reset_stddev) - BR_PROPERTY(br::Distance*, distance, Factory::make(".Dist(L2)")) + BR_PROPERTY(br::Distance*, distance, Distance::make("Dist(L2)", this)) BR_PROPERTY(double, mean, 0) BR_PROPERTY(double, stddev, 1) br::TemplateList impostors; float calculateIUM(const Template &probe, const TemplateList &gallery) const { - QList scores = distance->compare(gallery, probe); + const int probeLabel = probe.file.label(); + TemplateList subset = gallery; + for (int j=subset.size()-1; j>=0; j--) + if (subset[j].file.label() == probeLabel) + subset.removeAt(j); + + QList scores = distance->compare(subset, probe); float min, max; Common::MinMax(scores, &min, &max); double mean; @@ -39,14 +45,8 @@ class IUMTransform : public Transform impostors = data; QList iums; iums.reserve(impostors.size()); - QList labels = impostors.labels(); - for (int i=0; i=0; j--) - if (labels[j] == labels[i]) - subset.removeAt(j); - iums.append(calculateIUM(impostors[i], subset)); - } + for (int i=0; i bins; + + KDE() : min(0), max(1) {} + KDE(const QList &scores) + { + Common::MinMax(scores, &min, &max); + Common::MeanStdDev(scores, &mean, &stddev); + double h = Common::KernelDensityBandwidth(scores); + const int size = 255; + bins.reserve(size); + for (int i=0; i= max) return bins.last(); + const float x = (score-min)/(max-min)*bins.size(); + const float y1 = bins[floor(x)]; + const float y2 = bins[ceil(x)]; + return y1 + (y2-y1)*(x-floor(x)); + } +}; + +QDataStream &operator<<(QDataStream &stream, const KDE &kde) +{ + return stream << kde.min << kde.max << kde.mean << kde.stddev << kde.bins; +} + +QDataStream &operator>>(QDataStream &stream, KDE &kde) +{ + return stream >> kde.min >> kde.max >> kde.mean >> kde.stddev >> kde.bins; +} + +/* Match Probability */ +struct MP +{ + KDE genuine, impostor; + MP() {} + MP(const QList &genuineScores, const QList &impostorScores) + : genuine(genuineScores), impostor(impostorScores) {} + float operator()(float score, bool gaussian = false, bool log = false) const + { + const float g = genuine(score, gaussian); + const float s = g / (impostor(score, gaussian) + g); + if (log) return (std::max(std::log10(s), -10.f) + 10)/10; + else return s; + } +}; + +QDataStream &operator<<(QDataStream &stream, const MP &nmp) +{ + return stream << nmp.genuine << nmp.impostor; +} + +QDataStream &operator>>(QDataStream &stream, MP &nmp) +{ + return stream >> nmp.genuine >> nmp.impostor; +} + +/*! + * \ingroup distances + * \brief Match Probability \cite klare12 + * \author Josh Klontz \cite jklontz + */ +class MPDistance : public Distance +{ + Q_OBJECT + Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) + Q_PROPERTY(QString binKey READ get_binKey WRITE set_binKey RESET reset_binKey STORED false) + Q_PROPERTY(bool gaussian READ get_gaussian WRITE set_gaussian RESET reset_gaussian STORED false) + Q_PROPERTY(bool log READ get_log WRITE set_log RESET reset_log STORED false) + BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) + BR_PROPERTY(QString, binKey, "") + BR_PROPERTY(bool, gaussian, false) + BR_PROPERTY(bool, log, false) + + QHash mps; + + void train(const TemplateList &src) + { + distance->train(src); + + const QList labels = src.labels(); + QScopedPointer memoryOutput(dynamic_cast(Output::make(".Matrix", FileList(src.size()), FileList(src.size())))); + distance->compare(src, src, memoryOutput.data()); + + QHash< QString, QList > genuineScores, impostorScores; + for (int i=0; idata.at(i, j); + const QString bin = src[i].file.getString(binKey, ""); + if (labels[i] == labels[j]) genuineScores[bin].append(score); + else impostorScores[bin].append(score); + } + + foreach (const QString &key, genuineScores.keys()) + mps.insert(key, MP(genuineScores[key], impostorScores[key])); + } + + float _compare(const Template &target, const Template &query) const + { + return mps[query.file.getString(binKey, "")](distance->compare(target, query), gaussian, log); + } + + void store(QDataStream &stream) const + { + distance->store(stream); + stream << mps; + } + + void load(QDataStream &stream) + { + distance->load(stream); + stream >> mps; + } +}; + +BR_REGISTER(Distance, MPDistance) + +/*! + * \ingroup distances + * \brief Linear normalizes of a distance so the mean impostor score is 0 and the mean genuine score is 1. + * \author Josh Klontz \cite jklontz + */ +class UnitDistance : public Distance +{ + Q_OBJECT + Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance) + Q_PROPERTY(float a READ get_a WRITE set_a RESET reset_a) + Q_PROPERTY(float b READ get_b WRITE set_b RESET reset_b) + BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) + BR_PROPERTY(float, a, 1) + BR_PROPERTY(float, b, 0) + + void train(const TemplateList &templates) + { + const TemplateList samples = templates.mid(0, 2000); + const QList sampleLabels = samples.labels(); + QScopedPointer memoryOutput(dynamic_cast(Output::make(".Matrix", FileList(samples.size()), FileList(samples.size())))); + compare(samples, samples, memoryOutput.data()); + + double genuineAccumulator, impostorAccumulator; + int genuineCount, impostorCount; + genuineAccumulator = impostorAccumulator = genuineCount = impostorCount = 0; + + for (int i=0; idata.at(i, j); + if (sampleLabels[i] == sampleLabels[j]) { + genuineAccumulator += val; + genuineCount++; + } else { + impostorAccumulator += val; + impostorCount++; + } + } + } + + if (genuineCount == 0) { qWarning("No genuine matches."); return; } + if (impostorCount == 0) { qWarning("No impostor matches."); return; } + + double genuineMean = genuineAccumulator / genuineCount; + double impostorMean = impostorAccumulator / impostorCount; + + if (genuineMean == impostorMean) { qWarning("Genuines and impostors are indistinguishable."); return; } + + a = 1.0/(genuineMean-impostorMean); + b = impostorMean; + + qDebug("a = %f, b = %f", a, b); + } + + float _compare(const Template &target, const Template &query) const + { + return a * (distance->compare(target, query) - b); + } +}; + +BR_REGISTER(Distance, UnitDistance) + } // namespace br #include "quality.moc"