Commit ad88ddc7992e5ad6a48dba0d1362667318e8058e

Authored by Charles Otto
1 parent 93733b79

Faster MatchProbability training

Parallelize KDE lookup table building, also introduce a flag, and
don't do KDE training at all, if Gaussians will instead be used to
model the score distributions
Showing 1 changed file with 29 additions and 8 deletions
openbr/plugins/quality.cpp
@@ -77,6 +77,12 @@ class ImpostorUniquenessMeasureTransform : public Transform @@ -77,6 +77,12 @@ class ImpostorUniquenessMeasureTransform : public Transform
77 77
78 BR_REGISTER(Transform, ImpostorUniquenessMeasureTransform) 78 BR_REGISTER(Transform, ImpostorUniquenessMeasureTransform)
79 79
  80 +
  81 +float KDEPointer(const QList<float> *scores, double x, double h)
  82 +{
  83 + return Common::KernelDensityEstimation(*scores, x, h);
  84 +}
  85 +
80 /* Kernel Density Estimator */ 86 /* Kernel Density Estimator */
81 struct KDE 87 struct KDE
82 { 88 {
@@ -85,20 +91,35 @@ struct KDE @@ -85,20 +91,35 @@ struct KDE
85 QList<float> bins; 91 QList<float> bins;
86 92
87 KDE() : min(0), max(1), mean(0), stddev(1) {} 93 KDE() : min(0), max(1), mean(0), stddev(1) {}
88 - KDE(const QList<float> &scores) 94 +
  95 + KDE(const QList<float> &scores, bool trainKDE)
89 { 96 {
90 Common::MinMax(scores, &min, &max); 97 Common::MinMax(scores, &min, &max);
91 Common::MeanStdDev(scores, &mean, &stddev); 98 Common::MeanStdDev(scores, &mean, &stddev);
  99 +
  100 + if (!trainKDE)
  101 + return;
  102 +
92 double h = Common::KernelDensityBandwidth(scores); 103 double h = Common::KernelDensityBandwidth(scores);
93 const int size = 255; 104 const int size = 255;
94 bins.reserve(size); 105 bins.reserve(size);
95 - for (int i=0; i<size; i++)  
96 - bins.append(Common::KernelDensityEstimation(scores, min + (max-min)*i/(size-1), h)); 106 +
  107 + QFutureSynchronizer<float> futures;
  108 +
  109 + for (int i=0; i < size; i++)
  110 + futures.addFuture(QtConcurrent::run(KDEPointer, &scores, min + (max-min)*i/(size-1), h));
  111 + futures.waitForFinished();
  112 +
  113 + foreach(const QFuture<float> & future, futures.futures())
  114 + bins.append(future.result());
97 } 115 }
98 116
99 float operator()(float score, bool gaussian = true) const 117 float operator()(float score, bool gaussian = true) const
100 { 118 {
101 if (gaussian) return 1/(stddev*sqrt(2*CV_PI))*exp(-0.5*pow((score-mean)/stddev, 2)); 119 if (gaussian) return 1/(stddev*sqrt(2*CV_PI))*exp(-0.5*pow((score-mean)/stddev, 2));
  120 + if (bins.empty())
  121 + return -std::numeric_limits<float>::max();
  122 +
102 if (score <= min) return bins.first(); 123 if (score <= min) return bins.first();
103 if (score >= max) return bins.last(); 124 if (score >= max) return bins.last();
104 const float x = (score-min)/(max-min)*bins.size(); 125 const float x = (score-min)/(max-min)*bins.size();
@@ -123,8 +144,8 @@ struct MP @@ -123,8 +144,8 @@ struct MP
123 { 144 {
124 KDE genuine, impostor; 145 KDE genuine, impostor;
125 MP() {} 146 MP() {}
126 - MP(const QList<float> &genuineScores, const QList<float> &impostorScores)  
127 - : genuine(genuineScores), impostor(impostorScores) {} 147 + MP(const QList<float> &genuineScores, const QList<float> &impostorScores, bool trainKDE)
  148 + : genuine(genuineScores, trainKDE), impostor(impostorScores, trainKDE) {}
128 float operator()(float score, bool gaussian = true) const 149 float operator()(float score, bool gaussian = true) const
129 { 150 {
130 const float g = genuine(score, gaussian); 151 const float g = genuine(score, gaussian);
@@ -165,7 +186,7 @@ class MatchProbabilityDistance : public Distance @@ -165,7 +186,7 @@ class MatchProbabilityDistance : public Distance
165 const QList<int> labels = src.indexProperty(inputVariable); 186 const QList<int> labels = src.indexProperty(inputVariable);
166 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); 187 QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size())));
167 distance->compare(src, src, matrixOutput.data()); 188 distance->compare(src, src, matrixOutput.data());
168 - 189 +
169 QList<float> genuineScores, impostorScores; 190 QList<float> genuineScores, impostorScores;
170 genuineScores.reserve(labels.size()); 191 genuineScores.reserve(labels.size());
171 impostorScores.reserve(labels.size()*labels.size()); 192 impostorScores.reserve(labels.size()*labels.size());
@@ -178,8 +199,8 @@ class MatchProbabilityDistance : public Distance @@ -178,8 +199,8 @@ class MatchProbabilityDistance : public Distance
178 else impostorScores.append(score); 199 else impostorScores.append(score);
179 } 200 }
180 } 201 }
181 -  
182 - mp = MP(genuineScores, impostorScores); 202 +
  203 + mp = MP(genuineScores, impostorScores, !gaussian);
183 } 204 }
184 205
185 float compare(const Template &target, const Template &query) const 206 float compare(const Template &target, const Template &query) const