Commit 6046e0b9bbbd0109c371d27003f61edb8d0b6f27

Authored by sklum
2 parents 9deea8e7 f95f259e

Merge pull request #148 from biometrics/score_level_fusion

Score level fusion support
openbr/core/core.cpp
@@ -56,11 +56,6 @@ struct AlgorithmCore @@ -56,11 +56,6 @@ struct AlgorithmCore
56 56
57 TemplateList data(TemplateList::fromGallery(input)); 57 TemplateList data(TemplateList::fromGallery(input));
58 58
59 - // set the Train bool metadata, in case a Transform's project  
60 - // needs to know if it's called during train or enroll  
61 - for (int i=0; i<data.size(); i++)  
62 - data[i].file.set("Train", true);  
63 -  
64 if (transform.isNull()) qFatal("Null transform."); 59 if (transform.isNull()) qFatal("Null transform.");
65 qDebug("%d Training Files", data.size()); 60 qDebug("%d Training Files", data.size());
66 61
@@ -463,7 +458,7 @@ void br::Convert(const File &amp;fileType, const File &amp;inputFile, const File &amp;output @@ -463,7 +458,7 @@ void br::Convert(const File &amp;fileType, const File &amp;inputFile, const File &amp;output
463 458
464 if ((targetFiles.size() != m.cols || queryFiles.size() != m.rows) 459 if ((targetFiles.size() != m.cols || queryFiles.size() != m.rows)
465 && (m.cols != 1 || targetFiles.size() != m.rows || queryFiles.size() != m.rows)) 460 && (m.cols != 1 || targetFiles.size() != m.rows || queryFiles.size() != m.rows))
466 - qFatal("Similarity matrix and file size mismatch."); 461 + qFatal("Similarity matrix (%d, %d) and header (%d, %d) size mismatch.", m.rows, m.cols, queryFiles.size(), targetFiles.size());
467 462
468 QSharedPointer<Output> o(Factory<Output>::make(outputFile)); 463 QSharedPointer<Output> o(Factory<Output>::make(outputFile));
469 o->initialize(targetFiles, queryFiles); 464 o->initialize(targetFiles, queryFiles);
openbr/core/fuse.cpp
@@ -30,6 +30,9 @@ using namespace cv; @@ -30,6 +30,9 @@ using namespace cv;
30 30
31 static void normalizeMatrix(Mat &matrix, const Mat &mask, const QString &method) 31 static void normalizeMatrix(Mat &matrix, const Mat &mask, const QString &method)
32 { 32 {
  33 + if (matrix.rows != mask.rows && matrix.cols != mask.cols)
  34 + qFatal("Similarity matrix (%d, %d) and mask (%d, %d) size mismatch.", matrix.rows, matrix.cols, mask.rows, mask.cols);
  35 +
33 if (method == "None") return; 36 if (method == "None") return;
34 37
35 QList<float> vals; vals.reserve(matrix.rows*matrix.cols); 38 QList<float> vals; vals.reserve(matrix.rows*matrix.cols);
openbr/openbr_plugin.h
@@ -483,7 +483,7 @@ struct TemplateList : public QList&lt;Template&gt; @@ -483,7 +483,7 @@ struct TemplateList : public QList&lt;Template&gt;
483 } 483 }
484 484
485 /*! 485 /*!
486 - * \brief Returns a #br::TemplateList containing templates with one matrix at the specified index \em index. 486 + * \brief Returns a list of #br::TemplateList with each #br::Template in a given #br::TemplateList containing the number of matrices specified by \em partitionSizes.
487 */ 487 */
488 QList<TemplateList> partition(const QList<int> &partitionSizes) const 488 QList<TemplateList> partition(const QList<int> &partitionSizes) const
489 { 489 {
openbr/plugins/distance.cpp
@@ -179,58 +179,88 @@ BR_REGISTER(Distance, PipeDistance) @@ -179,58 +179,88 @@ BR_REGISTER(Distance, PipeDistance)
179 179
180 /*! 180 /*!
181 * \ingroup distances 181 * \ingroup distances
182 - * \brief Computes an operation on distances across multiple matrices of compared templates 182 + * \brief Fuses similarity scores across multiple matrices of compared templates
183 * \author Scott Klum \cite sklum 183 * \author Scott Klum \cite sklum
184 * \note Operation: Mean, sum, min, max are supported. 184 * \note Operation: Mean, sum, min, max are supported.
185 */ 185 */
186 -class OperationDistance : public Distance 186 +class FuseDistance : public Distance
187 { 187 {
188 Q_OBJECT 188 Q_OBJECT
189 Q_ENUMS(Operation) 189 Q_ENUMS(Operation)
190 - Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) 190 + Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)
191 Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false) 191 Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false)
192 192
  193 + QList<br::Distance*> distances;
  194 +
193 public: 195 public:
194 /*!< */ 196 /*!< */
195 enum Operation {Mean, Sum, Max, Min}; 197 enum Operation {Mean, Sum, Max, Min};
196 198
197 private: 199 private:
198 - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) 200 + BR_PROPERTY(QString, description, "IdenticalDistance")
199 BR_PROPERTY(Operation, operation, Mean) 201 BR_PROPERTY(Operation, operation, Mean)
200 202
201 void train(const TemplateList &src) 203 void train(const TemplateList &src)
202 { 204 {
203 - distance->train(src); 205 + // Partition the templates by matrix
  206 + QList<int> split;
  207 + for (int i=0; i<src.at(0).size(); i++) split.append(1);
  208 +
  209 + QList<TemplateList> partitionedSrc = src.partition(split);
  210 +
  211 + while (distances.size() < partitionedSrc.size())
  212 + distances.append(make(description));
  213 +
  214 + // Train on each of the partitions
  215 + for (int i=0; i<distances.size(); i++)
  216 + distances[i]->train(partitionedSrc[i]);
204 } 217 }
205 218
206 float compare(const Template &a, const Template &b) const 219 float compare(const Template &a, const Template &b) const
207 { 220 {
208 if (a.size() != b.size()) qFatal("Comparison size mismatch"); 221 if (a.size() != b.size()) qFatal("Comparison size mismatch");
209 222
210 - QList<float> distances;  
211 - for (int i = 0; i < a.size(); i++)  
212 - distances.append(distance->compare(a[i],b[i])); 223 + QList<float> scores;
  224 + for (int i=0; i<distances.size(); i++)
  225 + scores.append(distances[i]->compare(Template(a.file, a[i]),Template(b.file, b[i])));
213 226
214 switch (operation) { 227 switch (operation) {
215 case Mean: 228 case Mean:
216 - return std::accumulate(distances.begin(),distances.end(),0.0)/(float)distances.size(); 229 + return std::accumulate(scores.begin(),scores.end(),0.0)/(float)scores.size();
217 break; 230 break;
218 case Sum: 231 case Sum:
219 - return std::accumulate(distances.begin(),distances.end(),0.0); 232 + return std::accumulate(scores.begin(),scores.end(),0.0);
220 break; 233 break;
221 case Min: 234 case Min:
222 - return *std::min_element(distances.begin(),distances.end()); 235 + return *std::min_element(scores.begin(),scores.end());
223 break; 236 break;
224 case Max: 237 case Max:
225 - return *std::max_element(distances.begin(),distances.end()); 238 + return *std::max_element(scores.begin(),scores.end());
226 break; 239 break;
227 default: 240 default:
228 qFatal("Invalid operation."); 241 qFatal("Invalid operation.");
229 } 242 }
230 } 243 }
  244 +
  245 + void store(QDataStream &stream) const
  246 + {
  247 + stream << distances.size();
  248 + foreach (Distance *distance, distances)
  249 + distance->store(stream);
  250 + }
  251 +
  252 + void load(QDataStream &stream)
  253 + {
  254 + int numDistances;
  255 + stream >> numDistances;
  256 + while (distances.size() < numDistances)
  257 + distances.append(make(description));
  258 + foreach (Distance *distance, distances)
  259 + distance->load(stream);
  260 + }
231 }; 261 };
232 262
233 -BR_REGISTER(Distance, OperationDistance) 263 +BR_REGISTER(Distance, FuseDistance)
234 264
235 /*! 265 /*!
236 * \ingroup distances 266 * \ingroup distances
openbr/plugins/quality.cpp
@@ -211,6 +211,65 @@ protected: @@ -211,6 +211,65 @@ protected:
211 211
212 BR_REGISTER(Distance, MatchProbabilityDistance) 212 BR_REGISTER(Distance, MatchProbabilityDistance)
213 213
  214 +class ZScoreDistance : public Distance
  215 +{
  216 + Q_OBJECT
  217 + Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
  218 + Q_PROPERTY(bool crossModality READ get_crossModality WRITE set_crossModality RESET reset_crossModality STORED false)
  219 + BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))
  220 + BR_PROPERTY(bool, crossModality, false)
  221 +
  222 + float min, max;
  223 + double mean, stddev;
  224 +
  225 + void train(const TemplateList &src)
  226 + {
  227 + distance->train(src);
  228 +
  229 + QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size())));
  230 + distance->compare(src, src, matrixOutput.data());
  231 +
  232 + QList<float> scores;
  233 + scores.reserve(src.size()*src.size());
  234 + for (int i=0; i<src.size(); i++) {
  235 + for (int j=0; j<i; j++) {
  236 + const float score = matrixOutput.data()->data.at<float>(i, j);
  237 + if (score == -std::numeric_limits<float>::max()) continue;
  238 + if (crossModality && src[i].file.get<QString>("MODALITY") == src[j].file.get<QString>("MODALITY")) continue;
  239 + scores.append(score);
  240 + }
  241 + }
  242 +
  243 + Common::MinMax(scores, &min, &max);
  244 + Common::MeanStdDev(scores, &mean, &stddev);
  245 +
  246 + if (stddev == 0) qFatal("Stddev is 0.");
  247 + }
  248 +
  249 + float compare(const Template &target, const Template &query) const
  250 + {
  251 + float score = distance->compare(target,query);
  252 + if (score == -std::numeric_limits<float>::max()) score = (min - mean) / stddev;
  253 + else if (score == std::numeric_limits<float>::max()) score = (max - mean) / stddev;
  254 + else score = (score - mean) / stddev;
  255 + return score;
  256 + }
  257 +
  258 + void store(QDataStream &stream) const
  259 + {
  260 + distance->store(stream);
  261 + stream << min << max << mean << stddev;
  262 + }
  263 +
  264 + void load(QDataStream &stream)
  265 + {
  266 + distance->load(stream);
  267 + stream >> min >> max >> mean >> stddev;
  268 + }
  269 +};
  270 +
  271 +BR_REGISTER(Distance, ZScoreDistance)
  272 +
214 /*! 273 /*!
215 * \ingroup distances 274 * \ingroup distances
216 * \brief Match Probability modification for heat maps \cite klare12 275 * \brief Match Probability modification for heat maps \cite klare12