Commit 0a275cfb5035505c666827644e46effbbc096a8d
1 parent
b174c493
OperationDistance::train finished
Showing
5 changed files
with
40 additions
and
22 deletions
openbr/core/core.cpp
| @@ -463,7 +463,7 @@ void br::Convert(const File &fileType, const File &inputFile, const File &output | @@ -463,7 +463,7 @@ void br::Convert(const File &fileType, const File &inputFile, const File &output | ||
| 463 | 463 | ||
| 464 | if ((targetFiles.size() != m.cols || queryFiles.size() != m.rows) | 464 | if ((targetFiles.size() != m.cols || queryFiles.size() != m.rows) |
| 465 | && (m.cols != 1 || targetFiles.size() != m.rows || queryFiles.size() != m.rows)) | 465 | && (m.cols != 1 || targetFiles.size() != m.rows || queryFiles.size() != m.rows)) |
| 466 | - qFatal("Similarity matrix and file size mismatch."); | 466 | + qFatal("Similarity matrix (%d, %d) and header (%d, %d) size mismatch.", m.rows, m.cols, queryFiles.size(), targetFiles.size()); |
| 467 | 467 | ||
| 468 | QSharedPointer<Output> o(Factory<Output>::make(outputFile)); | 468 | QSharedPointer<Output> o(Factory<Output>::make(outputFile)); |
| 469 | o->initialize(targetFiles, queryFiles); | 469 | o->initialize(targetFiles, queryFiles); |
openbr/core/fuse.cpp
| @@ -30,6 +30,9 @@ using namespace cv; | @@ -30,6 +30,9 @@ using namespace cv; | ||
| 30 | 30 | ||
| 31 | static void normalizeMatrix(Mat &matrix, const Mat &mask, const QString &method) | 31 | static void normalizeMatrix(Mat &matrix, const Mat &mask, const QString &method) |
| 32 | { | 32 | { |
| 33 | + if (matrix.rows != mask.rows && matrix.cols != mask.cols) | ||
| 34 | + qFatal("Similarity matrix (%d, %d) and mask (%d, %d) size mismatch.", matrix.rows, matrix.cols, mask.rows, mask.cols); | ||
| 35 | + | ||
| 33 | if (method == "None") return; | 36 | if (method == "None") return; |
| 34 | 37 | ||
| 35 | QList<float> vals; vals.reserve(matrix.rows*matrix.cols); | 38 | QList<float> vals; vals.reserve(matrix.rows*matrix.cols); |
openbr/openbr_plugin.h
| @@ -483,7 +483,7 @@ struct TemplateList : public QList<Template> | @@ -483,7 +483,7 @@ struct TemplateList : public QList<Template> | ||
| 483 | } | 483 | } |
| 484 | 484 | ||
| 485 | /*! | 485 | /*! |
| 486 | - * \brief Returns a #br::TemplateList containing templates with one matrix at the specified index \em index. | 486 | + * \brief Returns a list of #br::TemplateList containing templates with a each TemplateList containing the number of matrices specified by \em partitionSizes. |
| 487 | */ | 487 | */ |
| 488 | QList<TemplateList> partition(const QList<int> &partitionSizes) const | 488 | QList<TemplateList> partition(const QList<int> &partitionSizes) const |
| 489 | { | 489 | { |
openbr/plugins/distance.cpp
| @@ -187,47 +187,61 @@ class OperationDistance : public Distance | @@ -187,47 +187,61 @@ class OperationDistance : public Distance | ||
| 187 | { | 187 | { |
| 188 | Q_OBJECT | 188 | Q_OBJECT |
| 189 | Q_ENUMS(Operation) | 189 | Q_ENUMS(Operation) |
| 190 | - Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) | 190 | + Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) |
| 191 | Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false) | 191 | Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false) |
| 192 | + Q_PROPERTY(QList<int> split READ get_split WRITE set_split RESET reset_split STORED false) | ||
| 193 | + | ||
| 194 | + QList<br::Distance*> distances; | ||
| 192 | 195 | ||
| 193 | public: | 196 | public: |
| 194 | /*!< */ | 197 | /*!< */ |
| 195 | enum Operation {Mean, Sum, Max, Min}; | 198 | enum Operation {Mean, Sum, Max, Min}; |
| 196 | 199 | ||
| 197 | private: | 200 | private: |
| 198 | - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) | 201 | + BR_PROPERTY(QString, description, "IdenticalDistance") |
| 199 | BR_PROPERTY(Operation, operation, Mean) | 202 | BR_PROPERTY(Operation, operation, Mean) |
| 203 | + BR_PROPERTY(QList<int>, split, QList<int>()) | ||
| 200 | 204 | ||
| 201 | void train(const TemplateList &src) | 205 | void train(const TemplateList &src) |
| 202 | { | 206 | { |
| 203 | - distance->train(src); | 207 | + // Default is to train on each matrix |
| 208 | + if (split.isEmpty()) for (int i=0; i<src.at(0).size(); i++) split.append(1); | ||
| 209 | + | ||
| 210 | + QList<TemplateList> partitionedSrc = src.partition(split); | ||
| 211 | + | ||
| 212 | + while (distances.size() < partitionedSrc.size()) | ||
| 213 | + distances.append(make(description)); | ||
| 214 | + | ||
| 215 | + // Train on each of the partitions | ||
| 216 | + for (int i=0; i<distances.size(); i++) | ||
| 217 | + distances[i]->train(partitionedSrc[i]); | ||
| 204 | } | 218 | } |
| 205 | 219 | ||
| 206 | float compare(const Template &a, const Template &b) const | 220 | float compare(const Template &a, const Template &b) const |
| 207 | { | 221 | { |
| 208 | if (a.size() != b.size()) qFatal("Comparison size mismatch"); | 222 | if (a.size() != b.size()) qFatal("Comparison size mismatch"); |
| 209 | 223 | ||
| 210 | - QList<float> distances; | ||
| 211 | - for (int i = 0; i < a.size(); i++) { | 224 | + QList<float> scores; |
| 225 | + for (int i=0; i<distances.size(); i++) { | ||
| 212 | Template ai = a.file; | 226 | Template ai = a.file; |
| 213 | - ai.m() = a[i].clone(); | 227 | + ai.m() = a[i]; |
| 214 | Template bi = b.file; | 228 | Template bi = b.file; |
| 215 | - bi.m() = b[i].clone(); | ||
| 216 | - distances.append(distance->compare(ai,bi)); | 229 | + bi.m() = b[i]; |
| 230 | + scores.append(distances[i]->compare(ai,bi)); | ||
| 217 | } | 231 | } |
| 218 | 232 | ||
| 219 | switch (operation) { | 233 | switch (operation) { |
| 220 | case Mean: | 234 | case Mean: |
| 221 | - return std::accumulate(distances.begin(),distances.end(),0.0)/(float)distances.size(); | 235 | + return std::accumulate(scores.begin(),scores.end(),0.0)/(float)scores.size(); |
| 222 | break; | 236 | break; |
| 223 | case Sum: | 237 | case Sum: |
| 224 | - return std::accumulate(distances.begin(),distances.end(),0.0); | 238 | + return std::accumulate(scores.begin(),scores.end(),0.0); |
| 225 | break; | 239 | break; |
| 226 | case Min: | 240 | case Min: |
| 227 | - return *std::min_element(distances.begin(),distances.end()); | 241 | + return *std::min_element(scores.begin(),scores.end()); |
| 228 | break; | 242 | break; |
| 229 | case Max: | 243 | case Max: |
| 230 | - return *std::max_element(distances.begin(),distances.end()); | 244 | + return *std::max_element(scores.begin(),scores.end()); |
| 231 | break; | 245 | break; |
| 232 | default: | 246 | default: |
| 233 | qFatal("Invalid operation."); | 247 | qFatal("Invalid operation."); |
| @@ -236,12 +250,19 @@ private: | @@ -236,12 +250,19 @@ private: | ||
| 236 | 250 | ||
| 237 | void store(QDataStream &stream) const | 251 | void store(QDataStream &stream) const |
| 238 | { | 252 | { |
| 239 | - distance->store(stream); | 253 | + stream << distances.size(); |
| 254 | + foreach (Distance *distance, distances) | ||
| 255 | + distance->store(stream); | ||
| 240 | } | 256 | } |
| 241 | 257 | ||
| 242 | void load(QDataStream &stream) | 258 | void load(QDataStream &stream) |
| 243 | { | 259 | { |
| 244 | - distance->load(stream); | 260 | + int numDistances; |
| 261 | + stream >> numDistances; | ||
| 262 | + while (distances.size() < numDistances) | ||
| 263 | + distances.append(make(description)); | ||
| 264 | + foreach (Distance *distance, distances) | ||
| 265 | + distance->load(stream); | ||
| 245 | } | 266 | } |
| 246 | }; | 267 | }; |
| 247 | 268 |
openbr/plugins/quality.cpp
| @@ -179,12 +179,7 @@ class MatchProbabilityDistance : public Distance | @@ -179,12 +179,7 @@ class MatchProbabilityDistance : public Distance | ||
| 179 | } | 179 | } |
| 180 | } | 180 | } |
| 181 | 181 | ||
| 182 | - qDebug() << "Genuines: " << genuineScores.mid(0,5); | ||
| 183 | - qDebug() << "Impostors: " << impostorScores.mid(0,5); | ||
| 184 | - | ||
| 185 | mp = MP(genuineScores, impostorScores); | 182 | mp = MP(genuineScores, impostorScores); |
| 186 | - | ||
| 187 | - qDebug() << mp(-0.881882,true); | ||
| 188 | } | 183 | } |
| 189 | 184 | ||
| 190 | float compare(const Template &target, const Template &query) const | 185 | float compare(const Template &target, const Template &query) const |
| @@ -192,7 +187,6 @@ class MatchProbabilityDistance : public Distance | @@ -192,7 +187,6 @@ class MatchProbabilityDistance : public Distance | ||
| 192 | const float rawScore = distance->compare(target, query); | 187 | const float rawScore = distance->compare(target, query); |
| 193 | if (rawScore == -std::numeric_limits<float>::max()) return rawScore; | 188 | if (rawScore == -std::numeric_limits<float>::max()) return rawScore; |
| 194 | if (!Globals->scoreNormalization) return -log(rawScore+1); | 189 | if (!Globals->scoreNormalization) return -log(rawScore+1); |
| 195 | - qDebug() << mp(rawScore, gaussian) << rawScore; | ||
| 196 | return mp(rawScore, gaussian); | 190 | return mp(rawScore, gaussian); |
| 197 | } | 191 | } |
| 198 | 192 |