Commit 6046e0b9bbbd0109c371d27003f61edb8d0b6f27

Authored by sklum
2 parents 9deea8e7 f95f259e

Merge pull request #148 from biometrics/score_level_fusion

Score level fusion support
openbr/core/core.cpp
... ... @@ -56,11 +56,6 @@ struct AlgorithmCore
56 56  
57 57 TemplateList data(TemplateList::fromGallery(input));
58 58  
59   - // set the Train bool metadata, in case a Transform's project
60   - // needs to know if it's called during train or enroll
61   - for (int i=0; i<data.size(); i++)
62   - data[i].file.set("Train", true);
63   -
64 59 if (transform.isNull()) qFatal("Null transform.");
65 60 qDebug("%d Training Files", data.size());
66 61  
... ... @@ -463,7 +458,7 @@ void br::Convert(const File &amp;fileType, const File &amp;inputFile, const File &amp;output
463 458  
464 459 if ((targetFiles.size() != m.cols || queryFiles.size() != m.rows)
465 460 && (m.cols != 1 || targetFiles.size() != m.rows || queryFiles.size() != m.rows))
466   - qFatal("Similarity matrix and file size mismatch.");
  461 + qFatal("Similarity matrix (%d, %d) and header (%d, %d) size mismatch.", m.rows, m.cols, queryFiles.size(), targetFiles.size());
467 462  
468 463 QSharedPointer<Output> o(Factory<Output>::make(outputFile));
469 464 o->initialize(targetFiles, queryFiles);
... ...
openbr/core/fuse.cpp
... ... @@ -30,6 +30,9 @@ using namespace cv;
30 30  
31 31 static void normalizeMatrix(Mat &matrix, const Mat &mask, const QString &method)
32 32 {
  33 + if (matrix.rows != mask.rows && matrix.cols != mask.cols)
  34 + qFatal("Similarity matrix (%d, %d) and mask (%d, %d) size mismatch.", matrix.rows, matrix.cols, mask.rows, mask.cols);
  35 +
33 36 if (method == "None") return;
34 37  
35 38 QList<float> vals; vals.reserve(matrix.rows*matrix.cols);
... ...
openbr/openbr_plugin.h
... ... @@ -483,7 +483,7 @@ struct TemplateList : public QList&lt;Template&gt;
483 483 }
484 484  
485 485 /*!
486   - * \brief Returns a #br::TemplateList containing templates with one matrix at the specified index \em index.
  486 + * \brief Returns a list of #br::TemplateList with each #br::Template in a given #br::TemplateList containing the number of matrices specified by \em partitionSizes.
487 487 */
488 488 QList<TemplateList> partition(const QList<int> &partitionSizes) const
489 489 {
... ...
openbr/plugins/distance.cpp
... ... @@ -179,58 +179,88 @@ BR_REGISTER(Distance, PipeDistance)
179 179  
180 180 /*!
181 181 * \ingroup distances
182   - * \brief Computes an operation on distances across multiple matrices of compared templates
  182 + * \brief Fuses similarity scores across multiple matrices of compared templates
183 183 * \author Scott Klum \cite sklum
184 184 * \note Operation: Mean, sum, min, max are supported.
185 185 */
186   -class OperationDistance : public Distance
  186 +class FuseDistance : public Distance
187 187 {
188 188 Q_OBJECT
189 189 Q_ENUMS(Operation)
190   - Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
  190 + Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)
191 191 Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false)
192 192  
  193 + QList<br::Distance*> distances;
  194 +
193 195 public:
194 196 /*!< */
195 197 enum Operation {Mean, Sum, Max, Min};
196 198  
197 199 private:
198   - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))
  200 + BR_PROPERTY(QString, description, "IdenticalDistance")
199 201 BR_PROPERTY(Operation, operation, Mean)
200 202  
201 203 void train(const TemplateList &src)
202 204 {
203   - distance->train(src);
  205 + // Partition the templates by matrix
  206 + QList<int> split;
  207 + for (int i=0; i<src.at(0).size(); i++) split.append(1);
  208 +
  209 + QList<TemplateList> partitionedSrc = src.partition(split);
  210 +
  211 + while (distances.size() < partitionedSrc.size())
  212 + distances.append(make(description));
  213 +
  214 + // Train on each of the partitions
  215 + for (int i=0; i<distances.size(); i++)
  216 + distances[i]->train(partitionedSrc[i]);
204 217 }
205 218  
206 219 float compare(const Template &a, const Template &b) const
207 220 {
208 221 if (a.size() != b.size()) qFatal("Comparison size mismatch");
209 222  
210   - QList<float> distances;
211   - for (int i = 0; i < a.size(); i++)
212   - distances.append(distance->compare(a[i],b[i]));
  223 + QList<float> scores;
  224 + for (int i=0; i<distances.size(); i++)
  225 + scores.append(distances[i]->compare(Template(a.file, a[i]),Template(b.file, b[i])));
213 226  
214 227 switch (operation) {
215 228 case Mean:
216   - return std::accumulate(distances.begin(),distances.end(),0.0)/(float)distances.size();
  229 + return std::accumulate(scores.begin(),scores.end(),0.0)/(float)scores.size();
217 230 break;
218 231 case Sum:
219   - return std::accumulate(distances.begin(),distances.end(),0.0);
  232 + return std::accumulate(scores.begin(),scores.end(),0.0);
220 233 break;
221 234 case Min:
222   - return *std::min_element(distances.begin(),distances.end());
  235 + return *std::min_element(scores.begin(),scores.end());
223 236 break;
224 237 case Max:
225   - return *std::max_element(distances.begin(),distances.end());
  238 + return *std::max_element(scores.begin(),scores.end());
226 239 break;
227 240 default:
228 241 qFatal("Invalid operation.");
229 242 }
230 243 }
  244 +
  245 + void store(QDataStream &stream) const
  246 + {
  247 + stream << distances.size();
  248 + foreach (Distance *distance, distances)
  249 + distance->store(stream);
  250 + }
  251 +
  252 + void load(QDataStream &stream)
  253 + {
  254 + int numDistances;
  255 + stream >> numDistances;
  256 + while (distances.size() < numDistances)
  257 + distances.append(make(description));
  258 + foreach (Distance *distance, distances)
  259 + distance->load(stream);
  260 + }
231 261 };
232 262  
233   -BR_REGISTER(Distance, OperationDistance)
  263 +BR_REGISTER(Distance, FuseDistance)
234 264  
235 265 /*!
236 266 * \ingroup distances
... ...
openbr/plugins/quality.cpp
... ... @@ -211,6 +211,65 @@ protected:
211 211  
212 212 BR_REGISTER(Distance, MatchProbabilityDistance)
213 213  
  214 +class ZScoreDistance : public Distance
  215 +{
  216 + Q_OBJECT
  217 + Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
  218 + Q_PROPERTY(bool crossModality READ get_crossModality WRITE set_crossModality RESET reset_crossModality STORED false)
  219 + BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))
  220 + BR_PROPERTY(bool, crossModality, false)
  221 +
  222 + float min, max;
  223 + double mean, stddev;
  224 +
  225 + void train(const TemplateList &src)
  226 + {
  227 + distance->train(src);
  228 +
  229 + QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size())));
  230 + distance->compare(src, src, matrixOutput.data());
  231 +
  232 + QList<float> scores;
  233 + scores.reserve(src.size()*src.size());
  234 + for (int i=0; i<src.size(); i++) {
  235 + for (int j=0; j<i; j++) {
  236 + const float score = matrixOutput.data()->data.at<float>(i, j);
  237 + if (score == -std::numeric_limits<float>::max()) continue;
  238 + if (crossModality && src[i].file.get<QString>("MODALITY") == src[j].file.get<QString>("MODALITY")) continue;
  239 + scores.append(score);
  240 + }
  241 + }
  242 +
  243 + Common::MinMax(scores, &min, &max);
  244 + Common::MeanStdDev(scores, &mean, &stddev);
  245 +
  246 + if (stddev == 0) qFatal("Stddev is 0.");
  247 + }
  248 +
  249 + float compare(const Template &target, const Template &query) const
  250 + {
  251 + float score = distance->compare(target,query);
  252 + if (score == -std::numeric_limits<float>::max()) score = (min - mean) / stddev;
  253 + else if (score == std::numeric_limits<float>::max()) score = (max - mean) / stddev;
  254 + else score = (score - mean) / stddev;
  255 + return score;
  256 + }
  257 +
  258 + void store(QDataStream &stream) const
  259 + {
  260 + distance->store(stream);
  261 + stream << min << max << mean << stddev;
  262 + }
  263 +
  264 + void load(QDataStream &stream)
  265 + {
  266 + distance->load(stream);
  267 + stream >> min >> max >> mean >> stddev;
  268 + }
  269 +};
  270 +
  271 +BR_REGISTER(Distance, ZScoreDistance)
  272 +
214 273 /*!
215 274 * \ingroup distances
216 275 * \brief Match Probability modification for heat maps \cite klare12
... ...