Commit 0a275cfb5035505c666827644e46effbbc096a8d

Authored by Scott Klum
1 parent b174c493

OperationDistance::train finished

openbr/core/core.cpp
@@ -463,7 +463,7 @@ void br::Convert(const File &fileType, const File &inputFile, const File &output @@ -463,7 +463,7 @@ void br::Convert(const File &fileType, const File &inputFile, const File &output
463 463
464 if ((targetFiles.size() != m.cols || queryFiles.size() != m.rows) 464 if ((targetFiles.size() != m.cols || queryFiles.size() != m.rows)
465 && (m.cols != 1 || targetFiles.size() != m.rows || queryFiles.size() != m.rows)) 465 && (m.cols != 1 || targetFiles.size() != m.rows || queryFiles.size() != m.rows))
466 - qFatal("Similarity matrix and file size mismatch."); 466 + qFatal("Similarity matrix (%d, %d) and header (%d, %d) size mismatch.", m.rows, m.cols, queryFiles.size(), targetFiles.size());
467 467
468 QSharedPointer<Output> o(Factory<Output>::make(outputFile)); 468 QSharedPointer<Output> o(Factory<Output>::make(outputFile));
469 o->initialize(targetFiles, queryFiles); 469 o->initialize(targetFiles, queryFiles);
openbr/core/fuse.cpp
@@ -30,6 +30,9 @@ using namespace cv; @@ -30,6 +30,9 @@ using namespace cv;
30 30
31 static void normalizeMatrix(Mat &matrix, const Mat &mask, const QString &method) 31 static void normalizeMatrix(Mat &matrix, const Mat &mask, const QString &method)
32 { 32 {
  33 + if (matrix.rows != mask.rows && matrix.cols != mask.cols)
  34 + qFatal("Similarity matrix (%d, %d) and mask (%d, %d) size mismatch.", matrix.rows, matrix.cols, mask.rows, mask.cols);
  35 +
33 if (method == "None") return; 36 if (method == "None") return;
34 37
35 QList<float> vals; vals.reserve(matrix.rows*matrix.cols); 38 QList<float> vals; vals.reserve(matrix.rows*matrix.cols);
openbr/openbr_plugin.h
@@ -483,7 +483,7 @@ struct TemplateList : public QList&lt;Template&gt; @@ -483,7 +483,7 @@ struct TemplateList : public QList&lt;Template&gt;
483 } 483 }
484 484
485 /*! 485 /*!
486 - * \brief Returns a #br::TemplateList containing templates with one matrix at the specified index \em index. 486 + * \brief Returns a list of #br::TemplateList containing templates with a each TemplateList containing the number of matrices specified by \em partitionSizes.
487 */ 487 */
488 QList<TemplateList> partition(const QList<int> &partitionSizes) const 488 QList<TemplateList> partition(const QList<int> &partitionSizes) const
489 { 489 {
openbr/plugins/distance.cpp
@@ -187,47 +187,61 @@ class OperationDistance : public Distance @@ -187,47 +187,61 @@ class OperationDistance : public Distance
187 { 187 {
188 Q_OBJECT 188 Q_OBJECT
189 Q_ENUMS(Operation) 189 Q_ENUMS(Operation)
190 - Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) 190 + Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)
191 Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false) 191 Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false)
  192 + Q_PROPERTY(QList<int> split READ get_split WRITE set_split RESET reset_split STORED false)
  193 +
  194 + QList<br::Distance*> distances;
192 195
193 public: 196 public:
194 /*!< */ 197 /*!< */
195 enum Operation {Mean, Sum, Max, Min}; 198 enum Operation {Mean, Sum, Max, Min};
196 199
197 private: 200 private:
198 - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) 201 + BR_PROPERTY(QString, description, "IdenticalDistance")
199 BR_PROPERTY(Operation, operation, Mean) 202 BR_PROPERTY(Operation, operation, Mean)
  203 + BR_PROPERTY(QList<int>, split, QList<int>())
200 204
201 void train(const TemplateList &src) 205 void train(const TemplateList &src)
202 { 206 {
203 - distance->train(src); 207 + // Default is to train on each matrix
  208 + if (split.isEmpty()) for (int i=0; i<src.at(0).size(); i++) split.append(1);
  209 +
  210 + QList<TemplateList> partitionedSrc = src.partition(split);
  211 +
  212 + while (distances.size() < partitionedSrc.size())
  213 + distances.append(make(description));
  214 +
  215 + // Train on each of the partitions
  216 + for (int i=0; i<distances.size(); i++)
  217 + distances[i]->train(partitionedSrc[i]);
204 } 218 }
205 219
206 float compare(const Template &a, const Template &b) const 220 float compare(const Template &a, const Template &b) const
207 { 221 {
208 if (a.size() != b.size()) qFatal("Comparison size mismatch"); 222 if (a.size() != b.size()) qFatal("Comparison size mismatch");
209 223
210 - QList<float> distances;  
211 - for (int i = 0; i < a.size(); i++) { 224 + QList<float> scores;
  225 + for (int i=0; i<distances.size(); i++) {
212 Template ai = a.file; 226 Template ai = a.file;
213 - ai.m() = a[i].clone(); 227 + ai.m() = a[i];
214 Template bi = b.file; 228 Template bi = b.file;
215 - bi.m() = b[i].clone();  
216 - distances.append(distance->compare(ai,bi)); 229 + bi.m() = b[i];
  230 + scores.append(distances[i]->compare(ai,bi));
217 } 231 }
218 232
219 switch (operation) { 233 switch (operation) {
220 case Mean: 234 case Mean:
221 - return std::accumulate(distances.begin(),distances.end(),0.0)/(float)distances.size(); 235 + return std::accumulate(scores.begin(),scores.end(),0.0)/(float)scores.size();
222 break; 236 break;
223 case Sum: 237 case Sum:
224 - return std::accumulate(distances.begin(),distances.end(),0.0); 238 + return std::accumulate(scores.begin(),scores.end(),0.0);
225 break; 239 break;
226 case Min: 240 case Min:
227 - return *std::min_element(distances.begin(),distances.end()); 241 + return *std::min_element(scores.begin(),scores.end());
228 break; 242 break;
229 case Max: 243 case Max:
230 - return *std::max_element(distances.begin(),distances.end()); 244 + return *std::max_element(scores.begin(),scores.end());
231 break; 245 break;
232 default: 246 default:
233 qFatal("Invalid operation."); 247 qFatal("Invalid operation.");
@@ -236,12 +250,19 @@ private: @@ -236,12 +250,19 @@ private:
236 250
237 void store(QDataStream &stream) const 251 void store(QDataStream &stream) const
238 { 252 {
239 - distance->store(stream); 253 + stream << distances.size();
  254 + foreach (Distance *distance, distances)
  255 + distance->store(stream);
240 } 256 }
241 257
242 void load(QDataStream &stream) 258 void load(QDataStream &stream)
243 { 259 {
244 - distance->load(stream); 260 + int numDistances;
  261 + stream >> numDistances;
  262 + while (distances.size() < numDistances)
  263 + distances.append(make(description));
  264 + foreach (Distance *distance, distances)
  265 + distance->load(stream);
245 } 266 }
246 }; 267 };
247 268
openbr/plugins/quality.cpp
@@ -179,12 +179,7 @@ class MatchProbabilityDistance : public Distance @@ -179,12 +179,7 @@ class MatchProbabilityDistance : public Distance
179 } 179 }
180 } 180 }
181 181
182 - qDebug() << "Genuines: " << genuineScores.mid(0,5);  
183 - qDebug() << "Impostors: " << impostorScores.mid(0,5);  
184 -  
185 mp = MP(genuineScores, impostorScores); 182 mp = MP(genuineScores, impostorScores);
186 -  
187 - qDebug() << mp(-0.881882,true);  
188 } 183 }
189 184
190 float compare(const Template &target, const Template &query) const 185 float compare(const Template &target, const Template &query) const
@@ -192,7 +187,6 @@ class MatchProbabilityDistance : public Distance @@ -192,7 +187,6 @@ class MatchProbabilityDistance : public Distance
192 const float rawScore = distance->compare(target, query); 187 const float rawScore = distance->compare(target, query);
193 if (rawScore == -std::numeric_limits<float>::max()) return rawScore; 188 if (rawScore == -std::numeric_limits<float>::max()) return rawScore;
194 if (!Globals->scoreNormalization) return -log(rawScore+1); 189 if (!Globals->scoreNormalization) return -log(rawScore+1);
195 - qDebug() << mp(rawScore, gaussian) << rawScore;  
196 return mp(rawScore, gaussian); 190 return mp(rawScore, gaussian);
197 } 191 }
198 192