Commit f7dff58cfe07ce9fb0341ac7f4ef4a067cc874f5

Authored by jklontz
2 parents f74da8dd f9c9e2d1

Merge pull request #7 from biometrics/kde

Implemented Imposter Score Normalization and Match Probability
sdk/core/core.cpp
@@ -252,7 +252,7 @@ private: @@ -252,7 +252,7 @@ private:
252 if (words.size() > 2) qFatal("AlgorithmCore::init invalid algorithm format."); 252 if (words.size() > 2) qFatal("AlgorithmCore::init invalid algorithm format.");
253 253
254 transform = QSharedPointer<Transform>(Transform::make(words[0], NULL)); 254 transform = QSharedPointer<Transform>(Transform::make(words[0], NULL));
255 - if (words.size() > 1) distance = QSharedPointer<Distance>(Factory<Distance>::make("." + words[1])); 255 + if (words.size() > 1) distance = QSharedPointer<Distance>(Distance::make(words[1], NULL));
256 } 256 }
257 }; 257 };
258 258
sdk/openbr_plugin.cpp
@@ -497,6 +497,8 @@ QString Object::argument(int index) const @@ -497,6 +497,8 @@ QString Object::argument(int index) const
497 return "[" + strings.join(",") + "]"; 497 return "[" + strings.join(",") + "]";
498 } else if (type == "br::Transform*") { 498 } else if (type == "br::Transform*") {
499 return variant.value<Transform*>()->description(); 499 return variant.value<Transform*>()->description();
  500 + } else if (type == "br::Distance*") {
  501 + return variant.value<Distance*>()->description();
500 } else if (type == "QStringList") { 502 } else if (type == "QStringList") {
501 return "[" + variant.toStringList().join(",") + "]"; 503 return "[" + variant.toStringList().join(",") + "]";
502 } 504 }
@@ -524,6 +526,8 @@ void Object::store(QDataStream &amp;stream) const @@ -524,6 +526,8 @@ void Object::store(QDataStream &amp;stream) const
524 transform->store(stream); 526 transform->store(stream);
525 } else if (type == "br::Transform*") { 527 } else if (type == "br::Transform*") {
526 property.read(this).value<Transform*>()->store(stream); 528 property.read(this).value<Transform*>()->store(stream);
  529 + } else if (type == "br::Distance*") {
  530 + property.read(this).value<Distance*>()->store(stream);
527 } else if (type == "bool") { 531 } else if (type == "bool") {
528 stream << property.read(this).toBool(); 532 stream << property.read(this).toBool();
529 } else if (type == "int") { 533 } else if (type == "int") {
@@ -556,6 +560,8 @@ void Object::load(QDataStream &amp;stream) @@ -556,6 +560,8 @@ void Object::load(QDataStream &amp;stream)
556 transform->load(stream); 560 transform->load(stream);
557 } else if (type == "br::Transform*") { 561 } else if (type == "br::Transform*") {
558 property.read(this).value<Transform*>()->load(stream); 562 property.read(this).value<Transform*>()->load(stream);
  563 + } else if (type == "br::Distance*") {
  564 + property.read(this).value<Distance*>()->load(stream);
559 } else if (type == "bool") { 565 } else if (type == "bool") {
560 bool value; 566 bool value;
561 stream >> value; 567 stream >> value;
@@ -620,6 +626,8 @@ void Object::setProperty(const QString &amp;name, const QString &amp;value) @@ -620,6 +626,8 @@ void Object::setProperty(const QString &amp;name, const QString &amp;value)
620 } 626 }
621 } else if (type == "br::Transform*") { 627 } else if (type == "br::Transform*") {
622 variant.setValue(Transform::make(value, this)); 628 variant.setValue(Transform::make(value, this));
  629 + } else if (type == "br::Distance*") {
  630 + variant.setValue(Distance::make(value, this));
623 } else if (type == "QStringList") { 631 } else if (type == "QStringList") {
624 variant.setValue(parse(value.mid(1, value.size()-2))); 632 variant.setValue(parse(value.mid(1, value.size()-2)));
625 } else if (type == "bool") { 633 } else if (type == "bool") {
@@ -630,6 +638,7 @@ void Object::setProperty(const QString &amp;name, const QString &amp;value) @@ -630,6 +638,7 @@ void Object::setProperty(const QString &amp;name, const QString &amp;value)
630 } else { 638 } else {
631 variant = value; 639 variant = value;
632 } 640 }
  641 +
633 if (!QObject::setProperty(qPrintable(name), variant) && !type.isEmpty()) 642 if (!QObject::setProperty(qPrintable(name), variant) && !type.isEmpty())
634 qFatal("Failed to set %s::%s to: %s %s", 643 qFatal("Failed to set %s::%s to: %s %s",
635 metaObject()->className(), qPrintable(name), qPrintable(value), qPrintable(type)); 644 metaObject()->className(), qPrintable(name), qPrintable(value), qPrintable(type));
@@ -786,6 +795,7 @@ void br::Context::initializeQt(QString sdkPath) @@ -786,6 +795,7 @@ void br::Context::initializeQt(QString sdkPath)
786 qRegisterMetaType< QList<int> >(); 795 qRegisterMetaType< QList<int> >();
787 qRegisterMetaType< br::Transform* >(); 796 qRegisterMetaType< br::Transform* >();
788 qRegisterMetaType< QList<br::Transform*> >(); 797 qRegisterMetaType< QList<br::Transform*> >();
  798 + qRegisterMetaType< br::Distance* >();
789 qRegisterMetaType< cv::Mat >(); 799 qRegisterMetaType< cv::Mat >();
790 800
791 qInstallMsgHandler(messageHandler); 801 qInstallMsgHandler(messageHandler);
@@ -1270,42 +1280,17 @@ void Transform::backProject(const TemplateList &amp;dst, TemplateList &amp;src) const @@ -1270,42 +1280,17 @@ void Transform::backProject(const TemplateList &amp;dst, TemplateList &amp;src) const
1270 } 1280 }
1271 1281
1272 /* Distance - public methods */ 1282 /* Distance - public methods */
1273 -void Distance::train(const TemplateList &templates)  
1274 -{  
1275 - const TemplateList samples = templates.mid(0, 2000);  
1276 - const QList<float> sampleLabels = samples.labels<float>();  
1277 - QScopedPointer<MatrixOutput> memoryOutput(dynamic_cast<MatrixOutput*>(Output::make(".Matrix", FileList(samples.size()), FileList(samples.size()))));  
1278 - compare(samples, samples, memoryOutput.data());  
1279 -  
1280 - double genuineAccumulator, impostorAccumulator;  
1281 - int genuineCount, impostorCount;  
1282 - genuineAccumulator = impostorAccumulator = genuineCount = impostorCount = 0;  
1283 -  
1284 - for (int i=0; i<samples.size(); i++) {  
1285 - for (int j=0; j<i; j++) {  
1286 - const float val = memoryOutput.data()->data.at<float>(i, j);  
1287 - if (sampleLabels[i] == sampleLabels[j]) {  
1288 - genuineAccumulator += val;  
1289 - genuineCount++;  
1290 - } else {  
1291 - impostorAccumulator += val;  
1292 - impostorCount++;  
1293 - }  
1294 - }  
1295 - }  
1296 -  
1297 - if (genuineCount == 0) { qWarning("No genuine matches."); return; }  
1298 - if (impostorCount == 0) { qWarning("No impostor matches."); return; }  
1299 -  
1300 - double genuineMean = genuineAccumulator / genuineCount;  
1301 - double impostorMean = impostorAccumulator / impostorCount;  
1302 -  
1303 - if (genuineMean == impostorMean) { qWarning("Genuines and impostors are indistinguishable."); return; } 1283 +Distance *Distance::make(QString str, QObject *parent)
  1284 +{
  1285 + // Check for custom transforms
  1286 + if (Globals->abbreviations.contains(str))
  1287 + return make(Globals->abbreviations[str], parent);
1304 1288
1305 - a = 1.0/(genuineMean-impostorMean);  
1306 - b = impostorMean; 1289 + File f = "." + str;
  1290 + Distance *distance = Factory<Distance>::make(f);
1307 1291
1308 - qDebug("a = %f, b = %f", a, b); 1292 + distance->setParent(parent);
  1293 + return distance;
1309 } 1294 }
1310 1295
1311 void Distance::compare(const TemplateList &target, const TemplateList &query, Output *output) const 1296 void Distance::compare(const TemplateList &target, const TemplateList &query, Output *output) const
@@ -1336,7 +1321,7 @@ float Distance::compare(const Template &amp;target, const Template &amp;query) const @@ -1336,7 +1321,7 @@ float Distance::compare(const Template &amp;target, const Template &amp;query) const
1336 return -std::numeric_limits<float>::max(); 1321 return -std::numeric_limits<float>::max();
1337 } 1322 }
1338 1323
1339 - return a * (_compare(target, query) - b); 1324 + return _compare(target, query);
1340 } 1325 }
1341 1326
1342 QList<float> Distance::compare(const TemplateList &targets, const Template &query) const 1327 QList<float> Distance::compare(const TemplateList &targets, const Template &query) const
sdk/openbr_plugin.h
@@ -971,7 +971,6 @@ public: @@ -971,7 +971,6 @@ public:
971 protected: 971 protected:
972 Transform(bool independent = true); /*!< \brief Construct a transform. */ 972 Transform(bool independent = true); /*!< \brief Construct a transform. */
973 inline Transform *make(const QString &description) { return make(description, this); } /*!< \brief Make a subtransform. */ 973 inline Transform *make(const QString &description) { return make(description, this); } /*!< \brief Make a subtransform. */
974 -  
975 }; 974 };
976 975
977 /*! 976 /*!
@@ -1065,19 +1064,19 @@ class BR_EXPORT Distance : public Object @@ -1065,19 +1064,19 @@ class BR_EXPORT Distance : public Object
1065 { 1064 {
1066 Q_OBJECT 1065 Q_OBJECT
1067 1066
1068 - // Score normalization  
1069 - Q_PROPERTY(float a READ get_a WRITE set_a RESET reset_a)  
1070 - Q_PROPERTY(float b READ get_b WRITE set_b RESET reset_b)  
1071 - BR_PROPERTY(float, a, 1)  
1072 - BR_PROPERTY(float, b, 0)  
1073 -  
1074 public: 1067 public:
  1068 + virtual ~Distance() {}
  1069 + static Distance *make(QString str, QObject *parent); /*!< \brief Make a distance from a string. */
  1070 +
1075 static QSharedPointer<Distance> fromAlgorithm(const QString &algorithm); /*!< \brief Retrieve an algorithm's distance. */ 1071 static QSharedPointer<Distance> fromAlgorithm(const QString &algorithm); /*!< \brief Retrieve an algorithm's distance. */
1076 - virtual void train(const TemplateList &src); /*!< \brief Train the distance. */ 1072 + virtual void train(const TemplateList &src) { (void) src; } /*!< \brief Train the distance. */
1077 virtual void compare(const TemplateList &target, const TemplateList &query, Output *output) const; /*!< \brief Compare two template lists. */ 1073 virtual void compare(const TemplateList &target, const TemplateList &query, Output *output) const; /*!< \brief Compare two template lists. */
1078 float compare(const Template &target, const Template &query) const; /*!< \brief Compute the normalized distance between two templates. */ 1074 float compare(const Template &target, const Template &query) const; /*!< \brief Compute the normalized distance between two templates. */
1079 QList<float> compare(const TemplateList &targets, const Template &query) const; /*!< \brief Compute the normalized distance between a template and a template list. */ 1075 QList<float> compare(const TemplateList &targets, const Template &query) const; /*!< \brief Compute the normalized distance between a template and a template list. */
1080 1076
  1077 +protected:
  1078 + inline Distance *make(const QString &description) { return make(description, this); } /*!< \brief Make a subdistance. */
  1079 +
1081 private: 1080 private:
1082 virtual void compareBlock(const TemplateList &target, const TemplateList &query, Output *output, int targetOffset, int queryOffset) const; 1081 virtual void compareBlock(const TemplateList &target, const TemplateList &query, Output *output, int targetOffset, int queryOffset) const;
1083 virtual float _compare(const Template &a, const Template &b) const = 0; /*!< \brief Compute the distance between two templates. */ 1082 virtual float _compare(const Template &a, const Template &b) const = 0; /*!< \brief Compute the distance between two templates. */
@@ -1132,6 +1131,7 @@ Q_DECLARE_METATYPE(QList&lt;float&gt;) @@ -1132,6 +1131,7 @@ Q_DECLARE_METATYPE(QList&lt;float&gt;)
1132 Q_DECLARE_METATYPE(QList<int>) 1131 Q_DECLARE_METATYPE(QList<int>)
1133 Q_DECLARE_METATYPE(br::Transform*) 1132 Q_DECLARE_METATYPE(br::Transform*)
1134 Q_DECLARE_METATYPE(QList<br::Transform*>) 1133 Q_DECLARE_METATYPE(QList<br::Transform*>)
  1134 +Q_DECLARE_METATYPE(br::Distance*)
1135 Q_DECLARE_METATYPE(cv::Mat) 1135 Q_DECLARE_METATYPE(cv::Mat)
1136 1136
1137 #endif // __OPENBR_PLUGIN_H 1137 #endif // __OPENBR_PLUGIN_H
sdk/plugins/algorithms.cpp
@@ -70,6 +70,7 @@ class AlgorithmsInitializer : public Initializer @@ -70,6 +70,7 @@ class AlgorithmsInitializer : public Initializer
70 Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+PCA(0.95,instances=-1)+Cat)"); 70 Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+PCA(0.95,instances=-1)+Cat)");
71 Globals->abbreviations.insert("AgeRegressor", "Center(Range,instances=-1)+SVM(RBF,EPS_SVR,instances=100)"); 71 Globals->abbreviations.insert("AgeRegressor", "Center(Range,instances=-1)+SVM(RBF,EPS_SVR,instances=100)");
72 Globals->abbreviations.insert("GenderClassifier", "Center(Range,instances=-1)+SVM(RBF,C_SVC,instances=4000)"); 72 Globals->abbreviations.insert("GenderClassifier", "Center(Range,instances=-1)+SVM(RBF,C_SVC,instances=4000)");
  73 + Globals->abbreviations.insert("UCharL1", "Unit(ByteL1)");
73 } 74 }
74 }; 75 };
75 76
sdk/plugins/compare.cpp
@@ -119,7 +119,7 @@ BR_REGISTER(Distance, DistDistance) @@ -119,7 +119,7 @@ BR_REGISTER(Distance, DistDistance)
119 * \brief Fast 8-bit L1 distance 119 * \brief Fast 8-bit L1 distance
120 * \author Josh Klontz \cite jklontz 120 * \author Josh Klontz \cite jklontz
121 */ 121 */
122 -class UCharL1Distance : public Distance 122 +class ByteL1Distance : public Distance
123 { 123 {
124 Q_OBJECT 124 Q_OBJECT
125 125
@@ -129,7 +129,7 @@ class UCharL1Distance : public Distance @@ -129,7 +129,7 @@ class UCharL1Distance : public Distance
129 } 129 }
130 }; 130 };
131 131
132 -BR_REGISTER(Distance, UCharL1Distance) 132 +BR_REGISTER(Distance, ByteL1Distance)
133 133
134 134
135 /*! 135 /*!
@@ -137,7 +137,7 @@ BR_REGISTER(Distance, UCharL1Distance) @@ -137,7 +137,7 @@ BR_REGISTER(Distance, UCharL1Distance)
137 * \brief Fast 4-bit L1 distance 137 * \brief Fast 4-bit L1 distance
138 * \author Josh Klontz \cite jklontz 138 * \author Josh Klontz \cite jklontz
139 */ 139 */
140 -class PackedUCharL1Distance : public Distance 140 +class HalfByteL1Distance : public Distance
141 { 141 {
142 Q_OBJECT 142 Q_OBJECT
143 143
@@ -147,7 +147,7 @@ class PackedUCharL1Distance : public Distance @@ -147,7 +147,7 @@ class PackedUCharL1Distance : public Distance
147 } 147 }
148 }; 148 };
149 149
150 -BR_REGISTER(Distance, PackedUCharL1Distance) 150 +BR_REGISTER(Distance, HalfByteL1Distance)
151 151
152 /*! 152 /*!
153 * \ingroup distances 153 * \ingroup distances
sdk/plugins/quality.cpp
@@ -18,14 +18,20 @@ class IUMTransform : public Transform @@ -18,14 +18,20 @@ class IUMTransform : public Transform
18 Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) 18 Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
19 Q_PROPERTY(double mean READ get_mean WRITE set_mean RESET reset_mean) 19 Q_PROPERTY(double mean READ get_mean WRITE set_mean RESET reset_mean)
20 Q_PROPERTY(double stddev READ get_stddev WRITE set_stddev RESET reset_stddev) 20 Q_PROPERTY(double stddev READ get_stddev WRITE set_stddev RESET reset_stddev)
21 - BR_PROPERTY(br::Distance*, distance, Factory<Distance>::make(".Dist(L2)")) 21 + BR_PROPERTY(br::Distance*, distance, Distance::make("Dist(L2)", this))
22 BR_PROPERTY(double, mean, 0) 22 BR_PROPERTY(double, mean, 0)
23 BR_PROPERTY(double, stddev, 1) 23 BR_PROPERTY(double, stddev, 1)
24 br::TemplateList impostors; 24 br::TemplateList impostors;
25 25
26 float calculateIUM(const Template &probe, const TemplateList &gallery) const 26 float calculateIUM(const Template &probe, const TemplateList &gallery) const
27 { 27 {
28 - QList<float> scores = distance->compare(gallery, probe); 28 + const int probeLabel = probe.file.label();
  29 + TemplateList subset = gallery;
  30 + for (int j=subset.size()-1; j>=0; j--)
  31 + if (subset[j].file.label() == probeLabel)
  32 + subset.removeAt(j);
  33 +
  34 + QList<float> scores = distance->compare(subset, probe);
29 float min, max; 35 float min, max;
30 Common::MinMax(scores, &min, &max); 36 Common::MinMax(scores, &min, &max);
31 double mean; 37 double mean;
@@ -39,14 +45,8 @@ class IUMTransform : public Transform @@ -39,14 +45,8 @@ class IUMTransform : public Transform
39 impostors = data; 45 impostors = data;
40 46
41 QList<float> iums; iums.reserve(impostors.size()); 47 QList<float> iums; iums.reserve(impostors.size());
42 - QList<int> labels = impostors.labels<int>();  
43 - for (int i=0; i<data.size(); i++) {  
44 - TemplateList subset = impostors;  
45 - for (int j=subset.size()-1; j>=0; j--)  
46 - if (labels[j] == labels[i])  
47 - subset.removeAt(j);  
48 - iums.append(calculateIUM(impostors[i], subset));  
49 - } 48 + for (int i=0; i<data.size(); i++)
  49 + iums.append(calculateIUM(impostors[i], impostors));
50 50
51 Common::MeanStdDev(iums, &mean, &stddev); 51 Common::MeanStdDev(iums, &mean, &stddev);
52 } 52 }
@@ -74,6 +74,194 @@ class IUMTransform : public Transform @@ -74,6 +74,194 @@ class IUMTransform : public Transform
74 74
75 BR_REGISTER(Transform, IUMTransform) 75 BR_REGISTER(Transform, IUMTransform)
76 76
  77 +/* Kernel Density Estimator */
  78 +struct KDE
  79 +{
  80 + float min, max;
  81 + double mean, stddev;
  82 + QList<float> bins;
  83 +
  84 + KDE() : min(0), max(1) {}
  85 + KDE(const QList<float> &scores)
  86 + {
  87 + Common::MinMax(scores, &min, &max);
  88 + Common::MeanStdDev(scores, &mean, &stddev);
  89 + double h = Common::KernelDensityBandwidth(scores);
  90 + const int size = 255;
  91 + bins.reserve(size);
  92 + for (int i=0; i<size; i++)
  93 + bins.append(Common::KernelDensityEstimation(scores, min + (max-min)*i/(size-1), h));
  94 + }
  95 +
  96 + float operator()(float score, bool gaussian = false) const
  97 + {
  98 + if (gaussian) return 1/(stddev*sqrt(2*CV_PI))*exp(-0.5*pow((score-mean)/stddev, 2));
  99 + if (score <= min) return bins.first();
  100 + if (score >= max) return bins.last();
  101 + const float x = (score-min)/(max-min)*bins.size();
  102 + const float y1 = bins[floor(x)];
  103 + const float y2 = bins[ceil(x)];
  104 + return y1 + (y2-y1)*(x-floor(x));
  105 + }
  106 +};
  107 +
  108 +QDataStream &operator<<(QDataStream &stream, const KDE &kde)
  109 +{
  110 + return stream << kde.min << kde.max << kde.mean << kde.stddev << kde.bins;
  111 +}
  112 +
  113 +QDataStream &operator>>(QDataStream &stream, KDE &kde)
  114 +{
  115 + return stream >> kde.min >> kde.max >> kde.mean >> kde.stddev >> kde.bins;
  116 +}
  117 +
  118 +/* Match Probability */
  119 +struct MP
  120 +{
  121 + KDE genuine, impostor;
  122 + MP() {}
  123 + MP(const QList<float> &genuineScores, const QList<float> &impostorScores)
  124 + : genuine(genuineScores), impostor(impostorScores) {}
  125 + float operator()(float score, bool gaussian = false, bool log = false) const
  126 + {
  127 + const float g = genuine(score, gaussian);
  128 + const float s = g / (impostor(score, gaussian) + g);
  129 + if (log) return (std::max(std::log10(s), -10.f) + 10)/10;
  130 + else return s;
  131 + }
  132 +};
  133 +
  134 +QDataStream &operator<<(QDataStream &stream, const MP &nmp)
  135 +{
  136 + return stream << nmp.genuine << nmp.impostor;
  137 +}
  138 +
  139 +QDataStream &operator>>(QDataStream &stream, MP &nmp)
  140 +{
  141 + return stream >> nmp.genuine >> nmp.impostor;
  142 +}
  143 +
  144 +/*!
  145 + * \ingroup distances
  146 + * \brief Match Probability \cite klare12
  147 + * \author Josh Klontz \cite jklontz
  148 + */
  149 +class MPDistance : public Distance
  150 +{
  151 + Q_OBJECT
  152 + Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
  153 + Q_PROPERTY(QString binKey READ get_binKey WRITE set_binKey RESET reset_binKey STORED false)
  154 + Q_PROPERTY(bool gaussian READ get_gaussian WRITE set_gaussian RESET reset_gaussian STORED false)
  155 + Q_PROPERTY(bool log READ get_log WRITE set_log RESET reset_log STORED false)
  156 + BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))
  157 + BR_PROPERTY(QString, binKey, "")
  158 + BR_PROPERTY(bool, gaussian, false)
  159 + BR_PROPERTY(bool, log, false)
  160 +
  161 + QHash<QString, MP> mps;
  162 +
  163 + void train(const TemplateList &src)
  164 + {
  165 + distance->train(src);
  166 +
  167 + const QList<int> labels = src.labels<int>();
  168 + QScopedPointer<MatrixOutput> memoryOutput(dynamic_cast<MatrixOutput*>(Output::make(".Matrix", FileList(src.size()), FileList(src.size()))));
  169 + distance->compare(src, src, memoryOutput.data());
  170 +
  171 + QHash< QString, QList<float> > genuineScores, impostorScores;
  172 + for (int i=0; i<src.size(); i++)
  173 + for (int j=0; j<i; j++) {
  174 + const float score = memoryOutput.data()->data.at<float>(i, j);
  175 + const QString bin = src[i].file.getString(binKey, "");
  176 + if (labels[i] == labels[j]) genuineScores[bin].append(score);
  177 + else impostorScores[bin].append(score);
  178 + }
  179 +
  180 + foreach (const QString &key, genuineScores.keys())
  181 + mps.insert(key, MP(genuineScores[key], impostorScores[key]));
  182 + }
  183 +
  184 + float _compare(const Template &target, const Template &query) const
  185 + {
  186 + return mps[query.file.getString(binKey, "")](distance->compare(target, query), gaussian, log);
  187 + }
  188 +
  189 + void store(QDataStream &stream) const
  190 + {
  191 + distance->store(stream);
  192 + stream << mps;
  193 + }
  194 +
  195 + void load(QDataStream &stream)
  196 + {
  197 + distance->load(stream);
  198 + stream >> mps;
  199 + }
  200 +};
  201 +
  202 +BR_REGISTER(Distance, MPDistance)
  203 +
  204 +/*!
  205 + * \ingroup distances
  206 + * \brief Linear normalizes of a distance so the mean impostor score is 0 and the mean genuine score is 1.
  207 + * \author Josh Klontz \cite jklontz
  208 + */
  209 +class UnitDistance : public Distance
  210 +{
  211 + Q_OBJECT
  212 + Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance)
  213 + Q_PROPERTY(float a READ get_a WRITE set_a RESET reset_a)
  214 + Q_PROPERTY(float b READ get_b WRITE set_b RESET reset_b)
  215 + BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))
  216 + BR_PROPERTY(float, a, 1)
  217 + BR_PROPERTY(float, b, 0)
  218 +
  219 + void train(const TemplateList &templates)
  220 + {
  221 + const TemplateList samples = templates.mid(0, 2000);
  222 + const QList<float> sampleLabels = samples.labels<float>();
  223 + QScopedPointer<MatrixOutput> memoryOutput(dynamic_cast<MatrixOutput*>(Output::make(".Matrix", FileList(samples.size()), FileList(samples.size()))));
  224 + compare(samples, samples, memoryOutput.data());
  225 +
  226 + double genuineAccumulator, impostorAccumulator;
  227 + int genuineCount, impostorCount;
  228 + genuineAccumulator = impostorAccumulator = genuineCount = impostorCount = 0;
  229 +
  230 + for (int i=0; i<samples.size(); i++) {
  231 + for (int j=0; j<i; j++) {
  232 + const float val = memoryOutput.data()->data.at<float>(i, j);
  233 + if (sampleLabels[i] == sampleLabels[j]) {
  234 + genuineAccumulator += val;
  235 + genuineCount++;
  236 + } else {
  237 + impostorAccumulator += val;
  238 + impostorCount++;
  239 + }
  240 + }
  241 + }
  242 +
  243 + if (genuineCount == 0) { qWarning("No genuine matches."); return; }
  244 + if (impostorCount == 0) { qWarning("No impostor matches."); return; }
  245 +
  246 + double genuineMean = genuineAccumulator / genuineCount;
  247 + double impostorMean = impostorAccumulator / impostorCount;
  248 +
  249 + if (genuineMean == impostorMean) { qWarning("Genuines and impostors are indistinguishable."); return; }
  250 +
  251 + a = 1.0/(genuineMean-impostorMean);
  252 + b = impostorMean;
  253 +
  254 + qDebug("a = %f, b = %f", a, b);
  255 + }
  256 +
  257 + float _compare(const Template &target, const Template &query) const
  258 + {
  259 + return a * (distance->compare(target, query) - b);
  260 + }
  261 +};
  262 +
  263 +BR_REGISTER(Distance, UnitDistance)
  264 +
77 } // namespace br 265 } // namespace br
78 266
79 #include "quality.moc" 267 #include "quality.moc"