Commit 6046e0b9bbbd0109c371d27003f61edb8d0b6f27
Merge pull request #148 from biometrics/score_level_fusion
Score level fusion support
Showing
5 changed files
with
107 additions
and
20 deletions
openbr/core/core.cpp
| ... | ... | @@ -56,11 +56,6 @@ struct AlgorithmCore |
| 56 | 56 | |
| 57 | 57 | TemplateList data(TemplateList::fromGallery(input)); |
| 58 | 58 | |
| 59 | - // set the Train bool metadata, in case a Transform's project | |
| 60 | - // needs to know if it's called during train or enroll | |
| 61 | - for (int i=0; i<data.size(); i++) | |
| 62 | - data[i].file.set("Train", true); | |
| 63 | - | |
| 64 | 59 | if (transform.isNull()) qFatal("Null transform."); |
| 65 | 60 | qDebug("%d Training Files", data.size()); |
| 66 | 61 | |
| ... | ... | @@ -463,7 +458,7 @@ void br::Convert(const File &fileType, const File &inputFile, const File &output |
| 463 | 458 | |
| 464 | 459 | if ((targetFiles.size() != m.cols || queryFiles.size() != m.rows) |
| 465 | 460 | && (m.cols != 1 || targetFiles.size() != m.rows || queryFiles.size() != m.rows)) |
| 466 | - qFatal("Similarity matrix and file size mismatch."); | |
| 461 | + qFatal("Similarity matrix (%d, %d) and header (%d, %d) size mismatch.", m.rows, m.cols, queryFiles.size(), targetFiles.size()); | |
| 467 | 462 | |
| 468 | 463 | QSharedPointer<Output> o(Factory<Output>::make(outputFile)); |
| 469 | 464 | o->initialize(targetFiles, queryFiles); | ... | ... |
openbr/core/fuse.cpp
| ... | ... | @@ -30,6 +30,9 @@ using namespace cv; |
| 30 | 30 | |
| 31 | 31 | static void normalizeMatrix(Mat &matrix, const Mat &mask, const QString &method) |
| 32 | 32 | { |
| 33 | + if (matrix.rows != mask.rows && matrix.cols != mask.cols) | |
| 34 | + qFatal("Similarity matrix (%d, %d) and mask (%d, %d) size mismatch.", matrix.rows, matrix.cols, mask.rows, mask.cols); | |
| 35 | + | |
| 33 | 36 | if (method == "None") return; |
| 34 | 37 | |
| 35 | 38 | QList<float> vals; vals.reserve(matrix.rows*matrix.cols); | ... | ... |
openbr/openbr_plugin.h
| ... | ... | @@ -483,7 +483,7 @@ struct TemplateList : public QList<Template> |
| 483 | 483 | } |
| 484 | 484 | |
| 485 | 485 | /*! |
| 486 | - * \brief Returns a #br::TemplateList containing templates with one matrix at the specified index \em index. | |
| 486 | + * \brief Returns a list of #br::TemplateList with each #br::Template in a given #br::TemplateList containing the number of matrices specified by \em partitionSizes. | |
| 487 | 487 | */ |
| 488 | 488 | QList<TemplateList> partition(const QList<int> &partitionSizes) const |
| 489 | 489 | { | ... | ... |
openbr/plugins/distance.cpp
| ... | ... | @@ -179,58 +179,88 @@ BR_REGISTER(Distance, PipeDistance) |
| 179 | 179 | |
| 180 | 180 | /*! |
| 181 | 181 | * \ingroup distances |
| 182 | - * \brief Computes an operation on distances across multiple matrices of compared templates | |
| 182 | + * \brief Fuses similarity scores across multiple matrices of compared templates | |
| 183 | 183 | * \author Scott Klum \cite sklum |
| 184 | 184 | * \note Operation: Mean, sum, min, max are supported. |
| 185 | 185 | */ |
| 186 | -class OperationDistance : public Distance | |
| 186 | +class FuseDistance : public Distance | |
| 187 | 187 | { |
| 188 | 188 | Q_OBJECT |
| 189 | 189 | Q_ENUMS(Operation) |
| 190 | - Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) | |
| 190 | + Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) | |
| 191 | 191 | Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false) |
| 192 | 192 | |
| 193 | + QList<br::Distance*> distances; | |
| 194 | + | |
| 193 | 195 | public: |
| 194 | 196 | /*!< */ |
| 195 | 197 | enum Operation {Mean, Sum, Max, Min}; |
| 196 | 198 | |
| 197 | 199 | private: |
| 198 | - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) | |
| 200 | + BR_PROPERTY(QString, description, "IdenticalDistance") | |
| 199 | 201 | BR_PROPERTY(Operation, operation, Mean) |
| 200 | 202 | |
| 201 | 203 | void train(const TemplateList &src) |
| 202 | 204 | { |
| 203 | - distance->train(src); | |
| 205 | + // Partition the templates by matrix | |
| 206 | + QList<int> split; | |
| 207 | + for (int i=0; i<src.at(0).size(); i++) split.append(1); | |
| 208 | + | |
| 209 | + QList<TemplateList> partitionedSrc = src.partition(split); | |
| 210 | + | |
| 211 | + while (distances.size() < partitionedSrc.size()) | |
| 212 | + distances.append(make(description)); | |
| 213 | + | |
| 214 | + // Train on each of the partitions | |
| 215 | + for (int i=0; i<distances.size(); i++) | |
| 216 | + distances[i]->train(partitionedSrc[i]); | |
| 204 | 217 | } |
| 205 | 218 | |
| 206 | 219 | float compare(const Template &a, const Template &b) const |
| 207 | 220 | { |
| 208 | 221 | if (a.size() != b.size()) qFatal("Comparison size mismatch"); |
| 209 | 222 | |
| 210 | - QList<float> distances; | |
| 211 | - for (int i = 0; i < a.size(); i++) | |
| 212 | - distances.append(distance->compare(a[i],b[i])); | |
| 223 | + QList<float> scores; | |
| 224 | + for (int i=0; i<distances.size(); i++) | |
| 225 | + scores.append(distances[i]->compare(Template(a.file, a[i]),Template(b.file, b[i]))); | |
| 213 | 226 | |
| 214 | 227 | switch (operation) { |
| 215 | 228 | case Mean: |
| 216 | - return std::accumulate(distances.begin(),distances.end(),0.0)/(float)distances.size(); | |
| 229 | + return std::accumulate(scores.begin(),scores.end(),0.0)/(float)scores.size(); | |
| 217 | 230 | break; |
| 218 | 231 | case Sum: |
| 219 | - return std::accumulate(distances.begin(),distances.end(),0.0); | |
| 232 | + return std::accumulate(scores.begin(),scores.end(),0.0); | |
| 220 | 233 | break; |
| 221 | 234 | case Min: |
| 222 | - return *std::min_element(distances.begin(),distances.end()); | |
| 235 | + return *std::min_element(scores.begin(),scores.end()); | |
| 223 | 236 | break; |
| 224 | 237 | case Max: |
| 225 | - return *std::max_element(distances.begin(),distances.end()); | |
| 238 | + return *std::max_element(scores.begin(),scores.end()); | |
| 226 | 239 | break; |
| 227 | 240 | default: |
| 228 | 241 | qFatal("Invalid operation."); |
| 229 | 242 | } |
| 230 | 243 | } |
| 244 | + | |
| 245 | + void store(QDataStream &stream) const | |
| 246 | + { | |
| 247 | + stream << distances.size(); | |
| 248 | + foreach (Distance *distance, distances) | |
| 249 | + distance->store(stream); | |
| 250 | + } | |
| 251 | + | |
| 252 | + void load(QDataStream &stream) | |
| 253 | + { | |
| 254 | + int numDistances; | |
| 255 | + stream >> numDistances; | |
| 256 | + while (distances.size() < numDistances) | |
| 257 | + distances.append(make(description)); | |
| 258 | + foreach (Distance *distance, distances) | |
| 259 | + distance->load(stream); | |
| 260 | + } | |
| 231 | 261 | }; |
| 232 | 262 | |
| 233 | -BR_REGISTER(Distance, OperationDistance) | |
| 263 | +BR_REGISTER(Distance, FuseDistance) | |
| 234 | 264 | |
| 235 | 265 | /*! |
| 236 | 266 | * \ingroup distances | ... | ... |
openbr/plugins/quality.cpp
| ... | ... | @@ -211,6 +211,65 @@ protected: |
| 211 | 211 | |
| 212 | 212 | BR_REGISTER(Distance, MatchProbabilityDistance) |
| 213 | 213 | |
| 214 | +class ZScoreDistance : public Distance | |
| 215 | +{ | |
| 216 | + Q_OBJECT | |
| 217 | + Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) | |
| 218 | + Q_PROPERTY(bool crossModality READ get_crossModality WRITE set_crossModality RESET reset_crossModality STORED false) | |
| 219 | + BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) | |
| 220 | + BR_PROPERTY(bool, crossModality, false) | |
| 221 | + | |
| 222 | + float min, max; | |
| 223 | + double mean, stddev; | |
| 224 | + | |
| 225 | + void train(const TemplateList &src) | |
| 226 | + { | |
| 227 | + distance->train(src); | |
| 228 | + | |
| 229 | + QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); | |
| 230 | + distance->compare(src, src, matrixOutput.data()); | |
| 231 | + | |
| 232 | + QList<float> scores; | |
| 233 | + scores.reserve(src.size()*src.size()); | |
| 234 | + for (int i=0; i<src.size(); i++) { | |
| 235 | + for (int j=0; j<i; j++) { | |
| 236 | + const float score = matrixOutput.data()->data.at<float>(i, j); | |
| 237 | + if (score == -std::numeric_limits<float>::max()) continue; | |
| 238 | + if (crossModality && src[i].file.get<QString>("MODALITY") == src[j].file.get<QString>("MODALITY")) continue; | |
| 239 | + scores.append(score); | |
| 240 | + } | |
| 241 | + } | |
| 242 | + | |
| 243 | + Common::MinMax(scores, &min, &max); | |
| 244 | + Common::MeanStdDev(scores, &mean, &stddev); | |
| 245 | + | |
| 246 | + if (stddev == 0) qFatal("Stddev is 0."); | |
| 247 | + } | |
| 248 | + | |
| 249 | + float compare(const Template &target, const Template &query) const | |
| 250 | + { | |
| 251 | + float score = distance->compare(target,query); | |
| 252 | + if (score == -std::numeric_limits<float>::max()) score = (min - mean) / stddev; | |
| 253 | + else if (score == std::numeric_limits<float>::max()) score = (max - mean) / stddev; | |
| 254 | + else score = (score - mean) / stddev; | |
| 255 | + return score; | |
| 256 | + } | |
| 257 | + | |
| 258 | + void store(QDataStream &stream) const | |
| 259 | + { | |
| 260 | + distance->store(stream); | |
| 261 | + stream << min << max << mean << stddev; | |
| 262 | + } | |
| 263 | + | |
| 264 | + void load(QDataStream &stream) | |
| 265 | + { | |
| 266 | + distance->load(stream); | |
| 267 | + stream >> min >> max >> mean >> stddev; | |
| 268 | + } | |
| 269 | +}; | |
| 270 | + | |
| 271 | +BR_REGISTER(Distance, ZScoreDistance) | |
| 272 | + | |
| 214 | 273 | /*! |
| 215 | 274 | * \ingroup distances |
| 216 | 275 | * \brief Match Probability modification for heat maps \cite klare12 | ... | ... |