Commit 0a275cfb5035505c666827644e46effbbc096a8d
1 parent
b174c493
OperationDistance::train finished
Showing
5 changed files
with
40 additions
and
22 deletions
openbr/core/core.cpp
| ... | ... | @@ -463,7 +463,7 @@ void br::Convert(const File &fileType, const File &inputFile, const File &output |
| 463 | 463 | |
| 464 | 464 | if ((targetFiles.size() != m.cols || queryFiles.size() != m.rows) |
| 465 | 465 | && (m.cols != 1 || targetFiles.size() != m.rows || queryFiles.size() != m.rows)) |
| 466 | - qFatal("Similarity matrix and file size mismatch."); | |
| 466 | + qFatal("Similarity matrix (%d, %d) and header (%d, %d) size mismatch.", m.rows, m.cols, queryFiles.size(), targetFiles.size()); | |
| 467 | 467 | |
| 468 | 468 | QSharedPointer<Output> o(Factory<Output>::make(outputFile)); |
| 469 | 469 | o->initialize(targetFiles, queryFiles); | ... | ... |
openbr/core/fuse.cpp
| ... | ... | @@ -30,6 +30,9 @@ using namespace cv; |
| 30 | 30 | |
| 31 | 31 | static void normalizeMatrix(Mat &matrix, const Mat &mask, const QString &method) |
| 32 | 32 | { |
| 33 | + if (matrix.rows != mask.rows && matrix.cols != mask.cols) | |
| 34 | + qFatal("Similarity matrix (%d, %d) and mask (%d, %d) size mismatch.", matrix.rows, matrix.cols, mask.rows, mask.cols); | |
| 35 | + | |
| 33 | 36 | if (method == "None") return; |
| 34 | 37 | |
| 35 | 38 | QList<float> vals; vals.reserve(matrix.rows*matrix.cols); | ... | ... |
openbr/openbr_plugin.h
| ... | ... | @@ -483,7 +483,7 @@ struct TemplateList : public QList<Template> |
| 483 | 483 | } |
| 484 | 484 | |
| 485 | 485 | /*! |
| 486 | - * \brief Returns a #br::TemplateList containing templates with one matrix at the specified index \em index. | |
| 486 | + * \brief Returns a list of #br::TemplateList containing templates with a each TemplateList containing the number of matrices specified by \em partitionSizes. | |
| 487 | 487 | */ |
| 488 | 488 | QList<TemplateList> partition(const QList<int> &partitionSizes) const |
| 489 | 489 | { | ... | ... |
openbr/plugins/distance.cpp
| ... | ... | @@ -187,47 +187,61 @@ class OperationDistance : public Distance |
| 187 | 187 | { |
| 188 | 188 | Q_OBJECT |
| 189 | 189 | Q_ENUMS(Operation) |
| 190 | - Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) | |
| 190 | + Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) | |
| 191 | 191 | Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false) |
| 192 | + Q_PROPERTY(QList<int> split READ get_split WRITE set_split RESET reset_split STORED false) | |
| 193 | + | |
| 194 | + QList<br::Distance*> distances; | |
| 192 | 195 | |
| 193 | 196 | public: |
| 194 | 197 | /*!< */ |
| 195 | 198 | enum Operation {Mean, Sum, Max, Min}; |
| 196 | 199 | |
| 197 | 200 | private: |
| 198 | - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) | |
| 201 | + BR_PROPERTY(QString, description, "IdenticalDistance") | |
| 199 | 202 | BR_PROPERTY(Operation, operation, Mean) |
| 203 | + BR_PROPERTY(QList<int>, split, QList<int>()) | |
| 200 | 204 | |
| 201 | 205 | void train(const TemplateList &src) |
| 202 | 206 | { |
| 203 | - distance->train(src); | |
| 207 | + // Default is to train on each matrix | |
| 208 | + if (split.isEmpty()) for (int i=0; i<src.at(0).size(); i++) split.append(1); | |
| 209 | + | |
| 210 | + QList<TemplateList> partitionedSrc = src.partition(split); | |
| 211 | + | |
| 212 | + while (distances.size() < partitionedSrc.size()) | |
| 213 | + distances.append(make(description)); | |
| 214 | + | |
| 215 | + // Train on each of the partitions | |
| 216 | + for (int i=0; i<distances.size(); i++) | |
| 217 | + distances[i]->train(partitionedSrc[i]); | |
| 204 | 218 | } |
| 205 | 219 | |
| 206 | 220 | float compare(const Template &a, const Template &b) const |
| 207 | 221 | { |
| 208 | 222 | if (a.size() != b.size()) qFatal("Comparison size mismatch"); |
| 209 | 223 | |
| 210 | - QList<float> distances; | |
| 211 | - for (int i = 0; i < a.size(); i++) { | |
| 224 | + QList<float> scores; | |
| 225 | + for (int i=0; i<distances.size(); i++) { | |
| 212 | 226 | Template ai = a.file; |
| 213 | - ai.m() = a[i].clone(); | |
| 227 | + ai.m() = a[i]; | |
| 214 | 228 | Template bi = b.file; |
| 215 | - bi.m() = b[i].clone(); | |
| 216 | - distances.append(distance->compare(ai,bi)); | |
| 229 | + bi.m() = b[i]; | |
| 230 | + scores.append(distances[i]->compare(ai,bi)); | |
| 217 | 231 | } |
| 218 | 232 | |
| 219 | 233 | switch (operation) { |
| 220 | 234 | case Mean: |
| 221 | - return std::accumulate(distances.begin(),distances.end(),0.0)/(float)distances.size(); | |
| 235 | + return std::accumulate(scores.begin(),scores.end(),0.0)/(float)scores.size(); | |
| 222 | 236 | break; |
| 223 | 237 | case Sum: |
| 224 | - return std::accumulate(distances.begin(),distances.end(),0.0); | |
| 238 | + return std::accumulate(scores.begin(),scores.end(),0.0); | |
| 225 | 239 | break; |
| 226 | 240 | case Min: |
| 227 | - return *std::min_element(distances.begin(),distances.end()); | |
| 241 | + return *std::min_element(scores.begin(),scores.end()); | |
| 228 | 242 | break; |
| 229 | 243 | case Max: |
| 230 | - return *std::max_element(distances.begin(),distances.end()); | |
| 244 | + return *std::max_element(scores.begin(),scores.end()); | |
| 231 | 245 | break; |
| 232 | 246 | default: |
| 233 | 247 | qFatal("Invalid operation."); |
| ... | ... | @@ -236,12 +250,19 @@ private: |
| 236 | 250 | |
| 237 | 251 | void store(QDataStream &stream) const |
| 238 | 252 | { |
| 239 | - distance->store(stream); | |
| 253 | + stream << distances.size(); | |
| 254 | + foreach (Distance *distance, distances) | |
| 255 | + distance->store(stream); | |
| 240 | 256 | } |
| 241 | 257 | |
| 242 | 258 | void load(QDataStream &stream) |
| 243 | 259 | { |
| 244 | - distance->load(stream); | |
| 260 | + int numDistances; | |
| 261 | + stream >> numDistances; | |
| 262 | + while (distances.size() < numDistances) | |
| 263 | + distances.append(make(description)); | |
| 264 | + foreach (Distance *distance, distances) | |
| 265 | + distance->load(stream); | |
| 245 | 266 | } |
| 246 | 267 | }; |
| 247 | 268 | ... | ... |
openbr/plugins/quality.cpp
| ... | ... | @@ -179,12 +179,7 @@ class MatchProbabilityDistance : public Distance |
| 179 | 179 | } |
| 180 | 180 | } |
| 181 | 181 | |
| 182 | - qDebug() << "Genuines: " << genuineScores.mid(0,5); | |
| 183 | - qDebug() << "Impostors: " << impostorScores.mid(0,5); | |
| 184 | - | |
| 185 | 182 | mp = MP(genuineScores, impostorScores); |
| 186 | - | |
| 187 | - qDebug() << mp(-0.881882,true); | |
| 188 | 183 | } |
| 189 | 184 | |
| 190 | 185 | float compare(const Template &target, const Template &query) const |
| ... | ... | @@ -192,7 +187,6 @@ class MatchProbabilityDistance : public Distance |
| 192 | 187 | const float rawScore = distance->compare(target, query); |
| 193 | 188 | if (rawScore == -std::numeric_limits<float>::max()) return rawScore; |
| 194 | 189 | if (!Globals->scoreNormalization) return -log(rawScore+1); |
| 195 | - qDebug() << mp(rawScore, gaussian) << rawScore; | |
| 196 | 190 | return mp(rawScore, gaussian); |
| 197 | 191 | } |
| 198 | 192 | ... | ... |