diff --git a/openbr/plugins/quality.cpp b/openbr/plugins/quality.cpp index 469009e..11d4bbd 100644 --- a/openbr/plugins/quality.cpp +++ b/openbr/plugins/quality.cpp @@ -77,6 +77,12 @@ class ImpostorUniquenessMeasureTransform : public Transform BR_REGISTER(Transform, ImpostorUniquenessMeasureTransform) + +float KDEPointer(const QList *scores, double x, double h) +{ + return Common::KernelDensityEstimation(*scores, x, h); +} + /* Kernel Density Estimator */ struct KDE { @@ -85,20 +91,35 @@ struct KDE QList bins; KDE() : min(0), max(1), mean(0), stddev(1) {} - KDE(const QList &scores) + + KDE(const QList &scores, bool trainKDE) { Common::MinMax(scores, &min, &max); Common::MeanStdDev(scores, &mean, &stddev); + + if (!trainKDE) + return; + double h = Common::KernelDensityBandwidth(scores); const int size = 255; bins.reserve(size); - for (int i=0; i futures; + + for (int i=0; i < size; i++) + futures.addFuture(QtConcurrent::run(KDEPointer, &scores, min + (max-min)*i/(size-1), h)); + futures.waitForFinished(); + + foreach(const QFuture & future, futures.futures()) + bins.append(future.result()); } float operator()(float score, bool gaussian = true) const { if (gaussian) return 1/(stddev*sqrt(2*CV_PI))*exp(-0.5*pow((score-mean)/stddev, 2)); + if (bins.empty()) + return -std::numeric_limits::max(); + if (score <= min) return bins.first(); if (score >= max) return bins.last(); const float x = (score-min)/(max-min)*bins.size(); @@ -123,8 +144,8 @@ struct MP { KDE genuine, impostor; MP() {} - MP(const QList &genuineScores, const QList &impostorScores) - : genuine(genuineScores), impostor(impostorScores) {} + MP(const QList &genuineScores, const QList &impostorScores, bool trainKDE) + : genuine(genuineScores, trainKDE), impostor(impostorScores, trainKDE) {} float operator()(float score, bool gaussian = true) const { const float g = genuine(score, gaussian); @@ -165,7 +186,7 @@ class MatchProbabilityDistance : public Distance const QList labels = src.indexProperty(inputVariable); QScopedPointer matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); distance->compare(src, src, matrixOutput.data()); - + QList genuineScores, impostorScores; genuineScores.reserve(labels.size()); impostorScores.reserve(labels.size()*labels.size()); @@ -178,8 +199,8 @@ class MatchProbabilityDistance : public Distance else impostorScores.append(score); } } - - mp = MP(genuineScores, impostorScores); + + mp = MP(genuineScores, impostorScores, !gaussian); } float compare(const Template &target, const Template &query) const