Commit c743cdf51dfedf4d3ab12c269879b40f84cb80f3
Merge branch 'master' of https://github.com/biometrics/openbr
Showing
12 changed files
with
175 additions
and
45 deletions
openbr/core/bee.cpp
| @@ -287,8 +287,8 @@ cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, | @@ -287,8 +287,8 @@ cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, | ||
| 287 | 287 | ||
| 288 | Mask_t val; | 288 | Mask_t val; |
| 289 | if (fileA == fileB) val = DontCare; | 289 | if (fileA == fileB) val = DontCare; |
| 290 | - else if (labelA == "-1") val = DontCare; | ||
| 291 | - else if (labelB == "-1") val = DontCare; | 290 | + else if (labelA == "-1") val = DontCare; |
| 291 | + else if (labelB == "-1") val = DontCare; | ||
| 292 | else if (partitionA != partition) val = DontCare; | 292 | else if (partitionA != partition) val = DontCare; |
| 293 | else if (partitionB == -1) val = NonMatch; | 293 | else if (partitionB == -1) val = NonMatch; |
| 294 | else if (partitionB != partition) val = DontCare; | 294 | else if (partitionB != partition) val = DontCare; |
openbr/openbr_plugin.cpp
| @@ -386,36 +386,66 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | @@ -386,36 +386,66 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | ||
| 386 | newTemplates = newTemplates.reduced(); | 386 | newTemplates = newTemplates.reduced(); |
| 387 | 387 | ||
| 388 | const int crossValidate = gallery.get<int>("crossValidate"); | 388 | const int crossValidate = gallery.get<int>("crossValidate"); |
| 389 | - if (crossValidate > 0) srand(0); | ||
| 390 | - | ||
| 391 | - for (int i=newTemplates.size()-1; i>=0; i--) { | ||
| 392 | - newTemplates[i].file.set("Index", i+templates.size()); | ||
| 393 | - newTemplates[i].file.set("Gallery", gallery.name); | ||
| 394 | - | ||
| 395 | - if (crossValidate > 0) { | ||
| 396 | - if (newTemplates[i].file.getBool("duplicatePartitions")) { | ||
| 397 | - // The duplicatePartitions flag is used to add target images | ||
| 398 | - // crossValidate times to the simmat/mask | ||
| 399 | - // when multiple training sets are being used | ||
| 400 | - | ||
| 401 | - // Set template to the first parition | ||
| 402 | - newTemplates[i].file.set("Partition", QVariant(0)); | ||
| 403 | - | ||
| 404 | - // Insert templates for all the other partitions | ||
| 405 | - for (int j=crossValidate-1; j>=1; j--) { | ||
| 406 | - Template allPartitionTemplate = newTemplates[i]; | ||
| 407 | - allPartitionTemplate.file.set("Partition", j); | ||
| 408 | - newTemplates.insert(i+1, allPartitionTemplate); | 389 | + |
| 390 | + if (gallery.getBool("leaveOneOut")) { | ||
| 391 | + QStringList labels; | ||
| 392 | + for (int i=newTemplates.size()-1; i>=0; i--) { | ||
| 393 | + newTemplates[i].file.set("Index", i+templates.size()); | ||
| 394 | + newTemplates[i].file.set("Gallery", gallery.name); | ||
| 395 | + | ||
| 396 | + QString label = newTemplates.at(i).file.get<QString>("Label"); | ||
| 397 | + // Have we seen this subject before? | ||
| 398 | + if (!labels.contains(label)) { | ||
| 399 | + labels.append(label); | ||
| 400 | + // Get indices belonging to this subject | ||
| 401 | + QList<int> labelIndices = newTemplates.find("Label",label); | ||
| 402 | + for (int j = 0; j < labelIndices.size(); j++) { | ||
| 403 | + // Set subject partitions | ||
| 404 | + newTemplates[labelIndices[j]].file.set("Partition",j%crossValidate); | ||
| 405 | + } | ||
| 406 | + // Extend the gallery for each partition | ||
| 407 | + for (int j=0; j<labelIndices.size(); j++) { | ||
| 408 | + for (int k=0; k<crossValidate; k++) { | ||
| 409 | + Template leaveOneOutTemplate = newTemplates[labelIndices[j]]; | ||
| 410 | + if (k!=leaveOneOutTemplate.file.get<int>("Partition")) { | ||
| 411 | + leaveOneOutTemplate.file.set("Partition", k); | ||
| 412 | + leaveOneOutTemplate.file.set("testOnly", true); | ||
| 413 | + newTemplates.insert(i+1,leaveOneOutTemplate); | ||
| 414 | + } | ||
| 415 | + } | ||
| 409 | } | 416 | } |
| 410 | - } else if (newTemplates[i].file.getBool("allPartitions")) { | ||
| 411 | - // The allPartitions flag is used to add an extended set | ||
| 412 | - // of target images to every partition | ||
| 413 | - newTemplates[i].file.set("Partition", -1); | ||
| 414 | - } else { | 417 | + } |
| 418 | + } | ||
| 419 | + } else { | ||
| 420 | + for (int i=newTemplates.size()-1; i>=0; i--) { | ||
| 421 | + newTemplates[i].file.set("Index", i+templates.size()); | ||
| 422 | + newTemplates[i].file.set("Gallery", gallery.name); | ||
| 423 | + | ||
| 424 | + if (crossValidate > 0) { | ||
| 425 | + if (newTemplates[i].file.getBool("duplicatePartitions")) { | ||
| 426 | + // The duplicatePartitions flag is used to add target images | ||
| 427 | + // crossValidate times to the simmat/mask | ||
| 428 | + // when multiple training sets are being used | ||
| 429 | + | ||
| 430 | + // Set template to the first parition | ||
| 431 | + newTemplates[i].file.set("Partition", QVariant(0)); | ||
| 432 | + | ||
| 433 | + // Insert templates for all the other partitions | ||
| 434 | + for (int j=crossValidate-1; j>0; j--) { | ||
| 435 | + Template duplicatePartitionsTemplate = newTemplates[i]; | ||
| 436 | + duplicatePartitionsTemplate.file.set("Partition", j); | ||
| 437 | + newTemplates.insert(i+1, duplicatePartitionsTemplate); | ||
| 438 | + } | ||
| 439 | + } else if (newTemplates[i].file.getBool("allPartitions")) { | ||
| 440 | + // The allPartitions flag is used to add an extended set | ||
| 441 | + // of target images to every partition | ||
| 442 | + newTemplates[i].file.set("Partition", -1); | ||
| 443 | + } else { | ||
| 415 | // Direct use of "Label" is not general -cao | 444 | // Direct use of "Label" is not general -cao |
| 416 | const QByteArray md5 = QCryptographicHash::hash(newTemplates[i].file.get<QString>("Label").toLatin1(), QCryptographicHash::Md5); | 445 | const QByteArray md5 = QCryptographicHash::hash(newTemplates[i].file.get<QString>("Label").toLatin1(), QCryptographicHash::Md5); |
| 417 | // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow | 446 | // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow |
| 418 | newTemplates[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); | 447 | newTemplates[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); |
| 448 | + } | ||
| 419 | } | 449 | } |
| 420 | } | 450 | } |
| 421 | } | 451 | } |
openbr/openbr_plugin.h
| @@ -503,6 +503,20 @@ struct TemplateList : public QList<Template> | @@ -503,6 +503,20 @@ struct TemplateList : public QList<Template> | ||
| 503 | reduced.merge(t); | 503 | reduced.merge(t); |
| 504 | return TemplateList() << reduced; | 504 | return TemplateList() << reduced; |
| 505 | } | 505 | } |
| 506 | + | ||
| 507 | + /*! | ||
| 508 | + * \brief Find the indices of templates with specified key, value pairs. | ||
| 509 | + */ | ||
| 510 | + template<typename T> | ||
| 511 | + QList<int> find(const QString& key, const T& value) | ||
| 512 | + { | ||
| 513 | + QList<int> indices; | ||
| 514 | + for (int i=0; i<size(); i++) | ||
| 515 | + if (at(i).file.contains(key)) | ||
| 516 | + if (at(i).file.get<T>(key) == value) | ||
| 517 | + indices.append(i); | ||
| 518 | + return indices; | ||
| 519 | + } | ||
| 506 | }; | 520 | }; |
| 507 | 521 | ||
| 508 | /*! | 522 | /*! |
openbr/plugins/format.cpp
| @@ -236,6 +236,7 @@ class DefaultFormat : public Format | @@ -236,6 +236,7 @@ class DefaultFormat : public Format | ||
| 236 | videoWriter.file = file; | 236 | videoWriter.file = file; |
| 237 | videoWriter.write(t); | 237 | videoWriter.write(t); |
| 238 | } else if (t.size() == 1) { | 238 | } else if (t.size() == 1) { |
| 239 | + QtUtils::touchDir(QDir(file.path())); | ||
| 239 | imwrite(file.name.toStdString(), t); | 240 | imwrite(file.name.toStdString(), t); |
| 240 | } | 241 | } |
| 241 | } | 242 | } |
openbr/plugins/gallery.cpp
| @@ -135,7 +135,7 @@ class EmptyGallery : public Gallery | @@ -135,7 +135,7 @@ class EmptyGallery : public Gallery | ||
| 135 | { | 135 | { |
| 136 | Q_OBJECT | 136 | Q_OBJECT |
| 137 | Q_PROPERTY(QString regexp READ get_regexp WRITE set_regexp RESET reset_regexp STORED false) | 137 | Q_PROPERTY(QString regexp READ get_regexp WRITE set_regexp RESET reset_regexp STORED false) |
| 138 | - BR_PROPERTY(QString, regexp, "") | 138 | + BR_PROPERTY(QString, regexp, QString()) |
| 139 | 139 | ||
| 140 | void init() | 140 | void init() |
| 141 | { | 141 | { |
| @@ -184,7 +184,10 @@ class EmptyGallery : public Gallery | @@ -184,7 +184,10 @@ class EmptyGallery : public Gallery | ||
| 184 | // Enrolling a null file is used as an idiom to initialize an algorithm | 184 | // Enrolling a null file is used as an idiom to initialize an algorithm |
| 185 | if (file.name.isEmpty()) return; | 185 | if (file.name.isEmpty()) return; |
| 186 | 186 | ||
| 187 | - const QString destination = file.name + "/" + (file.getBool("preservePath") ? t.file.name : t.file.fileName()); | 187 | + const QString newFormat = file.get<QString>("newFormat",QString()); |
| 188 | + QString destination = file.name + "/" + (file.getBool("preservePath") ? t.file.path()+"/" : QString()); | ||
| 189 | + destination += (newFormat.isEmpty() ? t.file.fileName() : t.file.baseName()+newFormat); | ||
| 190 | + | ||
| 188 | QMutexLocker diskLocker(&diskLock); // Windows prefers to crash when writing to disk in parallel | 191 | QMutexLocker diskLocker(&diskLock); // Windows prefers to crash when writing to disk in parallel |
| 189 | if (t.isNull()) { | 192 | if (t.isNull()) { |
| 190 | QtUtils::copyFile(t.file.resolved(), destination); | 193 | QtUtils::copyFile(t.file.resolved(), destination); |
openbr/plugins/landmarks.cpp
| @@ -164,6 +164,7 @@ class DelaunayTransform : public UntrainableTransform | @@ -164,6 +164,7 @@ class DelaunayTransform : public UntrainableTransform | ||
| 164 | 164 | ||
| 165 | if (points.empty() || rects.empty()) { | 165 | if (points.empty() || rects.empty()) { |
| 166 | dst = src; | 166 | dst = src; |
| 167 | + dst.file.clearRects(); | ||
| 167 | qWarning("Delauney triangulation failed because points or rects are empty."); | 168 | qWarning("Delauney triangulation failed because points or rects are empty."); |
| 168 | return; | 169 | return; |
| 169 | } | 170 | } |
| @@ -292,11 +293,57 @@ class DelaunayTransform : public UntrainableTransform | @@ -292,11 +293,57 @@ class DelaunayTransform : public UntrainableTransform | ||
| 292 | dst.file.setRects(QList<QRectF>() << OpenCVUtils::fromRect(boundingBox)); | 293 | dst.file.setRects(QList<QRectF>() << OpenCVUtils::fromRect(boundingBox)); |
| 293 | } | 294 | } |
| 294 | } | 295 | } |
| 295 | - | ||
| 296 | }; | 296 | }; |
| 297 | 297 | ||
| 298 | BR_REGISTER(Transform, DelaunayTransform) | 298 | BR_REGISTER(Transform, DelaunayTransform) |
| 299 | 299 | ||
| 300 | +/*! | ||
| 301 | + * \ingroup transforms | ||
| 302 | + * \brief Loads a set of fiduciary points from a .dat file | ||
| 303 | + * \author Scott Klum \cite sklum | ||
| 304 | + */ | ||
| 305 | +class LoadLandmarksTransform : public UntrainableTransform | ||
| 306 | +{ | ||
| 307 | + Q_OBJECT | ||
| 308 | + | ||
| 309 | + Q_PROPERTY(QString filePath READ get_filePath WRITE set_filePath RESET reset_filePath STORED false) | ||
| 310 | + BR_PROPERTY(QString, filePath, QString()) | ||
| 311 | + | ||
| 312 | + void project(const Template &src, Template &dst) const | ||
| 313 | + { | ||
| 314 | + dst = src; | ||
| 315 | + | ||
| 316 | + // Assume the fiduciary file has the same basename as src | ||
| 317 | + QString path = filePath + "/" + src.file.baseName() + ".dat"; | ||
| 318 | + | ||
| 319 | + QFile f(path); | ||
| 320 | + if (!f.open(QIODevice::ReadOnly)) qFatal("Unable to open %s for reading.", qPrintable(path)); | ||
| 321 | + | ||
| 322 | + QList<QPointF> landmarks; | ||
| 323 | + while(!f.atEnd()) { | ||
| 324 | + QByteArray line = f.readLine(); | ||
| 325 | + QString pointSet(line); | ||
| 326 | + pointSet = pointSet.simplified(); | ||
| 327 | + if (!pointSet.isEmpty()) { | ||
| 328 | + QStringList points = pointSet.split(" "); | ||
| 329 | + landmarks.append(QPointF(points[0].toFloat(),points[1].toFloat())); | ||
| 330 | + } | ||
| 331 | + } | ||
| 332 | + | ||
| 333 | + if (landmarks.size() < 35) qFatal("Unrecognized landmark set format."); | ||
| 334 | + | ||
| 335 | + dst.file.set("rightEye", landmarks[16]); | ||
| 336 | + dst.file.set("leftEye", landmarks[18]); | ||
| 337 | + | ||
| 338 | + landmarks.removeAt(18); | ||
| 339 | + landmarks.removeAt(16); | ||
| 340 | + | ||
| 341 | + dst.file.appendPoints(landmarks); | ||
| 342 | + } | ||
| 343 | +}; | ||
| 344 | + | ||
| 345 | +BR_REGISTER(Transform, LoadLandmarksTransform) | ||
| 346 | + | ||
| 300 | } // namespace br | 347 | } // namespace br |
| 301 | 348 | ||
| 302 | #include "landmarks.moc" | 349 | #include "landmarks.moc" |
openbr/plugins/output.cpp
| @@ -259,7 +259,8 @@ class rrOutput : public MatrixOutput | @@ -259,7 +259,8 @@ class rrOutput : public MatrixOutput | ||
| 259 | { | 259 | { |
| 260 | if (file.isNull() || targetFiles.isEmpty() || queryFiles.isEmpty()) return; | 260 | if (file.isNull() || targetFiles.isEmpty() || queryFiles.isEmpty()) return; |
| 261 | const int limit = file.get<int>("limit", 20); | 261 | const int limit = file.get<int>("limit", 20); |
| 262 | - const bool byLine = file.get<bool>("byLine", false); | 262 | + const bool byLine = file.getBool("byLine"); |
| 263 | + const bool simple = file.getBool("simple"); | ||
| 263 | const float threshold = file.get<float>("threshold", -std::numeric_limits<float>::max()); | 264 | const float threshold = file.get<float>("threshold", -std::numeric_limits<float>::max()); |
| 264 | 265 | ||
| 265 | QStringList lines; | 266 | QStringList lines; |
| @@ -273,7 +274,8 @@ class rrOutput : public MatrixOutput | @@ -273,7 +274,8 @@ class rrOutput : public MatrixOutput | ||
| 273 | if (pair.first < threshold) break; | 274 | if (pair.first < threshold) break; |
| 274 | File target = targetFiles[pair.second]; | 275 | File target = targetFiles[pair.second]; |
| 275 | target.set("Score", QString::number(pair.first)); | 276 | target.set("Score", QString::number(pair.first)); |
| 276 | - files.append(target.flat()); | 277 | + if (simple) files.append(target.baseName() + " " + QString::number(pair.first)); |
| 278 | + else files.append(target.flat()); | ||
| 277 | } | 279 | } |
| 278 | lines.append(files.join(byLine ? "\n" : ",")); | 280 | lines.append(files.join(byLine ? "\n" : ",")); |
| 279 | } | 281 | } |
openbr/plugins/pp5.cpp
| @@ -299,6 +299,7 @@ BR_REGISTER(Transform, PP5EnrollTransform) | @@ -299,6 +299,7 @@ BR_REGISTER(Transform, PP5EnrollTransform) | ||
| 299 | * \brief Compare templates with PP5 | 299 | * \brief Compare templates with PP5 |
| 300 | * \author Josh Klontz \cite jklontz | 300 | * \author Josh Klontz \cite jklontz |
| 301 | * \author E. Taborsky \cite mmtaborsky | 301 | * \author E. Taborsky \cite mmtaborsky |
| 302 | + * \note PP5 distance is known to be asymmetric | ||
| 302 | */ | 303 | */ |
| 303 | class PP5CompareDistance : public Distance | 304 | class PP5CompareDistance : public Distance |
| 304 | , public PP5Context | 305 | , public PP5Context |
openbr/plugins/stasm4.cpp
| @@ -36,8 +36,8 @@ class StasmInitializer : public Initializer | @@ -36,8 +36,8 @@ class StasmInitializer : public Initializer | ||
| 36 | 36 | ||
| 37 | void initialize() const | 37 | void initialize() const |
| 38 | { | 38 | { |
| 39 | - Globals->abbreviations.insert("RectFromStasmEyes","RectFromPoints([29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],0.125,6.0)+Resize(44,164)"); | ||
| 40 | - Globals->abbreviations.insert("RectFromStasmBrow","RectFromPoints([17, 18, 19, 20, 21, 22, 23, 24],0.15,6)+Resize(28,132)"); | 39 | + Globals->abbreviations.insert("RectFromStasmEyes","RectFromPoints([29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],0.125,6.0)"); |
| 40 | + Globals->abbreviations.insert("RectFromStasmBrow","RectFromPoints([16,17,18,19,20,21,22,23,24,25,26,27],0.15,5)"); | ||
| 41 | Globals->abbreviations.insert("RectFromStasmNose","RectFromPoints([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58],0.15,1.25)"); | 41 | Globals->abbreviations.insert("RectFromStasmNose","RectFromPoints([48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58],0.15,1.25)"); |
| 42 | Globals->abbreviations.insert("RectFromStasmMouth","RectFromPoints([59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76],0.3,2.5)"); | 42 | Globals->abbreviations.insert("RectFromStasmMouth","RectFromPoints([59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76],0.3,2.5)"); |
| 43 | } | 43 | } |
openbr/plugins/stream.cpp
| @@ -1261,7 +1261,7 @@ public: | @@ -1261,7 +1261,7 @@ public: | ||
| 1261 | { | 1261 | { |
| 1262 | } | 1262 | } |
| 1263 | 1263 | ||
| 1264 | - Q_PROPERTY(br::Transform *transform READ get_transform WRITE set_transform STORED false) | 1264 | + Q_PROPERTY(br::Transform *transform READ get_transform WRITE set_transform RESET reset_transform STORED false) |
| 1265 | Q_PROPERTY(int activeFrames READ get_activeFrames WRITE set_activeFrames RESET reset_activeFrames) | 1265 | Q_PROPERTY(int activeFrames READ get_activeFrames WRITE set_activeFrames RESET reset_activeFrames) |
| 1266 | BR_PROPERTY(br::Transform *, transform, NULL) | 1266 | BR_PROPERTY(br::Transform *, transform, NULL) |
| 1267 | BR_PROPERTY(int, activeFrames, 100) | 1267 | BR_PROPERTY(int, activeFrames, 100) |
openbr/plugins/validate.cpp
| 1 | #include <QFutureSynchronizer> | 1 | #include <QFutureSynchronizer> |
| 2 | #include <QtConcurrentRun> | 2 | #include <QtConcurrentRun> |
| 3 | #include "openbr_internal.h" | 3 | #include "openbr_internal.h" |
| 4 | +#include "openbr/core/common.h" | ||
| 4 | #include <openbr/core/qtutils.h> | 5 | #include <openbr/core/qtutils.h> |
| 5 | 6 | ||
| 6 | namespace br | 7 | namespace br |
| @@ -10,6 +11,7 @@ namespace br | @@ -10,6 +11,7 @@ namespace br | ||
| 10 | * \ingroup transforms | 11 | * \ingroup transforms |
| 11 | * \brief Cross validate a trainable transform. | 12 | * \brief Cross validate a trainable transform. |
| 12 | * \author Josh Klontz \cite jklontz | 13 | * \author Josh Klontz \cite jklontz |
| 14 | + * \author Scott Klum \cite sklum | ||
| 13 | * \note To use an extended gallery, add an allPartitions="true" flag to the gallery sigset for those images that should be compared | 15 | * \note To use an extended gallery, add an allPartitions="true" flag to the gallery sigset for those images that should be compared |
| 14 | * against for all testing partitions. | 16 | * against for all testing partitions. |
| 15 | */ | 17 | */ |
| @@ -17,7 +19,9 @@ class CrossValidateTransform : public MetaTransform | @@ -17,7 +19,9 @@ class CrossValidateTransform : public MetaTransform | ||
| 17 | { | 19 | { |
| 18 | Q_OBJECT | 20 | Q_OBJECT |
| 19 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) | 21 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) |
| 22 | + Q_PROPERTY(bool leaveOneOut READ get_leaveOneOut WRITE set_leaveOneOut RESET reset_leaveOneOut STORED false) | ||
| 20 | BR_PROPERTY(QString, description, "Identity") | 23 | BR_PROPERTY(QString, description, "Identity") |
| 24 | + BR_PROPERTY(bool, leaveOneOut, false) | ||
| 21 | 25 | ||
| 22 | QList<br::Transform*> transforms; | 26 | QList<br::Transform*> transforms; |
| 23 | 27 | ||
| @@ -40,11 +44,44 @@ class CrossValidateTransform : public MetaTransform | @@ -40,11 +44,44 @@ class CrossValidateTransform : public MetaTransform | ||
| 40 | 44 | ||
| 41 | QFutureSynchronizer<void> futures; | 45 | QFutureSynchronizer<void> futures; |
| 42 | for (int i=0; i<numPartitions; i++) { | 46 | for (int i=0; i<numPartitions; i++) { |
| 47 | + QList<int> partitionsBuffer = partitions; | ||
| 43 | TemplateList partitionedData = data; | 48 | TemplateList partitionedData = data; |
| 44 | - for (int j=partitionedData.size()-1; j>=0; j--) | ||
| 45 | - // Remove all templates from partition i | ||
| 46 | - if (partitions[j] == i) | 49 | + int j = partitionedData.size()-1; |
| 50 | + while (j>=0) { | ||
| 51 | + // Remove all templates belonging to partition i | ||
| 52 | + // if leaveOneOut is true, | ||
| 53 | + // and i is greater than the number of images for a particular subject | ||
| 54 | + // even if the partitions are different | ||
| 55 | + if (leaveOneOut) { | ||
| 56 | + const QString label = partitionedData.at(j).file.get<QString>("Label"); | ||
| 57 | + QList<int> subjectIndices = partitionedData.find("Label",label); | ||
| 58 | + QList<int> removed; | ||
| 59 | + // Remove test only data | ||
| 60 | + for (int k=subjectIndices.size()-1; k>=0; k--) | ||
| 61 | + if (partitionedData[subjectIndices[k]].file.getBool("testOnly")) { | ||
| 62 | + removed.append(subjectIndices[k]); | ||
| 63 | + subjectIndices.removeAt(k); | ||
| 64 | + } | ||
| 65 | + // Remove template that was repeated to make the testOnly template | ||
| 66 | + if (subjectIndices.size() > 1 && subjectIndices.size() <= i) { | ||
| 67 | + removed.append(subjectIndices[i%subjectIndices.size()]); | ||
| 68 | + } | ||
| 69 | + else if (partitionsBuffer[j] == i) { | ||
| 70 | + removed.append(j); | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + if (!removed.empty()) { | ||
| 74 | + typedef QPair<int,int> Pair; | ||
| 75 | + foreach (Pair pair, Common::Sort(removed,true)) { | ||
| 76 | + partitionedData.removeAt(pair.first); partitionsBuffer.removeAt(pair.first); j--; | ||
| 77 | + } | ||
| 78 | + } else { | ||
| 79 | + j--; | ||
| 80 | + } | ||
| 81 | + } else if (partitions[j] == i) { | ||
| 47 | partitionedData.removeAt(j); | 82 | partitionedData.removeAt(j); |
| 83 | + } else j--; | ||
| 84 | + } | ||
| 48 | // Train on the remaining templates | 85 | // Train on the remaining templates |
| 49 | futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); | 86 | futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); |
| 50 | } | 87 | } |
| @@ -53,12 +90,7 @@ class CrossValidateTransform : public MetaTransform | @@ -53,12 +90,7 @@ class CrossValidateTransform : public MetaTransform | ||
| 53 | 90 | ||
| 54 | void project(const Template &src, Template &dst) const | 91 | void project(const Template &src, Template &dst) const |
| 55 | { | 92 | { |
| 56 | - // If the src partition is greater than the number of training partitions, | ||
| 57 | - // assume that projection should be done using the same training data for all partitions. | ||
| 58 | - int partition = src.file.get<int>("Partition", 0); | ||
| 59 | - if (partition >= transforms.size()) partition = 0; | ||
| 60 | - | ||
| 61 | - transforms[partition]->project(src, dst); | 93 | + transforms[src.file.get<int>("Partition", 0)]->project(src, dst); |
| 62 | } | 94 | } |
| 63 | 95 | ||
| 64 | void store(QDataStream &stream) const | 96 | void store(QDataStream &stream) const |