Commit ad88ddc7992e5ad6a48dba0d1362667318e8058e
1 parent
93733b79
Faster MatchProbability training
Parallelize KDE lookup table building, also introduce a flag, and don't do KDE training at all, if Gaussians will instead be used to model the score distributions
Showing
1 changed file
with
29 additions
and
8 deletions
openbr/plugins/quality.cpp
| @@ -77,6 +77,12 @@ class ImpostorUniquenessMeasureTransform : public Transform | @@ -77,6 +77,12 @@ class ImpostorUniquenessMeasureTransform : public Transform | ||
| 77 | 77 | ||
| 78 | BR_REGISTER(Transform, ImpostorUniquenessMeasureTransform) | 78 | BR_REGISTER(Transform, ImpostorUniquenessMeasureTransform) |
| 79 | 79 | ||
| 80 | + | ||
| 81 | +float KDEPointer(const QList<float> *scores, double x, double h) | ||
| 82 | +{ | ||
| 83 | + return Common::KernelDensityEstimation(*scores, x, h); | ||
| 84 | +} | ||
| 85 | + | ||
| 80 | /* Kernel Density Estimator */ | 86 | /* Kernel Density Estimator */ |
| 81 | struct KDE | 87 | struct KDE |
| 82 | { | 88 | { |
| @@ -85,20 +91,35 @@ struct KDE | @@ -85,20 +91,35 @@ struct KDE | ||
| 85 | QList<float> bins; | 91 | QList<float> bins; |
| 86 | 92 | ||
| 87 | KDE() : min(0), max(1), mean(0), stddev(1) {} | 93 | KDE() : min(0), max(1), mean(0), stddev(1) {} |
| 88 | - KDE(const QList<float> &scores) | 94 | + |
| 95 | + KDE(const QList<float> &scores, bool trainKDE) | ||
| 89 | { | 96 | { |
| 90 | Common::MinMax(scores, &min, &max); | 97 | Common::MinMax(scores, &min, &max); |
| 91 | Common::MeanStdDev(scores, &mean, &stddev); | 98 | Common::MeanStdDev(scores, &mean, &stddev); |
| 99 | + | ||
| 100 | + if (!trainKDE) | ||
| 101 | + return; | ||
| 102 | + | ||
| 92 | double h = Common::KernelDensityBandwidth(scores); | 103 | double h = Common::KernelDensityBandwidth(scores); |
| 93 | const int size = 255; | 104 | const int size = 255; |
| 94 | bins.reserve(size); | 105 | bins.reserve(size); |
| 95 | - for (int i=0; i<size; i++) | ||
| 96 | - bins.append(Common::KernelDensityEstimation(scores, min + (max-min)*i/(size-1), h)); | 106 | + |
| 107 | + QFutureSynchronizer<float> futures; | ||
| 108 | + | ||
| 109 | + for (int i=0; i < size; i++) | ||
| 110 | + futures.addFuture(QtConcurrent::run(KDEPointer, &scores, min + (max-min)*i/(size-1), h)); | ||
| 111 | + futures.waitForFinished(); | ||
| 112 | + | ||
| 113 | + foreach(const QFuture<float> & future, futures.futures()) | ||
| 114 | + bins.append(future.result()); | ||
| 97 | } | 115 | } |
| 98 | 116 | ||
| 99 | float operator()(float score, bool gaussian = true) const | 117 | float operator()(float score, bool gaussian = true) const |
| 100 | { | 118 | { |
| 101 | if (gaussian) return 1/(stddev*sqrt(2*CV_PI))*exp(-0.5*pow((score-mean)/stddev, 2)); | 119 | if (gaussian) return 1/(stddev*sqrt(2*CV_PI))*exp(-0.5*pow((score-mean)/stddev, 2)); |
| 120 | + if (bins.empty()) | ||
| 121 | + return -std::numeric_limits<float>::max(); | ||
| 122 | + | ||
| 102 | if (score <= min) return bins.first(); | 123 | if (score <= min) return bins.first(); |
| 103 | if (score >= max) return bins.last(); | 124 | if (score >= max) return bins.last(); |
| 104 | const float x = (score-min)/(max-min)*bins.size(); | 125 | const float x = (score-min)/(max-min)*bins.size(); |
| @@ -123,8 +144,8 @@ struct MP | @@ -123,8 +144,8 @@ struct MP | ||
| 123 | { | 144 | { |
| 124 | KDE genuine, impostor; | 145 | KDE genuine, impostor; |
| 125 | MP() {} | 146 | MP() {} |
| 126 | - MP(const QList<float> &genuineScores, const QList<float> &impostorScores) | ||
| 127 | - : genuine(genuineScores), impostor(impostorScores) {} | 147 | + MP(const QList<float> &genuineScores, const QList<float> &impostorScores, bool trainKDE) |
| 148 | + : genuine(genuineScores, trainKDE), impostor(impostorScores, trainKDE) {} | ||
| 128 | float operator()(float score, bool gaussian = true) const | 149 | float operator()(float score, bool gaussian = true) const |
| 129 | { | 150 | { |
| 130 | const float g = genuine(score, gaussian); | 151 | const float g = genuine(score, gaussian); |
| @@ -165,7 +186,7 @@ class MatchProbabilityDistance : public Distance | @@ -165,7 +186,7 @@ class MatchProbabilityDistance : public Distance | ||
| 165 | const QList<int> labels = src.indexProperty(inputVariable); | 186 | const QList<int> labels = src.indexProperty(inputVariable); |
| 166 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); | 187 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); |
| 167 | distance->compare(src, src, matrixOutput.data()); | 188 | distance->compare(src, src, matrixOutput.data()); |
| 168 | - | 189 | + |
| 169 | QList<float> genuineScores, impostorScores; | 190 | QList<float> genuineScores, impostorScores; |
| 170 | genuineScores.reserve(labels.size()); | 191 | genuineScores.reserve(labels.size()); |
| 171 | impostorScores.reserve(labels.size()*labels.size()); | 192 | impostorScores.reserve(labels.size()*labels.size()); |
| @@ -178,8 +199,8 @@ class MatchProbabilityDistance : public Distance | @@ -178,8 +199,8 @@ class MatchProbabilityDistance : public Distance | ||
| 178 | else impostorScores.append(score); | 199 | else impostorScores.append(score); |
| 179 | } | 200 | } |
| 180 | } | 201 | } |
| 181 | - | ||
| 182 | - mp = MP(genuineScores, impostorScores); | 202 | + |
| 203 | + mp = MP(genuineScores, impostorScores, !gaussian); | ||
| 183 | } | 204 | } |
| 184 | 205 | ||
| 185 | float compare(const Template &target, const Template &query) const | 206 | float compare(const Template &target, const Template &query) const |