Commit f7dff58cfe07ce9fb0341ac7f4ef4a067cc874f5

Authored by jklontz
2 parents f74da8dd f9c9e2d1

Merge pull request #7 from biometrics/kde

Implemented Imposter Score Normalization and Match Probability
sdk/core/core.cpp
... ... @@ -252,7 +252,7 @@ private:
252 252 if (words.size() > 2) qFatal("AlgorithmCore::init invalid algorithm format.");
253 253  
254 254 transform = QSharedPointer<Transform>(Transform::make(words[0], NULL));
255   - if (words.size() > 1) distance = QSharedPointer<Distance>(Factory<Distance>::make("." + words[1]));
  255 + if (words.size() > 1) distance = QSharedPointer<Distance>(Distance::make(words[1], NULL));
256 256 }
257 257 };
258 258  
... ...
sdk/openbr_plugin.cpp
... ... @@ -497,6 +497,8 @@ QString Object::argument(int index) const
497 497 return "[" + strings.join(",") + "]";
498 498 } else if (type == "br::Transform*") {
499 499 return variant.value<Transform*>()->description();
  500 + } else if (type == "br::Distance*") {
  501 + return variant.value<Distance*>()->description();
500 502 } else if (type == "QStringList") {
501 503 return "[" + variant.toStringList().join(",") + "]";
502 504 }
... ... @@ -524,6 +526,8 @@ void Object::store(QDataStream &amp;stream) const
524 526 transform->store(stream);
525 527 } else if (type == "br::Transform*") {
526 528 property.read(this).value<Transform*>()->store(stream);
  529 + } else if (type == "br::Distance*") {
  530 + property.read(this).value<Distance*>()->store(stream);
527 531 } else if (type == "bool") {
528 532 stream << property.read(this).toBool();
529 533 } else if (type == "int") {
... ... @@ -556,6 +560,8 @@ void Object::load(QDataStream &amp;stream)
556 560 transform->load(stream);
557 561 } else if (type == "br::Transform*") {
558 562 property.read(this).value<Transform*>()->load(stream);
  563 + } else if (type == "br::Distance*") {
  564 + property.read(this).value<Distance*>()->load(stream);
559 565 } else if (type == "bool") {
560 566 bool value;
561 567 stream >> value;
... ... @@ -620,6 +626,8 @@ void Object::setProperty(const QString &amp;name, const QString &amp;value)
620 626 }
621 627 } else if (type == "br::Transform*") {
622 628 variant.setValue(Transform::make(value, this));
  629 + } else if (type == "br::Distance*") {
  630 + variant.setValue(Distance::make(value, this));
623 631 } else if (type == "QStringList") {
624 632 variant.setValue(parse(value.mid(1, value.size()-2)));
625 633 } else if (type == "bool") {
... ... @@ -630,6 +638,7 @@ void Object::setProperty(const QString &amp;name, const QString &amp;value)
630 638 } else {
631 639 variant = value;
632 640 }
  641 +
633 642 if (!QObject::setProperty(qPrintable(name), variant) && !type.isEmpty())
634 643 qFatal("Failed to set %s::%s to: %s %s",
635 644 metaObject()->className(), qPrintable(name), qPrintable(value), qPrintable(type));
... ... @@ -786,6 +795,7 @@ void br::Context::initializeQt(QString sdkPath)
786 795 qRegisterMetaType< QList<int> >();
787 796 qRegisterMetaType< br::Transform* >();
788 797 qRegisterMetaType< QList<br::Transform*> >();
  798 + qRegisterMetaType< br::Distance* >();
789 799 qRegisterMetaType< cv::Mat >();
790 800  
791 801 qInstallMsgHandler(messageHandler);
... ... @@ -1270,42 +1280,17 @@ void Transform::backProject(const TemplateList &amp;dst, TemplateList &amp;src) const
1270 1280 }
1271 1281  
1272 1282 /* Distance - public methods */
1273   -void Distance::train(const TemplateList &templates)
1274   -{
1275   - const TemplateList samples = templates.mid(0, 2000);
1276   - const QList<float> sampleLabels = samples.labels<float>();
1277   - QScopedPointer<MatrixOutput> memoryOutput(dynamic_cast<MatrixOutput*>(Output::make(".Matrix", FileList(samples.size()), FileList(samples.size()))));
1278   - compare(samples, samples, memoryOutput.data());
1279   -
1280   - double genuineAccumulator, impostorAccumulator;
1281   - int genuineCount, impostorCount;
1282   - genuineAccumulator = impostorAccumulator = genuineCount = impostorCount = 0;
1283   -
1284   - for (int i=0; i<samples.size(); i++) {
1285   - for (int j=0; j<i; j++) {
1286   - const float val = memoryOutput.data()->data.at<float>(i, j);
1287   - if (sampleLabels[i] == sampleLabels[j]) {
1288   - genuineAccumulator += val;
1289   - genuineCount++;
1290   - } else {
1291   - impostorAccumulator += val;
1292   - impostorCount++;
1293   - }
1294   - }
1295   - }
1296   -
1297   - if (genuineCount == 0) { qWarning("No genuine matches."); return; }
1298   - if (impostorCount == 0) { qWarning("No impostor matches."); return; }
1299   -
1300   - double genuineMean = genuineAccumulator / genuineCount;
1301   - double impostorMean = impostorAccumulator / impostorCount;
1302   -
1303   - if (genuineMean == impostorMean) { qWarning("Genuines and impostors are indistinguishable."); return; }
  1283 +Distance *Distance::make(QString str, QObject *parent)
  1284 +{
  1285 + // Check for custom transforms
  1286 + if (Globals->abbreviations.contains(str))
  1287 + return make(Globals->abbreviations[str], parent);
1304 1288  
1305   - a = 1.0/(genuineMean-impostorMean);
1306   - b = impostorMean;
  1289 + File f = "." + str;
  1290 + Distance *distance = Factory<Distance>::make(f);
1307 1291  
1308   - qDebug("a = %f, b = %f", a, b);
  1292 + distance->setParent(parent);
  1293 + return distance;
1309 1294 }
1310 1295  
1311 1296 void Distance::compare(const TemplateList &target, const TemplateList &query, Output *output) const
... ... @@ -1336,7 +1321,7 @@ float Distance::compare(const Template &amp;target, const Template &amp;query) const
1336 1321 return -std::numeric_limits<float>::max();
1337 1322 }
1338 1323  
1339   - return a * (_compare(target, query) - b);
  1324 + return _compare(target, query);
1340 1325 }
1341 1326  
1342 1327 QList<float> Distance::compare(const TemplateList &targets, const Template &query) const
... ...
sdk/openbr_plugin.h
... ... @@ -971,7 +971,6 @@ public:
971 971 protected:
972 972 Transform(bool independent = true); /*!< \brief Construct a transform. */
973 973 inline Transform *make(const QString &description) { return make(description, this); } /*!< \brief Make a subtransform. */
974   -
975 974 };
976 975  
977 976 /*!
... ... @@ -1065,19 +1064,19 @@ class BR_EXPORT Distance : public Object
1065 1064 {
1066 1065 Q_OBJECT
1067 1066  
1068   - // Score normalization
1069   - Q_PROPERTY(float a READ get_a WRITE set_a RESET reset_a)
1070   - Q_PROPERTY(float b READ get_b WRITE set_b RESET reset_b)
1071   - BR_PROPERTY(float, a, 1)
1072   - BR_PROPERTY(float, b, 0)
1073   -
1074 1067 public:
  1068 + virtual ~Distance() {}
  1069 + static Distance *make(QString str, QObject *parent); /*!< \brief Make a distance from a string. */
  1070 +
1075 1071 static QSharedPointer<Distance> fromAlgorithm(const QString &algorithm); /*!< \brief Retrieve an algorithm's distance. */
1076   - virtual void train(const TemplateList &src); /*!< \brief Train the distance. */
  1072 + virtual void train(const TemplateList &src) { (void) src; } /*!< \brief Train the distance. */
1077 1073 virtual void compare(const TemplateList &target, const TemplateList &query, Output *output) const; /*!< \brief Compare two template lists. */
1078 1074 float compare(const Template &target, const Template &query) const; /*!< \brief Compute the normalized distance between two templates. */
1079 1075 QList<float> compare(const TemplateList &targets, const Template &query) const; /*!< \brief Compute the normalized distance between a template and a template list. */
1080 1076  
  1077 +protected:
  1078 + inline Distance *make(const QString &description) { return make(description, this); } /*!< \brief Make a subdistance. */
  1079 +
1081 1080 private:
1082 1081 virtual void compareBlock(const TemplateList &target, const TemplateList &query, Output *output, int targetOffset, int queryOffset) const;
1083 1082 virtual float _compare(const Template &a, const Template &b) const = 0; /*!< \brief Compute the distance between two templates. */
... ... @@ -1132,6 +1131,7 @@ Q_DECLARE_METATYPE(QList&lt;float&gt;)
1132 1131 Q_DECLARE_METATYPE(QList<int>)
1133 1132 Q_DECLARE_METATYPE(br::Transform*)
1134 1133 Q_DECLARE_METATYPE(QList<br::Transform*>)
  1134 +Q_DECLARE_METATYPE(br::Distance*)
1135 1135 Q_DECLARE_METATYPE(cv::Mat)
1136 1136  
1137 1137 #endif // __OPENBR_PLUGIN_H
... ...
sdk/plugins/algorithms.cpp
... ... @@ -70,6 +70,7 @@ class AlgorithmsInitializer : public Initializer
70 70 Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+PCA(0.95,instances=-1)+Cat)");
71 71 Globals->abbreviations.insert("AgeRegressor", "Center(Range,instances=-1)+SVM(RBF,EPS_SVR,instances=100)");
72 72 Globals->abbreviations.insert("GenderClassifier", "Center(Range,instances=-1)+SVM(RBF,C_SVC,instances=4000)");
  73 + Globals->abbreviations.insert("UCharL1", "Unit(ByteL1)");
73 74 }
74 75 };
75 76  
... ...
sdk/plugins/compare.cpp
... ... @@ -119,7 +119,7 @@ BR_REGISTER(Distance, DistDistance)
119 119 * \brief Fast 8-bit L1 distance
120 120 * \author Josh Klontz \cite jklontz
121 121 */
122   -class UCharL1Distance : public Distance
  122 +class ByteL1Distance : public Distance
123 123 {
124 124 Q_OBJECT
125 125  
... ... @@ -129,7 +129,7 @@ class UCharL1Distance : public Distance
129 129 }
130 130 };
131 131  
132   -BR_REGISTER(Distance, UCharL1Distance)
  132 +BR_REGISTER(Distance, ByteL1Distance)
133 133  
134 134  
135 135 /*!
... ... @@ -137,7 +137,7 @@ BR_REGISTER(Distance, UCharL1Distance)
137 137 * \brief Fast 4-bit L1 distance
138 138 * \author Josh Klontz \cite jklontz
139 139 */
140   -class PackedUCharL1Distance : public Distance
  140 +class HalfByteL1Distance : public Distance
141 141 {
142 142 Q_OBJECT
143 143  
... ... @@ -147,7 +147,7 @@ class PackedUCharL1Distance : public Distance
147 147 }
148 148 };
149 149  
150   -BR_REGISTER(Distance, PackedUCharL1Distance)
  150 +BR_REGISTER(Distance, HalfByteL1Distance)
151 151  
152 152 /*!
153 153 * \ingroup distances
... ...
sdk/plugins/quality.cpp
... ... @@ -18,14 +18,20 @@ class IUMTransform : public Transform
18 18 Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
19 19 Q_PROPERTY(double mean READ get_mean WRITE set_mean RESET reset_mean)
20 20 Q_PROPERTY(double stddev READ get_stddev WRITE set_stddev RESET reset_stddev)
21   - BR_PROPERTY(br::Distance*, distance, Factory<Distance>::make(".Dist(L2)"))
  21 + BR_PROPERTY(br::Distance*, distance, Distance::make("Dist(L2)", this))
22 22 BR_PROPERTY(double, mean, 0)
23 23 BR_PROPERTY(double, stddev, 1)
24 24 br::TemplateList impostors;
25 25  
26 26 float calculateIUM(const Template &probe, const TemplateList &gallery) const
27 27 {
28   - QList<float> scores = distance->compare(gallery, probe);
  28 + const int probeLabel = probe.file.label();
  29 + TemplateList subset = gallery;
  30 + for (int j=subset.size()-1; j>=0; j--)
  31 + if (subset[j].file.label() == probeLabel)
  32 + subset.removeAt(j);
  33 +
  34 + QList<float> scores = distance->compare(subset, probe);
29 35 float min, max;
30 36 Common::MinMax(scores, &min, &max);
31 37 double mean;
... ... @@ -39,14 +45,8 @@ class IUMTransform : public Transform
39 45 impostors = data;
40 46  
41 47 QList<float> iums; iums.reserve(impostors.size());
42   - QList<int> labels = impostors.labels<int>();
43   - for (int i=0; i<data.size(); i++) {
44   - TemplateList subset = impostors;
45   - for (int j=subset.size()-1; j>=0; j--)
46   - if (labels[j] == labels[i])
47   - subset.removeAt(j);
48   - iums.append(calculateIUM(impostors[i], subset));
49   - }
  48 + for (int i=0; i<data.size(); i++)
  49 + iums.append(calculateIUM(impostors[i], impostors));
50 50  
51 51 Common::MeanStdDev(iums, &mean, &stddev);
52 52 }
... ... @@ -74,6 +74,194 @@ class IUMTransform : public Transform
74 74  
75 75 BR_REGISTER(Transform, IUMTransform)
76 76  
  77 +/* Kernel Density Estimator */
  78 +struct KDE
  79 +{
  80 + float min, max;
  81 + double mean, stddev;
  82 + QList<float> bins;
  83 +
  84 + KDE() : min(0), max(1) {}
  85 + KDE(const QList<float> &scores)
  86 + {
  87 + Common::MinMax(scores, &min, &max);
  88 + Common::MeanStdDev(scores, &mean, &stddev);
  89 + double h = Common::KernelDensityBandwidth(scores);
  90 + const int size = 255;
  91 + bins.reserve(size);
  92 + for (int i=0; i<size; i++)
  93 + bins.append(Common::KernelDensityEstimation(scores, min + (max-min)*i/(size-1), h));
  94 + }
  95 +
  96 + float operator()(float score, bool gaussian = false) const
  97 + {
  98 + if (gaussian) return 1/(stddev*sqrt(2*CV_PI))*exp(-0.5*pow((score-mean)/stddev, 2));
  99 + if (score <= min) return bins.first();
  100 + if (score >= max) return bins.last();
  101 + const float x = (score-min)/(max-min)*bins.size();
  102 + const float y1 = bins[floor(x)];
  103 + const float y2 = bins[ceil(x)];
  104 + return y1 + (y2-y1)*(x-floor(x));
  105 + }
  106 +};
  107 +
  108 +QDataStream &operator<<(QDataStream &stream, const KDE &kde)
  109 +{
  110 + return stream << kde.min << kde.max << kde.mean << kde.stddev << kde.bins;
  111 +}
  112 +
  113 +QDataStream &operator>>(QDataStream &stream, KDE &kde)
  114 +{
  115 + return stream >> kde.min >> kde.max >> kde.mean >> kde.stddev >> kde.bins;
  116 +}
  117 +
  118 +/* Match Probability */
  119 +struct MP
  120 +{
  121 + KDE genuine, impostor;
  122 + MP() {}
  123 + MP(const QList<float> &genuineScores, const QList<float> &impostorScores)
  124 + : genuine(genuineScores), impostor(impostorScores) {}
  125 + float operator()(float score, bool gaussian = false, bool log = false) const
  126 + {
  127 + const float g = genuine(score, gaussian);
  128 + const float s = g / (impostor(score, gaussian) + g);
  129 + if (log) return (std::max(std::log10(s), -10.f) + 10)/10;
  130 + else return s;
  131 + }
  132 +};
  133 +
  134 +QDataStream &operator<<(QDataStream &stream, const MP &nmp)
  135 +{
  136 + return stream << nmp.genuine << nmp.impostor;
  137 +}
  138 +
  139 +QDataStream &operator>>(QDataStream &stream, MP &nmp)
  140 +{
  141 + return stream >> nmp.genuine >> nmp.impostor;
  142 +}
  143 +
  144 +/*!
  145 + * \ingroup distances
  146 + * \brief Match Probability \cite klare12
  147 + * \author Josh Klontz \cite jklontz
  148 + */
  149 +class MPDistance : public Distance
  150 +{
  151 + Q_OBJECT
  152 + Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
  153 + Q_PROPERTY(QString binKey READ get_binKey WRITE set_binKey RESET reset_binKey STORED false)
  154 + Q_PROPERTY(bool gaussian READ get_gaussian WRITE set_gaussian RESET reset_gaussian STORED false)
  155 + Q_PROPERTY(bool log READ get_log WRITE set_log RESET reset_log STORED false)
  156 + BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))
  157 + BR_PROPERTY(QString, binKey, "")
  158 + BR_PROPERTY(bool, gaussian, false)
  159 + BR_PROPERTY(bool, log, false)
  160 +
  161 + QHash<QString, MP> mps;
  162 +
  163 + void train(const TemplateList &src)
  164 + {
  165 + distance->train(src);
  166 +
  167 + const QList<int> labels = src.labels<int>();
  168 + QScopedPointer<MatrixOutput> memoryOutput(dynamic_cast<MatrixOutput*>(Output::make(".Matrix", FileList(src.size()), FileList(src.size()))));
  169 + distance->compare(src, src, memoryOutput.data());
  170 +
  171 + QHash< QString, QList<float> > genuineScores, impostorScores;
  172 + for (int i=0; i<src.size(); i++)
  173 + for (int j=0; j<i; j++) {
  174 + const float score = memoryOutput.data()->data.at<float>(i, j);
  175 + const QString bin = src[i].file.getString(binKey, "");
  176 + if (labels[i] == labels[j]) genuineScores[bin].append(score);
  177 + else impostorScores[bin].append(score);
  178 + }
  179 +
  180 + foreach (const QString &key, genuineScores.keys())
  181 + mps.insert(key, MP(genuineScores[key], impostorScores[key]));
  182 + }
  183 +
  184 + float _compare(const Template &target, const Template &query) const
  185 + {
  186 + return mps[query.file.getString(binKey, "")](distance->compare(target, query), gaussian, log);
  187 + }
  188 +
  189 + void store(QDataStream &stream) const
  190 + {
  191 + distance->store(stream);
  192 + stream << mps;
  193 + }
  194 +
  195 + void load(QDataStream &stream)
  196 + {
  197 + distance->load(stream);
  198 + stream >> mps;
  199 + }
  200 +};
  201 +
  202 +BR_REGISTER(Distance, MPDistance)
  203 +
  204 +/*!
  205 + * \ingroup distances
  206 + * \brief Linear normalizes of a distance so the mean impostor score is 0 and the mean genuine score is 1.
  207 + * \author Josh Klontz \cite jklontz
  208 + */
  209 +class UnitDistance : public Distance
  210 +{
  211 + Q_OBJECT
  212 + Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance)
  213 + Q_PROPERTY(float a READ get_a WRITE set_a RESET reset_a)
  214 + Q_PROPERTY(float b READ get_b WRITE set_b RESET reset_b)
  215 + BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))
  216 + BR_PROPERTY(float, a, 1)
  217 + BR_PROPERTY(float, b, 0)
  218 +
  219 + void train(const TemplateList &templates)
  220 + {
  221 + const TemplateList samples = templates.mid(0, 2000);
  222 + const QList<float> sampleLabels = samples.labels<float>();
  223 + QScopedPointer<MatrixOutput> memoryOutput(dynamic_cast<MatrixOutput*>(Output::make(".Matrix", FileList(samples.size()), FileList(samples.size()))));
  224 + compare(samples, samples, memoryOutput.data());
  225 +
  226 + double genuineAccumulator, impostorAccumulator;
  227 + int genuineCount, impostorCount;
  228 + genuineAccumulator = impostorAccumulator = genuineCount = impostorCount = 0;
  229 +
  230 + for (int i=0; i<samples.size(); i++) {
  231 + for (int j=0; j<i; j++) {
  232 + const float val = memoryOutput.data()->data.at<float>(i, j);
  233 + if (sampleLabels[i] == sampleLabels[j]) {
  234 + genuineAccumulator += val;
  235 + genuineCount++;
  236 + } else {
  237 + impostorAccumulator += val;
  238 + impostorCount++;
  239 + }
  240 + }
  241 + }
  242 +
  243 + if (genuineCount == 0) { qWarning("No genuine matches."); return; }
  244 + if (impostorCount == 0) { qWarning("No impostor matches."); return; }
  245 +
  246 + double genuineMean = genuineAccumulator / genuineCount;
  247 + double impostorMean = impostorAccumulator / impostorCount;
  248 +
  249 + if (genuineMean == impostorMean) { qWarning("Genuines and impostors are indistinguishable."); return; }
  250 +
  251 + a = 1.0/(genuineMean-impostorMean);
  252 + b = impostorMean;
  253 +
  254 + qDebug("a = %f, b = %f", a, b);
  255 + }
  256 +
  257 + float _compare(const Template &target, const Template &query) const
  258 + {
  259 + return a * (distance->compare(target, query) - b);
  260 + }
  261 +};
  262 +
  263 +BR_REGISTER(Distance, UnitDistance)
  264 +
77 265 } // namespace br
78 266  
79 267 #include "quality.moc"
... ...