Commit f7dff58cfe07ce9fb0341ac7f4ef4a067cc874f5
Merge pull request #7 from biometrics/kde
Implemented Imposter Score Normalization and Match Probability
Showing
6 changed files
with
232 additions
and
58 deletions
sdk/core/core.cpp
| @@ -252,7 +252,7 @@ private: | @@ -252,7 +252,7 @@ private: | ||
| 252 | if (words.size() > 2) qFatal("AlgorithmCore::init invalid algorithm format."); | 252 | if (words.size() > 2) qFatal("AlgorithmCore::init invalid algorithm format."); |
| 253 | 253 | ||
| 254 | transform = QSharedPointer<Transform>(Transform::make(words[0], NULL)); | 254 | transform = QSharedPointer<Transform>(Transform::make(words[0], NULL)); |
| 255 | - if (words.size() > 1) distance = QSharedPointer<Distance>(Factory<Distance>::make("." + words[1])); | 255 | + if (words.size() > 1) distance = QSharedPointer<Distance>(Distance::make(words[1], NULL)); |
| 256 | } | 256 | } |
| 257 | }; | 257 | }; |
| 258 | 258 |
sdk/openbr_plugin.cpp
| @@ -497,6 +497,8 @@ QString Object::argument(int index) const | @@ -497,6 +497,8 @@ QString Object::argument(int index) const | ||
| 497 | return "[" + strings.join(",") + "]"; | 497 | return "[" + strings.join(",") + "]"; |
| 498 | } else if (type == "br::Transform*") { | 498 | } else if (type == "br::Transform*") { |
| 499 | return variant.value<Transform*>()->description(); | 499 | return variant.value<Transform*>()->description(); |
| 500 | + } else if (type == "br::Distance*") { | ||
| 501 | + return variant.value<Distance*>()->description(); | ||
| 500 | } else if (type == "QStringList") { | 502 | } else if (type == "QStringList") { |
| 501 | return "[" + variant.toStringList().join(",") + "]"; | 503 | return "[" + variant.toStringList().join(",") + "]"; |
| 502 | } | 504 | } |
| @@ -524,6 +526,8 @@ void Object::store(QDataStream &stream) const | @@ -524,6 +526,8 @@ void Object::store(QDataStream &stream) const | ||
| 524 | transform->store(stream); | 526 | transform->store(stream); |
| 525 | } else if (type == "br::Transform*") { | 527 | } else if (type == "br::Transform*") { |
| 526 | property.read(this).value<Transform*>()->store(stream); | 528 | property.read(this).value<Transform*>()->store(stream); |
| 529 | + } else if (type == "br::Distance*") { | ||
| 530 | + property.read(this).value<Distance*>()->store(stream); | ||
| 527 | } else if (type == "bool") { | 531 | } else if (type == "bool") { |
| 528 | stream << property.read(this).toBool(); | 532 | stream << property.read(this).toBool(); |
| 529 | } else if (type == "int") { | 533 | } else if (type == "int") { |
| @@ -556,6 +560,8 @@ void Object::load(QDataStream &stream) | @@ -556,6 +560,8 @@ void Object::load(QDataStream &stream) | ||
| 556 | transform->load(stream); | 560 | transform->load(stream); |
| 557 | } else if (type == "br::Transform*") { | 561 | } else if (type == "br::Transform*") { |
| 558 | property.read(this).value<Transform*>()->load(stream); | 562 | property.read(this).value<Transform*>()->load(stream); |
| 563 | + } else if (type == "br::Distance*") { | ||
| 564 | + property.read(this).value<Distance*>()->load(stream); | ||
| 559 | } else if (type == "bool") { | 565 | } else if (type == "bool") { |
| 560 | bool value; | 566 | bool value; |
| 561 | stream >> value; | 567 | stream >> value; |
| @@ -620,6 +626,8 @@ void Object::setProperty(const QString &name, const QString &value) | @@ -620,6 +626,8 @@ void Object::setProperty(const QString &name, const QString &value) | ||
| 620 | } | 626 | } |
| 621 | } else if (type == "br::Transform*") { | 627 | } else if (type == "br::Transform*") { |
| 622 | variant.setValue(Transform::make(value, this)); | 628 | variant.setValue(Transform::make(value, this)); |
| 629 | + } else if (type == "br::Distance*") { | ||
| 630 | + variant.setValue(Distance::make(value, this)); | ||
| 623 | } else if (type == "QStringList") { | 631 | } else if (type == "QStringList") { |
| 624 | variant.setValue(parse(value.mid(1, value.size()-2))); | 632 | variant.setValue(parse(value.mid(1, value.size()-2))); |
| 625 | } else if (type == "bool") { | 633 | } else if (type == "bool") { |
| @@ -630,6 +638,7 @@ void Object::setProperty(const QString &name, const QString &value) | @@ -630,6 +638,7 @@ void Object::setProperty(const QString &name, const QString &value) | ||
| 630 | } else { | 638 | } else { |
| 631 | variant = value; | 639 | variant = value; |
| 632 | } | 640 | } |
| 641 | + | ||
| 633 | if (!QObject::setProperty(qPrintable(name), variant) && !type.isEmpty()) | 642 | if (!QObject::setProperty(qPrintable(name), variant) && !type.isEmpty()) |
| 634 | qFatal("Failed to set %s::%s to: %s %s", | 643 | qFatal("Failed to set %s::%s to: %s %s", |
| 635 | metaObject()->className(), qPrintable(name), qPrintable(value), qPrintable(type)); | 644 | metaObject()->className(), qPrintable(name), qPrintable(value), qPrintable(type)); |
| @@ -786,6 +795,7 @@ void br::Context::initializeQt(QString sdkPath) | @@ -786,6 +795,7 @@ void br::Context::initializeQt(QString sdkPath) | ||
| 786 | qRegisterMetaType< QList<int> >(); | 795 | qRegisterMetaType< QList<int> >(); |
| 787 | qRegisterMetaType< br::Transform* >(); | 796 | qRegisterMetaType< br::Transform* >(); |
| 788 | qRegisterMetaType< QList<br::Transform*> >(); | 797 | qRegisterMetaType< QList<br::Transform*> >(); |
| 798 | + qRegisterMetaType< br::Distance* >(); | ||
| 789 | qRegisterMetaType< cv::Mat >(); | 799 | qRegisterMetaType< cv::Mat >(); |
| 790 | 800 | ||
| 791 | qInstallMsgHandler(messageHandler); | 801 | qInstallMsgHandler(messageHandler); |
| @@ -1270,42 +1280,17 @@ void Transform::backProject(const TemplateList &dst, TemplateList &src) const | @@ -1270,42 +1280,17 @@ void Transform::backProject(const TemplateList &dst, TemplateList &src) const | ||
| 1270 | } | 1280 | } |
| 1271 | 1281 | ||
| 1272 | /* Distance - public methods */ | 1282 | /* Distance - public methods */ |
| 1273 | -void Distance::train(const TemplateList &templates) | ||
| 1274 | -{ | ||
| 1275 | - const TemplateList samples = templates.mid(0, 2000); | ||
| 1276 | - const QList<float> sampleLabels = samples.labels<float>(); | ||
| 1277 | - QScopedPointer<MatrixOutput> memoryOutput(dynamic_cast<MatrixOutput*>(Output::make(".Matrix", FileList(samples.size()), FileList(samples.size())))); | ||
| 1278 | - compare(samples, samples, memoryOutput.data()); | ||
| 1279 | - | ||
| 1280 | - double genuineAccumulator, impostorAccumulator; | ||
| 1281 | - int genuineCount, impostorCount; | ||
| 1282 | - genuineAccumulator = impostorAccumulator = genuineCount = impostorCount = 0; | ||
| 1283 | - | ||
| 1284 | - for (int i=0; i<samples.size(); i++) { | ||
| 1285 | - for (int j=0; j<i; j++) { | ||
| 1286 | - const float val = memoryOutput.data()->data.at<float>(i, j); | ||
| 1287 | - if (sampleLabels[i] == sampleLabels[j]) { | ||
| 1288 | - genuineAccumulator += val; | ||
| 1289 | - genuineCount++; | ||
| 1290 | - } else { | ||
| 1291 | - impostorAccumulator += val; | ||
| 1292 | - impostorCount++; | ||
| 1293 | - } | ||
| 1294 | - } | ||
| 1295 | - } | ||
| 1296 | - | ||
| 1297 | - if (genuineCount == 0) { qWarning("No genuine matches."); return; } | ||
| 1298 | - if (impostorCount == 0) { qWarning("No impostor matches."); return; } | ||
| 1299 | - | ||
| 1300 | - double genuineMean = genuineAccumulator / genuineCount; | ||
| 1301 | - double impostorMean = impostorAccumulator / impostorCount; | ||
| 1302 | - | ||
| 1303 | - if (genuineMean == impostorMean) { qWarning("Genuines and impostors are indistinguishable."); return; } | 1283 | +Distance *Distance::make(QString str, QObject *parent) |
| 1284 | +{ | ||
| 1285 | + // Check for custom transforms | ||
| 1286 | + if (Globals->abbreviations.contains(str)) | ||
| 1287 | + return make(Globals->abbreviations[str], parent); | ||
| 1304 | 1288 | ||
| 1305 | - a = 1.0/(genuineMean-impostorMean); | ||
| 1306 | - b = impostorMean; | 1289 | + File f = "." + str; |
| 1290 | + Distance *distance = Factory<Distance>::make(f); | ||
| 1307 | 1291 | ||
| 1308 | - qDebug("a = %f, b = %f", a, b); | 1292 | + distance->setParent(parent); |
| 1293 | + return distance; | ||
| 1309 | } | 1294 | } |
| 1310 | 1295 | ||
| 1311 | void Distance::compare(const TemplateList &target, const TemplateList &query, Output *output) const | 1296 | void Distance::compare(const TemplateList &target, const TemplateList &query, Output *output) const |
| @@ -1336,7 +1321,7 @@ float Distance::compare(const Template &target, const Template &query) const | @@ -1336,7 +1321,7 @@ float Distance::compare(const Template &target, const Template &query) const | ||
| 1336 | return -std::numeric_limits<float>::max(); | 1321 | return -std::numeric_limits<float>::max(); |
| 1337 | } | 1322 | } |
| 1338 | 1323 | ||
| 1339 | - return a * (_compare(target, query) - b); | 1324 | + return _compare(target, query); |
| 1340 | } | 1325 | } |
| 1341 | 1326 | ||
| 1342 | QList<float> Distance::compare(const TemplateList &targets, const Template &query) const | 1327 | QList<float> Distance::compare(const TemplateList &targets, const Template &query) const |
sdk/openbr_plugin.h
| @@ -971,7 +971,6 @@ public: | @@ -971,7 +971,6 @@ public: | ||
| 971 | protected: | 971 | protected: |
| 972 | Transform(bool independent = true); /*!< \brief Construct a transform. */ | 972 | Transform(bool independent = true); /*!< \brief Construct a transform. */ |
| 973 | inline Transform *make(const QString &description) { return make(description, this); } /*!< \brief Make a subtransform. */ | 973 | inline Transform *make(const QString &description) { return make(description, this); } /*!< \brief Make a subtransform. */ |
| 974 | - | ||
| 975 | }; | 974 | }; |
| 976 | 975 | ||
| 977 | /*! | 976 | /*! |
| @@ -1065,19 +1064,19 @@ class BR_EXPORT Distance : public Object | @@ -1065,19 +1064,19 @@ class BR_EXPORT Distance : public Object | ||
| 1065 | { | 1064 | { |
| 1066 | Q_OBJECT | 1065 | Q_OBJECT |
| 1067 | 1066 | ||
| 1068 | - // Score normalization | ||
| 1069 | - Q_PROPERTY(float a READ get_a WRITE set_a RESET reset_a) | ||
| 1070 | - Q_PROPERTY(float b READ get_b WRITE set_b RESET reset_b) | ||
| 1071 | - BR_PROPERTY(float, a, 1) | ||
| 1072 | - BR_PROPERTY(float, b, 0) | ||
| 1073 | - | ||
| 1074 | public: | 1067 | public: |
| 1068 | + virtual ~Distance() {} | ||
| 1069 | + static Distance *make(QString str, QObject *parent); /*!< \brief Make a distance from a string. */ | ||
| 1070 | + | ||
| 1075 | static QSharedPointer<Distance> fromAlgorithm(const QString &algorithm); /*!< \brief Retrieve an algorithm's distance. */ | 1071 | static QSharedPointer<Distance> fromAlgorithm(const QString &algorithm); /*!< \brief Retrieve an algorithm's distance. */ |
| 1076 | - virtual void train(const TemplateList &src); /*!< \brief Train the distance. */ | 1072 | + virtual void train(const TemplateList &src) { (void) src; } /*!< \brief Train the distance. */ |
| 1077 | virtual void compare(const TemplateList &target, const TemplateList &query, Output *output) const; /*!< \brief Compare two template lists. */ | 1073 | virtual void compare(const TemplateList &target, const TemplateList &query, Output *output) const; /*!< \brief Compare two template lists. */ |
| 1078 | float compare(const Template &target, const Template &query) const; /*!< \brief Compute the normalized distance between two templates. */ | 1074 | float compare(const Template &target, const Template &query) const; /*!< \brief Compute the normalized distance between two templates. */ |
| 1079 | QList<float> compare(const TemplateList &targets, const Template &query) const; /*!< \brief Compute the normalized distance between a template and a template list. */ | 1075 | QList<float> compare(const TemplateList &targets, const Template &query) const; /*!< \brief Compute the normalized distance between a template and a template list. */ |
| 1080 | 1076 | ||
| 1077 | +protected: | ||
| 1078 | + inline Distance *make(const QString &description) { return make(description, this); } /*!< \brief Make a subdistance. */ | ||
| 1079 | + | ||
| 1081 | private: | 1080 | private: |
| 1082 | virtual void compareBlock(const TemplateList &target, const TemplateList &query, Output *output, int targetOffset, int queryOffset) const; | 1081 | virtual void compareBlock(const TemplateList &target, const TemplateList &query, Output *output, int targetOffset, int queryOffset) const; |
| 1083 | virtual float _compare(const Template &a, const Template &b) const = 0; /*!< \brief Compute the distance between two templates. */ | 1082 | virtual float _compare(const Template &a, const Template &b) const = 0; /*!< \brief Compute the distance between two templates. */ |
| @@ -1132,6 +1131,7 @@ Q_DECLARE_METATYPE(QList<float>) | @@ -1132,6 +1131,7 @@ Q_DECLARE_METATYPE(QList<float>) | ||
| 1132 | Q_DECLARE_METATYPE(QList<int>) | 1131 | Q_DECLARE_METATYPE(QList<int>) |
| 1133 | Q_DECLARE_METATYPE(br::Transform*) | 1132 | Q_DECLARE_METATYPE(br::Transform*) |
| 1134 | Q_DECLARE_METATYPE(QList<br::Transform*>) | 1133 | Q_DECLARE_METATYPE(QList<br::Transform*>) |
| 1134 | +Q_DECLARE_METATYPE(br::Distance*) | ||
| 1135 | Q_DECLARE_METATYPE(cv::Mat) | 1135 | Q_DECLARE_METATYPE(cv::Mat) |
| 1136 | 1136 | ||
| 1137 | #endif // __OPENBR_PLUGIN_H | 1137 | #endif // __OPENBR_PLUGIN_H |
sdk/plugins/algorithms.cpp
| @@ -70,6 +70,7 @@ class AlgorithmsInitializer : public Initializer | @@ -70,6 +70,7 @@ class AlgorithmsInitializer : public Initializer | ||
| 70 | Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+PCA(0.95,instances=-1)+Cat)"); | 70 | Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+PCA(0.95,instances=-1)+Cat)"); |
| 71 | Globals->abbreviations.insert("AgeRegressor", "Center(Range,instances=-1)+SVM(RBF,EPS_SVR,instances=100)"); | 71 | Globals->abbreviations.insert("AgeRegressor", "Center(Range,instances=-1)+SVM(RBF,EPS_SVR,instances=100)"); |
| 72 | Globals->abbreviations.insert("GenderClassifier", "Center(Range,instances=-1)+SVM(RBF,C_SVC,instances=4000)"); | 72 | Globals->abbreviations.insert("GenderClassifier", "Center(Range,instances=-1)+SVM(RBF,C_SVC,instances=4000)"); |
| 73 | + Globals->abbreviations.insert("UCharL1", "Unit(ByteL1)"); | ||
| 73 | } | 74 | } |
| 74 | }; | 75 | }; |
| 75 | 76 |
sdk/plugins/compare.cpp
| @@ -119,7 +119,7 @@ BR_REGISTER(Distance, DistDistance) | @@ -119,7 +119,7 @@ BR_REGISTER(Distance, DistDistance) | ||
| 119 | * \brief Fast 8-bit L1 distance | 119 | * \brief Fast 8-bit L1 distance |
| 120 | * \author Josh Klontz \cite jklontz | 120 | * \author Josh Klontz \cite jklontz |
| 121 | */ | 121 | */ |
| 122 | -class UCharL1Distance : public Distance | 122 | +class ByteL1Distance : public Distance |
| 123 | { | 123 | { |
| 124 | Q_OBJECT | 124 | Q_OBJECT |
| 125 | 125 | ||
| @@ -129,7 +129,7 @@ class UCharL1Distance : public Distance | @@ -129,7 +129,7 @@ class UCharL1Distance : public Distance | ||
| 129 | } | 129 | } |
| 130 | }; | 130 | }; |
| 131 | 131 | ||
| 132 | -BR_REGISTER(Distance, UCharL1Distance) | 132 | +BR_REGISTER(Distance, ByteL1Distance) |
| 133 | 133 | ||
| 134 | 134 | ||
| 135 | /*! | 135 | /*! |
| @@ -137,7 +137,7 @@ BR_REGISTER(Distance, UCharL1Distance) | @@ -137,7 +137,7 @@ BR_REGISTER(Distance, UCharL1Distance) | ||
| 137 | * \brief Fast 4-bit L1 distance | 137 | * \brief Fast 4-bit L1 distance |
| 138 | * \author Josh Klontz \cite jklontz | 138 | * \author Josh Klontz \cite jklontz |
| 139 | */ | 139 | */ |
| 140 | -class PackedUCharL1Distance : public Distance | 140 | +class HalfByteL1Distance : public Distance |
| 141 | { | 141 | { |
| 142 | Q_OBJECT | 142 | Q_OBJECT |
| 143 | 143 | ||
| @@ -147,7 +147,7 @@ class PackedUCharL1Distance : public Distance | @@ -147,7 +147,7 @@ class PackedUCharL1Distance : public Distance | ||
| 147 | } | 147 | } |
| 148 | }; | 148 | }; |
| 149 | 149 | ||
| 150 | -BR_REGISTER(Distance, PackedUCharL1Distance) | 150 | +BR_REGISTER(Distance, HalfByteL1Distance) |
| 151 | 151 | ||
| 152 | /*! | 152 | /*! |
| 153 | * \ingroup distances | 153 | * \ingroup distances |
sdk/plugins/quality.cpp
| @@ -18,14 +18,20 @@ class IUMTransform : public Transform | @@ -18,14 +18,20 @@ class IUMTransform : public Transform | ||
| 18 | Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) | 18 | Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) |
| 19 | Q_PROPERTY(double mean READ get_mean WRITE set_mean RESET reset_mean) | 19 | Q_PROPERTY(double mean READ get_mean WRITE set_mean RESET reset_mean) |
| 20 | Q_PROPERTY(double stddev READ get_stddev WRITE set_stddev RESET reset_stddev) | 20 | Q_PROPERTY(double stddev READ get_stddev WRITE set_stddev RESET reset_stddev) |
| 21 | - BR_PROPERTY(br::Distance*, distance, Factory<Distance>::make(".Dist(L2)")) | 21 | + BR_PROPERTY(br::Distance*, distance, Distance::make("Dist(L2)", this)) |
| 22 | BR_PROPERTY(double, mean, 0) | 22 | BR_PROPERTY(double, mean, 0) |
| 23 | BR_PROPERTY(double, stddev, 1) | 23 | BR_PROPERTY(double, stddev, 1) |
| 24 | br::TemplateList impostors; | 24 | br::TemplateList impostors; |
| 25 | 25 | ||
| 26 | float calculateIUM(const Template &probe, const TemplateList &gallery) const | 26 | float calculateIUM(const Template &probe, const TemplateList &gallery) const |
| 27 | { | 27 | { |
| 28 | - QList<float> scores = distance->compare(gallery, probe); | 28 | + const int probeLabel = probe.file.label(); |
| 29 | + TemplateList subset = gallery; | ||
| 30 | + for (int j=subset.size()-1; j>=0; j--) | ||
| 31 | + if (subset[j].file.label() == probeLabel) | ||
| 32 | + subset.removeAt(j); | ||
| 33 | + | ||
| 34 | + QList<float> scores = distance->compare(subset, probe); | ||
| 29 | float min, max; | 35 | float min, max; |
| 30 | Common::MinMax(scores, &min, &max); | 36 | Common::MinMax(scores, &min, &max); |
| 31 | double mean; | 37 | double mean; |
| @@ -39,14 +45,8 @@ class IUMTransform : public Transform | @@ -39,14 +45,8 @@ class IUMTransform : public Transform | ||
| 39 | impostors = data; | 45 | impostors = data; |
| 40 | 46 | ||
| 41 | QList<float> iums; iums.reserve(impostors.size()); | 47 | QList<float> iums; iums.reserve(impostors.size()); |
| 42 | - QList<int> labels = impostors.labels<int>(); | ||
| 43 | - for (int i=0; i<data.size(); i++) { | ||
| 44 | - TemplateList subset = impostors; | ||
| 45 | - for (int j=subset.size()-1; j>=0; j--) | ||
| 46 | - if (labels[j] == labels[i]) | ||
| 47 | - subset.removeAt(j); | ||
| 48 | - iums.append(calculateIUM(impostors[i], subset)); | ||
| 49 | - } | 48 | + for (int i=0; i<data.size(); i++) |
| 49 | + iums.append(calculateIUM(impostors[i], impostors)); | ||
| 50 | 50 | ||
| 51 | Common::MeanStdDev(iums, &mean, &stddev); | 51 | Common::MeanStdDev(iums, &mean, &stddev); |
| 52 | } | 52 | } |
| @@ -74,6 +74,194 @@ class IUMTransform : public Transform | @@ -74,6 +74,194 @@ class IUMTransform : public Transform | ||
| 74 | 74 | ||
| 75 | BR_REGISTER(Transform, IUMTransform) | 75 | BR_REGISTER(Transform, IUMTransform) |
| 76 | 76 | ||
| 77 | +/* Kernel Density Estimator */ | ||
| 78 | +struct KDE | ||
| 79 | +{ | ||
| 80 | + float min, max; | ||
| 81 | + double mean, stddev; | ||
| 82 | + QList<float> bins; | ||
| 83 | + | ||
| 84 | + KDE() : min(0), max(1) {} | ||
| 85 | + KDE(const QList<float> &scores) | ||
| 86 | + { | ||
| 87 | + Common::MinMax(scores, &min, &max); | ||
| 88 | + Common::MeanStdDev(scores, &mean, &stddev); | ||
| 89 | + double h = Common::KernelDensityBandwidth(scores); | ||
| 90 | + const int size = 255; | ||
| 91 | + bins.reserve(size); | ||
| 92 | + for (int i=0; i<size; i++) | ||
| 93 | + bins.append(Common::KernelDensityEstimation(scores, min + (max-min)*i/(size-1), h)); | ||
| 94 | + } | ||
| 95 | + | ||
| 96 | + float operator()(float score, bool gaussian = false) const | ||
| 97 | + { | ||
| 98 | + if (gaussian) return 1/(stddev*sqrt(2*CV_PI))*exp(-0.5*pow((score-mean)/stddev, 2)); | ||
| 99 | + if (score <= min) return bins.first(); | ||
| 100 | + if (score >= max) return bins.last(); | ||
| 101 | + const float x = (score-min)/(max-min)*bins.size(); | ||
| 102 | + const float y1 = bins[floor(x)]; | ||
| 103 | + const float y2 = bins[ceil(x)]; | ||
| 104 | + return y1 + (y2-y1)*(x-floor(x)); | ||
| 105 | + } | ||
| 106 | +}; | ||
| 107 | + | ||
| 108 | +QDataStream &operator<<(QDataStream &stream, const KDE &kde) | ||
| 109 | +{ | ||
| 110 | + return stream << kde.min << kde.max << kde.mean << kde.stddev << kde.bins; | ||
| 111 | +} | ||
| 112 | + | ||
| 113 | +QDataStream &operator>>(QDataStream &stream, KDE &kde) | ||
| 114 | +{ | ||
| 115 | + return stream >> kde.min >> kde.max >> kde.mean >> kde.stddev >> kde.bins; | ||
| 116 | +} | ||
| 117 | + | ||
| 118 | +/* Match Probability */ | ||
| 119 | +struct MP | ||
| 120 | +{ | ||
| 121 | + KDE genuine, impostor; | ||
| 122 | + MP() {} | ||
| 123 | + MP(const QList<float> &genuineScores, const QList<float> &impostorScores) | ||
| 124 | + : genuine(genuineScores), impostor(impostorScores) {} | ||
| 125 | + float operator()(float score, bool gaussian = false, bool log = false) const | ||
| 126 | + { | ||
| 127 | + const float g = genuine(score, gaussian); | ||
| 128 | + const float s = g / (impostor(score, gaussian) + g); | ||
| 129 | + if (log) return (std::max(std::log10(s), -10.f) + 10)/10; | ||
| 130 | + else return s; | ||
| 131 | + } | ||
| 132 | +}; | ||
| 133 | + | ||
| 134 | +QDataStream &operator<<(QDataStream &stream, const MP &nmp) | ||
| 135 | +{ | ||
| 136 | + return stream << nmp.genuine << nmp.impostor; | ||
| 137 | +} | ||
| 138 | + | ||
| 139 | +QDataStream &operator>>(QDataStream &stream, MP &nmp) | ||
| 140 | +{ | ||
| 141 | + return stream >> nmp.genuine >> nmp.impostor; | ||
| 142 | +} | ||
| 143 | + | ||
| 144 | +/*! | ||
| 145 | + * \ingroup distances | ||
| 146 | + * \brief Match Probability \cite klare12 | ||
| 147 | + * \author Josh Klontz \cite jklontz | ||
| 148 | + */ | ||
| 149 | +class MPDistance : public Distance | ||
| 150 | +{ | ||
| 151 | + Q_OBJECT | ||
| 152 | + Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) | ||
| 153 | + Q_PROPERTY(QString binKey READ get_binKey WRITE set_binKey RESET reset_binKey STORED false) | ||
| 154 | + Q_PROPERTY(bool gaussian READ get_gaussian WRITE set_gaussian RESET reset_gaussian STORED false) | ||
| 155 | + Q_PROPERTY(bool log READ get_log WRITE set_log RESET reset_log STORED false) | ||
| 156 | + BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) | ||
| 157 | + BR_PROPERTY(QString, binKey, "") | ||
| 158 | + BR_PROPERTY(bool, gaussian, false) | ||
| 159 | + BR_PROPERTY(bool, log, false) | ||
| 160 | + | ||
| 161 | + QHash<QString, MP> mps; | ||
| 162 | + | ||
| 163 | + void train(const TemplateList &src) | ||
| 164 | + { | ||
| 165 | + distance->train(src); | ||
| 166 | + | ||
| 167 | + const QList<int> labels = src.labels<int>(); | ||
| 168 | + QScopedPointer<MatrixOutput> memoryOutput(dynamic_cast<MatrixOutput*>(Output::make(".Matrix", FileList(src.size()), FileList(src.size())))); | ||
| 169 | + distance->compare(src, src, memoryOutput.data()); | ||
| 170 | + | ||
| 171 | + QHash< QString, QList<float> > genuineScores, impostorScores; | ||
| 172 | + for (int i=0; i<src.size(); i++) | ||
| 173 | + for (int j=0; j<i; j++) { | ||
| 174 | + const float score = memoryOutput.data()->data.at<float>(i, j); | ||
| 175 | + const QString bin = src[i].file.getString(binKey, ""); | ||
| 176 | + if (labels[i] == labels[j]) genuineScores[bin].append(score); | ||
| 177 | + else impostorScores[bin].append(score); | ||
| 178 | + } | ||
| 179 | + | ||
| 180 | + foreach (const QString &key, genuineScores.keys()) | ||
| 181 | + mps.insert(key, MP(genuineScores[key], impostorScores[key])); | ||
| 182 | + } | ||
| 183 | + | ||
| 184 | + float _compare(const Template &target, const Template &query) const | ||
| 185 | + { | ||
| 186 | + return mps[query.file.getString(binKey, "")](distance->compare(target, query), gaussian, log); | ||
| 187 | + } | ||
| 188 | + | ||
| 189 | + void store(QDataStream &stream) const | ||
| 190 | + { | ||
| 191 | + distance->store(stream); | ||
| 192 | + stream << mps; | ||
| 193 | + } | ||
| 194 | + | ||
| 195 | + void load(QDataStream &stream) | ||
| 196 | + { | ||
| 197 | + distance->load(stream); | ||
| 198 | + stream >> mps; | ||
| 199 | + } | ||
| 200 | +}; | ||
| 201 | + | ||
| 202 | +BR_REGISTER(Distance, MPDistance) | ||
| 203 | + | ||
| 204 | +/*! | ||
| 205 | + * \ingroup distances | ||
| 206 | + * \brief Linear normalizes of a distance so the mean impostor score is 0 and the mean genuine score is 1. | ||
| 207 | + * \author Josh Klontz \cite jklontz | ||
| 208 | + */ | ||
| 209 | +class UnitDistance : public Distance | ||
| 210 | +{ | ||
| 211 | + Q_OBJECT | ||
| 212 | + Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance) | ||
| 213 | + Q_PROPERTY(float a READ get_a WRITE set_a RESET reset_a) | ||
| 214 | + Q_PROPERTY(float b READ get_b WRITE set_b RESET reset_b) | ||
| 215 | + BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) | ||
| 216 | + BR_PROPERTY(float, a, 1) | ||
| 217 | + BR_PROPERTY(float, b, 0) | ||
| 218 | + | ||
| 219 | + void train(const TemplateList &templates) | ||
| 220 | + { | ||
| 221 | + const TemplateList samples = templates.mid(0, 2000); | ||
| 222 | + const QList<float> sampleLabels = samples.labels<float>(); | ||
| 223 | + QScopedPointer<MatrixOutput> memoryOutput(dynamic_cast<MatrixOutput*>(Output::make(".Matrix", FileList(samples.size()), FileList(samples.size())))); | ||
| 224 | + compare(samples, samples, memoryOutput.data()); | ||
| 225 | + | ||
| 226 | + double genuineAccumulator, impostorAccumulator; | ||
| 227 | + int genuineCount, impostorCount; | ||
| 228 | + genuineAccumulator = impostorAccumulator = genuineCount = impostorCount = 0; | ||
| 229 | + | ||
| 230 | + for (int i=0; i<samples.size(); i++) { | ||
| 231 | + for (int j=0; j<i; j++) { | ||
| 232 | + const float val = memoryOutput.data()->data.at<float>(i, j); | ||
| 233 | + if (sampleLabels[i] == sampleLabels[j]) { | ||
| 234 | + genuineAccumulator += val; | ||
| 235 | + genuineCount++; | ||
| 236 | + } else { | ||
| 237 | + impostorAccumulator += val; | ||
| 238 | + impostorCount++; | ||
| 239 | + } | ||
| 240 | + } | ||
| 241 | + } | ||
| 242 | + | ||
| 243 | + if (genuineCount == 0) { qWarning("No genuine matches."); return; } | ||
| 244 | + if (impostorCount == 0) { qWarning("No impostor matches."); return; } | ||
| 245 | + | ||
| 246 | + double genuineMean = genuineAccumulator / genuineCount; | ||
| 247 | + double impostorMean = impostorAccumulator / impostorCount; | ||
| 248 | + | ||
| 249 | + if (genuineMean == impostorMean) { qWarning("Genuines and impostors are indistinguishable."); return; } | ||
| 250 | + | ||
| 251 | + a = 1.0/(genuineMean-impostorMean); | ||
| 252 | + b = impostorMean; | ||
| 253 | + | ||
| 254 | + qDebug("a = %f, b = %f", a, b); | ||
| 255 | + } | ||
| 256 | + | ||
| 257 | + float _compare(const Template &target, const Template &query) const | ||
| 258 | + { | ||
| 259 | + return a * (distance->compare(target, query) - b); | ||
| 260 | + } | ||
| 261 | +}; | ||
| 262 | + | ||
| 263 | +BR_REGISTER(Distance, UnitDistance) | ||
| 264 | + | ||
| 77 | } // namespace br | 265 | } // namespace br |
| 78 | 266 | ||
| 79 | #include "quality.moc" | 267 | #include "quality.moc" |