Commit 6046e0b9bbbd0109c371d27003f61edb8d0b6f27
Merge pull request #148 from biometrics/score_level_fusion
Score level fusion support
Showing
5 changed files
with
107 additions
and
20 deletions
openbr/core/core.cpp
| @@ -56,11 +56,6 @@ struct AlgorithmCore | @@ -56,11 +56,6 @@ struct AlgorithmCore | ||
| 56 | 56 | ||
| 57 | TemplateList data(TemplateList::fromGallery(input)); | 57 | TemplateList data(TemplateList::fromGallery(input)); |
| 58 | 58 | ||
| 59 | - // set the Train bool metadata, in case a Transform's project | ||
| 60 | - // needs to know if it's called during train or enroll | ||
| 61 | - for (int i=0; i<data.size(); i++) | ||
| 62 | - data[i].file.set("Train", true); | ||
| 63 | - | ||
| 64 | if (transform.isNull()) qFatal("Null transform."); | 59 | if (transform.isNull()) qFatal("Null transform."); |
| 65 | qDebug("%d Training Files", data.size()); | 60 | qDebug("%d Training Files", data.size()); |
| 66 | 61 | ||
| @@ -463,7 +458,7 @@ void br::Convert(const File &fileType, const File &inputFile, const File &output | @@ -463,7 +458,7 @@ void br::Convert(const File &fileType, const File &inputFile, const File &output | ||
| 463 | 458 | ||
| 464 | if ((targetFiles.size() != m.cols || queryFiles.size() != m.rows) | 459 | if ((targetFiles.size() != m.cols || queryFiles.size() != m.rows) |
| 465 | && (m.cols != 1 || targetFiles.size() != m.rows || queryFiles.size() != m.rows)) | 460 | && (m.cols != 1 || targetFiles.size() != m.rows || queryFiles.size() != m.rows)) |
| 466 | - qFatal("Similarity matrix and file size mismatch."); | 461 | + qFatal("Similarity matrix (%d, %d) and header (%d, %d) size mismatch.", m.rows, m.cols, queryFiles.size(), targetFiles.size()); |
| 467 | 462 | ||
| 468 | QSharedPointer<Output> o(Factory<Output>::make(outputFile)); | 463 | QSharedPointer<Output> o(Factory<Output>::make(outputFile)); |
| 469 | o->initialize(targetFiles, queryFiles); | 464 | o->initialize(targetFiles, queryFiles); |
openbr/core/fuse.cpp
| @@ -30,6 +30,9 @@ using namespace cv; | @@ -30,6 +30,9 @@ using namespace cv; | ||
| 30 | 30 | ||
| 31 | static void normalizeMatrix(Mat &matrix, const Mat &mask, const QString &method) | 31 | static void normalizeMatrix(Mat &matrix, const Mat &mask, const QString &method) |
| 32 | { | 32 | { |
| 33 | + if (matrix.rows != mask.rows && matrix.cols != mask.cols) | ||
| 34 | + qFatal("Similarity matrix (%d, %d) and mask (%d, %d) size mismatch.", matrix.rows, matrix.cols, mask.rows, mask.cols); | ||
| 35 | + | ||
| 33 | if (method == "None") return; | 36 | if (method == "None") return; |
| 34 | 37 | ||
| 35 | QList<float> vals; vals.reserve(matrix.rows*matrix.cols); | 38 | QList<float> vals; vals.reserve(matrix.rows*matrix.cols); |
openbr/openbr_plugin.h
| @@ -483,7 +483,7 @@ struct TemplateList : public QList<Template> | @@ -483,7 +483,7 @@ struct TemplateList : public QList<Template> | ||
| 483 | } | 483 | } |
| 484 | 484 | ||
| 485 | /*! | 485 | /*! |
| 486 | - * \brief Returns a #br::TemplateList containing templates with one matrix at the specified index \em index. | 486 | + * \brief Returns a list of #br::TemplateList with each #br::Template in a given #br::TemplateList containing the number of matrices specified by \em partitionSizes. |
| 487 | */ | 487 | */ |
| 488 | QList<TemplateList> partition(const QList<int> &partitionSizes) const | 488 | QList<TemplateList> partition(const QList<int> &partitionSizes) const |
| 489 | { | 489 | { |
openbr/plugins/distance.cpp
| @@ -179,58 +179,88 @@ BR_REGISTER(Distance, PipeDistance) | @@ -179,58 +179,88 @@ BR_REGISTER(Distance, PipeDistance) | ||
| 179 | 179 | ||
| 180 | /*! | 180 | /*! |
| 181 | * \ingroup distances | 181 | * \ingroup distances |
| 182 | - * \brief Computes an operation on distances across multiple matrices of compared templates | 182 | + * \brief Fuses similarity scores across multiple matrices of compared templates |
| 183 | * \author Scott Klum \cite sklum | 183 | * \author Scott Klum \cite sklum |
| 184 | * \note Operation: Mean, sum, min, max are supported. | 184 | * \note Operation: Mean, sum, min, max are supported. |
| 185 | */ | 185 | */ |
| 186 | -class OperationDistance : public Distance | 186 | +class FuseDistance : public Distance |
| 187 | { | 187 | { |
| 188 | Q_OBJECT | 188 | Q_OBJECT |
| 189 | Q_ENUMS(Operation) | 189 | Q_ENUMS(Operation) |
| 190 | - Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) | 190 | + Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) |
| 191 | Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false) | 191 | Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false) |
| 192 | 192 | ||
| 193 | + QList<br::Distance*> distances; | ||
| 194 | + | ||
| 193 | public: | 195 | public: |
| 194 | /*!< */ | 196 | /*!< */ |
| 195 | enum Operation {Mean, Sum, Max, Min}; | 197 | enum Operation {Mean, Sum, Max, Min}; |
| 196 | 198 | ||
| 197 | private: | 199 | private: |
| 198 | - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) | 200 | + BR_PROPERTY(QString, description, "IdenticalDistance") |
| 199 | BR_PROPERTY(Operation, operation, Mean) | 201 | BR_PROPERTY(Operation, operation, Mean) |
| 200 | 202 | ||
| 201 | void train(const TemplateList &src) | 203 | void train(const TemplateList &src) |
| 202 | { | 204 | { |
| 203 | - distance->train(src); | 205 | + // Partition the templates by matrix |
| 206 | + QList<int> split; | ||
| 207 | + for (int i=0; i<src.at(0).size(); i++) split.append(1); | ||
| 208 | + | ||
| 209 | + QList<TemplateList> partitionedSrc = src.partition(split); | ||
| 210 | + | ||
| 211 | + while (distances.size() < partitionedSrc.size()) | ||
| 212 | + distances.append(make(description)); | ||
| 213 | + | ||
| 214 | + // Train on each of the partitions | ||
| 215 | + for (int i=0; i<distances.size(); i++) | ||
| 216 | + distances[i]->train(partitionedSrc[i]); | ||
| 204 | } | 217 | } |
| 205 | 218 | ||
| 206 | float compare(const Template &a, const Template &b) const | 219 | float compare(const Template &a, const Template &b) const |
| 207 | { | 220 | { |
| 208 | if (a.size() != b.size()) qFatal("Comparison size mismatch"); | 221 | if (a.size() != b.size()) qFatal("Comparison size mismatch"); |
| 209 | 222 | ||
| 210 | - QList<float> distances; | ||
| 211 | - for (int i = 0; i < a.size(); i++) | ||
| 212 | - distances.append(distance->compare(a[i],b[i])); | 223 | + QList<float> scores; |
| 224 | + for (int i=0; i<distances.size(); i++) | ||
| 225 | + scores.append(distances[i]->compare(Template(a.file, a[i]),Template(b.file, b[i]))); | ||
| 213 | 226 | ||
| 214 | switch (operation) { | 227 | switch (operation) { |
| 215 | case Mean: | 228 | case Mean: |
| 216 | - return std::accumulate(distances.begin(),distances.end(),0.0)/(float)distances.size(); | 229 | + return std::accumulate(scores.begin(),scores.end(),0.0)/(float)scores.size(); |
| 217 | break; | 230 | break; |
| 218 | case Sum: | 231 | case Sum: |
| 219 | - return std::accumulate(distances.begin(),distances.end(),0.0); | 232 | + return std::accumulate(scores.begin(),scores.end(),0.0); |
| 220 | break; | 233 | break; |
| 221 | case Min: | 234 | case Min: |
| 222 | - return *std::min_element(distances.begin(),distances.end()); | 235 | + return *std::min_element(scores.begin(),scores.end()); |
| 223 | break; | 236 | break; |
| 224 | case Max: | 237 | case Max: |
| 225 | - return *std::max_element(distances.begin(),distances.end()); | 238 | + return *std::max_element(scores.begin(),scores.end()); |
| 226 | break; | 239 | break; |
| 227 | default: | 240 | default: |
| 228 | qFatal("Invalid operation."); | 241 | qFatal("Invalid operation."); |
| 229 | } | 242 | } |
| 230 | } | 243 | } |
| 244 | + | ||
| 245 | + void store(QDataStream &stream) const | ||
| 246 | + { | ||
| 247 | + stream << distances.size(); | ||
| 248 | + foreach (Distance *distance, distances) | ||
| 249 | + distance->store(stream); | ||
| 250 | + } | ||
| 251 | + | ||
| 252 | + void load(QDataStream &stream) | ||
| 253 | + { | ||
| 254 | + int numDistances; | ||
| 255 | + stream >> numDistances; | ||
| 256 | + while (distances.size() < numDistances) | ||
| 257 | + distances.append(make(description)); | ||
| 258 | + foreach (Distance *distance, distances) | ||
| 259 | + distance->load(stream); | ||
| 260 | + } | ||
| 231 | }; | 261 | }; |
| 232 | 262 | ||
| 233 | -BR_REGISTER(Distance, OperationDistance) | 263 | +BR_REGISTER(Distance, FuseDistance) |
| 234 | 264 | ||
| 235 | /*! | 265 | /*! |
| 236 | * \ingroup distances | 266 | * \ingroup distances |
openbr/plugins/quality.cpp
| @@ -211,6 +211,65 @@ protected: | @@ -211,6 +211,65 @@ protected: | ||
| 211 | 211 | ||
| 212 | BR_REGISTER(Distance, MatchProbabilityDistance) | 212 | BR_REGISTER(Distance, MatchProbabilityDistance) |
| 213 | 213 | ||
| 214 | +class ZScoreDistance : public Distance | ||
| 215 | +{ | ||
| 216 | + Q_OBJECT | ||
| 217 | + Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) | ||
| 218 | + Q_PROPERTY(bool crossModality READ get_crossModality WRITE set_crossModality RESET reset_crossModality STORED false) | ||
| 219 | + BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) | ||
| 220 | + BR_PROPERTY(bool, crossModality, false) | ||
| 221 | + | ||
| 222 | + float min, max; | ||
| 223 | + double mean, stddev; | ||
| 224 | + | ||
| 225 | + void train(const TemplateList &src) | ||
| 226 | + { | ||
| 227 | + distance->train(src); | ||
| 228 | + | ||
| 229 | + QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); | ||
| 230 | + distance->compare(src, src, matrixOutput.data()); | ||
| 231 | + | ||
| 232 | + QList<float> scores; | ||
| 233 | + scores.reserve(src.size()*src.size()); | ||
| 234 | + for (int i=0; i<src.size(); i++) { | ||
| 235 | + for (int j=0; j<i; j++) { | ||
| 236 | + const float score = matrixOutput.data()->data.at<float>(i, j); | ||
| 237 | + if (score == -std::numeric_limits<float>::max()) continue; | ||
| 238 | + if (crossModality && src[i].file.get<QString>("MODALITY") == src[j].file.get<QString>("MODALITY")) continue; | ||
| 239 | + scores.append(score); | ||
| 240 | + } | ||
| 241 | + } | ||
| 242 | + | ||
| 243 | + Common::MinMax(scores, &min, &max); | ||
| 244 | + Common::MeanStdDev(scores, &mean, &stddev); | ||
| 245 | + | ||
| 246 | + if (stddev == 0) qFatal("Stddev is 0."); | ||
| 247 | + } | ||
| 248 | + | ||
| 249 | + float compare(const Template &target, const Template &query) const | ||
| 250 | + { | ||
| 251 | + float score = distance->compare(target,query); | ||
| 252 | + if (score == -std::numeric_limits<float>::max()) score = (min - mean) / stddev; | ||
| 253 | + else if (score == std::numeric_limits<float>::max()) score = (max - mean) / stddev; | ||
| 254 | + else score = (score - mean) / stddev; | ||
| 255 | + return score; | ||
| 256 | + } | ||
| 257 | + | ||
| 258 | + void store(QDataStream &stream) const | ||
| 259 | + { | ||
| 260 | + distance->store(stream); | ||
| 261 | + stream << min << max << mean << stddev; | ||
| 262 | + } | ||
| 263 | + | ||
| 264 | + void load(QDataStream &stream) | ||
| 265 | + { | ||
| 266 | + distance->load(stream); | ||
| 267 | + stream >> min >> max >> mean >> stddev; | ||
| 268 | + } | ||
| 269 | +}; | ||
| 270 | + | ||
| 271 | +BR_REGISTER(Distance, ZScoreDistance) | ||
| 272 | + | ||
| 214 | /*! | 273 | /*! |
| 215 | * \ingroup distances | 274 | * \ingroup distances |
| 216 | * \brief Match Probability modification for heat maps \cite klare12 | 275 | * \brief Match Probability modification for heat maps \cite klare12 |