Commit f605bccf7629fe6feed17aed2752d0951e82a70a
1 parent
4de57762
Merge changes
Showing
6 changed files
with
60 additions
and
12 deletions
openbr/core/core.cpp
| @@ -41,6 +41,7 @@ struct AlgorithmCore | @@ -41,6 +41,7 @@ struct AlgorithmCore | ||
| 41 | 41 | ||
| 42 | void train(const File &input, const QString &model) | 42 | void train(const File &input, const QString &model) |
| 43 | { | 43 | { |
| 44 | + qDebug() << input; | ||
| 44 | TemplateList data(TemplateList::fromGallery(input)); | 45 | TemplateList data(TemplateList::fromGallery(input)); |
| 45 | 46 | ||
| 46 | if (transform.isNull()) qFatal("Null transform."); | 47 | if (transform.isNull()) qFatal("Null transform."); |
| @@ -393,6 +394,7 @@ void br::Convert(const File &fileType, const File &inputFile, const File &output | @@ -393,6 +394,7 @@ void br::Convert(const File &fileType, const File &inputFile, const File &output | ||
| 393 | QSharedPointer<Output> o(Factory<Output>::make(outputFile)); | 394 | QSharedPointer<Output> o(Factory<Output>::make(outputFile)); |
| 394 | o->initialize(targetFiles, queryFiles); | 395 | o->initialize(targetFiles, queryFiles); |
| 395 | 396 | ||
| 397 | + qDebug() << m.rows << m.cols << targetFiles.size() << queryFiles.size(); | ||
| 396 | for (int i=0; i<queryFiles.size(); i++) | 398 | for (int i=0; i<queryFiles.size(); i++) |
| 397 | for (int j=0; j<targetFiles.size(); j++) | 399 | for (int j=0; j<targetFiles.size(); j++) |
| 398 | o->setRelative(m.at<float>(i,j), i, j); | 400 | o->setRelative(m.at<float>(i,j), i, j); |
openbr/openbr_plugin.cpp
| @@ -386,16 +386,27 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | @@ -386,16 +386,27 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | ||
| 386 | newTemplates = newTemplates.reduced(); | 386 | newTemplates = newTemplates.reduced(); |
| 387 | 387 | ||
| 388 | const int crossValidate = gallery.get<int>("crossValidate"); | 388 | const int crossValidate = gallery.get<int>("crossValidate"); |
| 389 | - if (crossValidate > 0) srand(0); | ||
| 390 | 389 | ||
| 391 | - if (gallery.getBool("leaveOneOut", 0)) { | ||
| 392 | - QStringList subjects = File::get(newTemplates.files(),"Subject","-1"); | ||
| 393 | - subjects. | ||
| 394 | - // Get QStringLists of unique subjects | ||
| 395 | - | ||
| 396 | - // For each list of unique subjects, decide randomly which to test on | ||
| 397 | - for (int i = 0; i < subjects.size(); i++) { | ||
| 398 | - if (subjects | 390 | + if (gallery.getBool("leaveOneOut")) { |
| 391 | + QStringList subjects; | ||
| 392 | + for (int i = 0; i < newTemplates.size(); i++) { | ||
| 393 | + QString subject = newTemplates.at(i).file.get<QString>("Subject"); | ||
| 394 | + // Have we seen this subject before? | ||
| 395 | + if (subjects.contains(subject)) { | ||
| 396 | + subjects.append(subject); | ||
| 397 | + // Get indices belonging to this subject | ||
| 398 | + QList<int> subjectIndices = newTemplates.find("Subject",subject); | ||
| 399 | + for (int j = 0; j < subjectIndices.size(); j++) { | ||
| 400 | + // Set subject partitions | ||
| 401 | + newTemplates[subjectIndices[j]].file.set("Partition",j); | ||
| 402 | + } | ||
| 403 | + // Generate more templates if necessary | ||
| 404 | + for (int j=0; j<crossValidate-subjectIndices.size(); j++) { | ||
| 405 | + Template leaveOneOutTemplate = newTemplates[subjectIndices[j%subjectIndices.size()]]; | ||
| 406 | + leaveOneOutTemplate.file.set("Partition", j+subjectIndices.size()); | ||
| 407 | + newTemplates.append(leaveOneOutTemplate); | ||
| 408 | + } | ||
| 409 | + } | ||
| 399 | } | 410 | } |
| 400 | } else { | 411 | } else { |
| 401 | for (int i=newTemplates.size()-1; i>=0; i--) { | 412 | for (int i=newTemplates.size()-1; i>=0; i--) { |
openbr/openbr_plugin.h
| @@ -509,6 +509,20 @@ struct TemplateList : public QList<Template> | @@ -509,6 +509,20 @@ struct TemplateList : public QList<Template> | ||
| 509 | reduced.merge(t); | 509 | reduced.merge(t); |
| 510 | return TemplateList() << reduced; | 510 | return TemplateList() << reduced; |
| 511 | } | 511 | } |
| 512 | + | ||
| 513 | + /*! | ||
| 514 | + * \brief Find the indices of templates with specified key, value pairs. | ||
| 515 | + */ | ||
| 516 | + template<typename T> | ||
| 517 | + QList<int> find(const QString& key, const T& value) | ||
| 518 | + { | ||
| 519 | + QList<int> indices; | ||
| 520 | + for (int i=0; i<size(); i++) | ||
| 521 | + if (at(i).file.contains(key)) | ||
| 522 | + if (at(i).file.get<T>(key) == value) | ||
| 523 | + indices.append(i); | ||
| 524 | + return indices; | ||
| 525 | + } | ||
| 512 | }; | 526 | }; |
| 513 | 527 | ||
| 514 | /*! | 528 | /*! |
openbr/plugins/output.cpp
| @@ -365,9 +365,12 @@ class evalOutput : public MatrixOutput | @@ -365,9 +365,12 @@ class evalOutput : public MatrixOutput | ||
| 365 | 365 | ||
| 366 | ~evalOutput() | 366 | ~evalOutput() |
| 367 | { | 367 | { |
| 368 | + qDebug() << "here"; | ||
| 369 | + | ||
| 368 | if (data.data) { | 370 | if (data.data) { |
| 369 | const QString csv = QString(file.name).replace(".eval", ".csv"); | 371 | const QString csv = QString(file.name).replace(".eval", ".csv"); |
| 370 | if ((Globals->crossValidate == 0) || (!crossValidate)) { | 372 | if ((Globals->crossValidate == 0) || (!crossValidate)) { |
| 373 | + qDebug() << "here"; | ||
| 371 | Evaluate(data, BEE::makeMask(targetFiles, queryFiles), csv); | 374 | Evaluate(data, BEE::makeMask(targetFiles, queryFiles), csv); |
| 372 | } else { | 375 | } else { |
| 373 | QFutureSynchronizer<float> futures; | 376 | QFutureSynchronizer<float> futures; |
openbr/plugins/pp5.cpp
| @@ -299,6 +299,7 @@ BR_REGISTER(Transform, PP5EnrollTransform) | @@ -299,6 +299,7 @@ BR_REGISTER(Transform, PP5EnrollTransform) | ||
| 299 | * \brief Compare templates with PP5 | 299 | * \brief Compare templates with PP5 |
| 300 | * \author Josh Klontz \cite jklontz | 300 | * \author Josh Klontz \cite jklontz |
| 301 | * \author E. Taborsky \cite mmtaborsky | 301 | * \author E. Taborsky \cite mmtaborsky |
| 302 | + * \note PP5 distance is known to be asymmetric | ||
| 302 | */ | 303 | */ |
| 303 | class PP5CompareDistance : public Distance | 304 | class PP5CompareDistance : public Distance |
| 304 | , public PP5Context | 305 | , public PP5Context |
openbr/plugins/validate.cpp
| 1 | #include <QFutureSynchronizer> | 1 | #include <QFutureSynchronizer> |
| 2 | #include <QtConcurrentRun> | 2 | #include <QtConcurrentRun> |
| 3 | #include "openbr_internal.h" | 3 | #include "openbr_internal.h" |
| 4 | +#include "openbr/core/common.h" | ||
| 4 | #include <openbr/core/qtutils.h> | 5 | #include <openbr/core/qtutils.h> |
| 5 | 6 | ||
| 6 | namespace br | 7 | namespace br |
| @@ -17,7 +18,9 @@ class CrossValidateTransform : public MetaTransform | @@ -17,7 +18,9 @@ class CrossValidateTransform : public MetaTransform | ||
| 17 | { | 18 | { |
| 18 | Q_OBJECT | 19 | Q_OBJECT |
| 19 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) | 20 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) |
| 21 | + Q_PROPERTY(bool leaveOneOut READ get_leaveOneOut WRITE set_leaveOneOut RESET reset_leaveOneOut STORED false) | ||
| 20 | BR_PROPERTY(QString, description, "Identity") | 22 | BR_PROPERTY(QString, description, "Identity") |
| 23 | + BR_PROPERTY(bool, leaveOneOut, false) | ||
| 21 | 24 | ||
| 22 | QList<br::Transform*> transforms; | 25 | QList<br::Transform*> transforms; |
| 23 | 26 | ||
| @@ -41,11 +44,25 @@ class CrossValidateTransform : public MetaTransform | @@ -41,11 +44,25 @@ class CrossValidateTransform : public MetaTransform | ||
| 41 | QFutureSynchronizer<void> futures; | 44 | QFutureSynchronizer<void> futures; |
| 42 | for (int i=0; i<numPartitions; i++) { | 45 | for (int i=0; i<numPartitions; i++) { |
| 43 | TemplateList partitionedData = data; | 46 | TemplateList partitionedData = data; |
| 47 | + QList<int> removed; | ||
| 44 | for (int j=partitionedData.size()-1; j>=0; j--) | 48 | for (int j=partitionedData.size()-1; j>=0; j--) |
| 45 | - // Remove all templates from partition i | ||
| 46 | - if (partitions[j] == i) | ||
| 47 | - partitionedData.removeAt(j); | 49 | + // Remove all templates belonging to partition i |
| 50 | + // if leaveOneOut is true, | ||
| 51 | + // and i is greater than the number of images for a particular subject | ||
| 52 | + // even if the partitions are different | ||
| 53 | + if (leaveOneOut) { | ||
| 54 | + QList<int> subjectIndices = partitionedData.find("Subject",partitionedData.at(j).file.get<QString>("Subject")); | ||
| 55 | + qDebug() << i << subjectIndices.size(); | ||
| 56 | + if (i > subjectIndices.size()) { | ||
| 57 | + qDebug() << i%subjectIndices.size(); | ||
| 58 | + removed.append(subjectIndices[i%subjectIndices.size()]); | ||
| 59 | + } | ||
| 60 | + } else if (partitions[j] == i) | ||
| 61 | + removed.append(j); | ||
| 62 | + typedef QPair<int,int> Pair; | ||
| 63 | + foreach (const Pair &pair, Common::Sort(removed,true)) partitionedData.removeAt(pair.first); | ||
| 48 | // Train on the remaining templates | 64 | // Train on the remaining templates |
| 65 | + foreach (const Template &t, partitionedData) qDebug() << "Remaining data for partition " << i << ": " << t.file.baseName(); | ||
| 49 | futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); | 66 | futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); |
| 50 | } | 67 | } |
| 51 | futures.waitForFinished(); | 68 | futures.waitForFinished(); |