#include #include #include namespace br { /*! * \ingroup transforms * \brief Cross validate a trainable transform. * \author Josh Klontz \cite jklontz */ class CrossValidateTransform : public MetaTransform { Q_OBJECT Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) BR_PROPERTY(QString, description, "Identity") QList transforms; void train(const TemplateList &data) { int numPartitions = 0; QList partitions; partitions.reserve(data.size()); foreach (const File &file, data.files()) { partitions.append(file.get("Cross_Validation_Partition", 0)); numPartitions = std::max(numPartitions, partitions.last()+1); } while (transforms.size() < numPartitions) transforms.append(make(description)); if (numPartitions < 2) { transforms.first()->train(data); return; } QFutureSynchronizer futures; for (int i=0; i=0; j--) if (partitions[j] == i) partitionedData.removeAt(j); if (Globals->parallelism) futures.addFuture(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); else transforms[i]->train(partitionedData); } futures.waitForFinished(); } void project(const Template &src, Template &dst) const { transforms[src.file.get("Cross_Validation_Partition", 0)]->project(src, dst); } void store(QDataStream &stream) const { stream << transforms.size(); foreach (Transform *transform, transforms) transform->store(stream); } void load(QDataStream &stream) { int numTransforms; stream >> numTransforms; while (transforms.size() < numTransforms) transforms.append(make(description)); foreach (Transform *transform, transforms) transform->load(stream); } }; BR_REGISTER(Transform, CrossValidateTransform) /*! * \ingroup distances * \brief Cross validate a distance metric. * \author Josh Klontz \cite jklontz */ class CrossValidateDistance : public Distance { Q_OBJECT float compare(const Template &a, const Template &b) const { const int partitionA = a.file.get("Cross_Validation_Partition", 0); const int partitionB = b.file.get("Cross_Validation_Partition", 0); return (partitionA != partitionB) ? -std::numeric_limits::max() : 0; } }; BR_REGISTER(Distance, CrossValidateDistance) /*! * \ingroup distances * \brief Checks target metadata against filters. * \author Josh Klontz \cite jklontz */ class FilterDistance : public Distance { Q_OBJECT float compare(const Template &a, const Template &b) const { (void) b; // Query template isn't checked foreach (const QString &key, Globals->filters.keys()) { bool keep = false; const QString metadata = a.file.get(key, ""); if (metadata.isEmpty()) continue; foreach (const QString &value, Globals->filters[key]) { if (metadata == value) { keep = true; break; } } if (!keep) return -std::numeric_limits::max(); } return 0; } }; BR_REGISTER(Distance, FilterDistance) } // namespace br #include "validate.moc"