Commit 0a275cfb5035505c666827644e46effbbc096a8d

Authored by Scott Klum
1 parent b174c493

OperationDistance::train finished

openbr/core/core.cpp
... ... @@ -463,7 +463,7 @@ void br::Convert(const File &fileType, const File &inputFile, const File &output
463 463  
464 464 if ((targetFiles.size() != m.cols || queryFiles.size() != m.rows)
465 465 && (m.cols != 1 || targetFiles.size() != m.rows || queryFiles.size() != m.rows))
466   - qFatal("Similarity matrix and file size mismatch.");
  466 + qFatal("Similarity matrix (%d, %d) and header (%d, %d) size mismatch.", m.rows, m.cols, queryFiles.size(), targetFiles.size());
467 467  
468 468 QSharedPointer<Output> o(Factory<Output>::make(outputFile));
469 469 o->initialize(targetFiles, queryFiles);
... ...
openbr/core/fuse.cpp
... ... @@ -30,6 +30,9 @@ using namespace cv;
30 30  
31 31 static void normalizeMatrix(Mat &matrix, const Mat &mask, const QString &method)
32 32 {
  33 + if (matrix.rows != mask.rows && matrix.cols != mask.cols)
  34 + qFatal("Similarity matrix (%d, %d) and mask (%d, %d) size mismatch.", matrix.rows, matrix.cols, mask.rows, mask.cols);
  35 +
33 36 if (method == "None") return;
34 37  
35 38 QList<float> vals; vals.reserve(matrix.rows*matrix.cols);
... ...
openbr/openbr_plugin.h
... ... @@ -483,7 +483,7 @@ struct TemplateList : public QList&lt;Template&gt;
483 483 }
484 484  
485 485 /*!
486   - * \brief Returns a #br::TemplateList containing templates with one matrix at the specified index \em index.
  486 + * \brief Returns a list of #br::TemplateList containing templates with a each TemplateList containing the number of matrices specified by \em partitionSizes.
487 487 */
488 488 QList<TemplateList> partition(const QList<int> &partitionSizes) const
489 489 {
... ...
openbr/plugins/distance.cpp
... ... @@ -187,47 +187,61 @@ class OperationDistance : public Distance
187 187 {
188 188 Q_OBJECT
189 189 Q_ENUMS(Operation)
190   - Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
  190 + Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false)
191 191 Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false)
  192 + Q_PROPERTY(QList<int> split READ get_split WRITE set_split RESET reset_split STORED false)
  193 +
  194 + QList<br::Distance*> distances;
192 195  
193 196 public:
194 197 /*!< */
195 198 enum Operation {Mean, Sum, Max, Min};
196 199  
197 200 private:
198   - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))
  201 + BR_PROPERTY(QString, description, "IdenticalDistance")
199 202 BR_PROPERTY(Operation, operation, Mean)
  203 + BR_PROPERTY(QList<int>, split, QList<int>())
200 204  
201 205 void train(const TemplateList &src)
202 206 {
203   - distance->train(src);
  207 + // Default is to train on each matrix
  208 + if (split.isEmpty()) for (int i=0; i<src.at(0).size(); i++) split.append(1);
  209 +
  210 + QList<TemplateList> partitionedSrc = src.partition(split);
  211 +
  212 + while (distances.size() < partitionedSrc.size())
  213 + distances.append(make(description));
  214 +
  215 + // Train on each of the partitions
  216 + for (int i=0; i<distances.size(); i++)
  217 + distances[i]->train(partitionedSrc[i]);
204 218 }
205 219  
206 220 float compare(const Template &a, const Template &b) const
207 221 {
208 222 if (a.size() != b.size()) qFatal("Comparison size mismatch");
209 223  
210   - QList<float> distances;
211   - for (int i = 0; i < a.size(); i++) {
  224 + QList<float> scores;
  225 + for (int i=0; i<distances.size(); i++) {
212 226 Template ai = a.file;
213   - ai.m() = a[i].clone();
  227 + ai.m() = a[i];
214 228 Template bi = b.file;
215   - bi.m() = b[i].clone();
216   - distances.append(distance->compare(ai,bi));
  229 + bi.m() = b[i];
  230 + scores.append(distances[i]->compare(ai,bi));
217 231 }
218 232  
219 233 switch (operation) {
220 234 case Mean:
221   - return std::accumulate(distances.begin(),distances.end(),0.0)/(float)distances.size();
  235 + return std::accumulate(scores.begin(),scores.end(),0.0)/(float)scores.size();
222 236 break;
223 237 case Sum:
224   - return std::accumulate(distances.begin(),distances.end(),0.0);
  238 + return std::accumulate(scores.begin(),scores.end(),0.0);
225 239 break;
226 240 case Min:
227   - return *std::min_element(distances.begin(),distances.end());
  241 + return *std::min_element(scores.begin(),scores.end());
228 242 break;
229 243 case Max:
230   - return *std::max_element(distances.begin(),distances.end());
  244 + return *std::max_element(scores.begin(),scores.end());
231 245 break;
232 246 default:
233 247 qFatal("Invalid operation.");
... ... @@ -236,12 +250,19 @@ private:
236 250  
237 251 void store(QDataStream &stream) const
238 252 {
239   - distance->store(stream);
  253 + stream << distances.size();
  254 + foreach (Distance *distance, distances)
  255 + distance->store(stream);
240 256 }
241 257  
242 258 void load(QDataStream &stream)
243 259 {
244   - distance->load(stream);
  260 + int numDistances;
  261 + stream >> numDistances;
  262 + while (distances.size() < numDistances)
  263 + distances.append(make(description));
  264 + foreach (Distance *distance, distances)
  265 + distance->load(stream);
245 266 }
246 267 };
247 268  
... ...
openbr/plugins/quality.cpp
... ... @@ -179,12 +179,7 @@ class MatchProbabilityDistance : public Distance
179 179 }
180 180 }
181 181  
182   - qDebug() << "Genuines: " << genuineScores.mid(0,5);
183   - qDebug() << "Impostors: " << impostorScores.mid(0,5);
184   -
185 182 mp = MP(genuineScores, impostorScores);
186   -
187   - qDebug() << mp(-0.881882,true);
188 183 }
189 184  
190 185 float compare(const Template &target, const Template &query) const
... ... @@ -192,7 +187,6 @@ class MatchProbabilityDistance : public Distance
192 187 const float rawScore = distance->compare(target, query);
193 188 if (rawScore == -std::numeric_limits<float>::max()) return rawScore;
194 189 if (!Globals->scoreNormalization) return -log(rawScore+1);
195   - qDebug() << mp(rawScore, gaussian) << rawScore;
196 190 return mp(rawScore, gaussian);
197 191 }
198 192  
... ...