/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * Copyright 2012 The MITRE Corporation * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ #include #include "core/common.h" #include "core/qtutils.h" using namespace br; /**** ALGORITHM_CORE ****/ struct AlgorithmCore { QSharedPointer transform; QSharedPointer distance; AlgorithmCore(const QString &name) { this->name = name; init(name); } bool isClassifier() const { return distance.isNull(); } void train(const File &input, const QString &model) { TemplateList data(TemplateList::fromGallery(input)); if (transform.isNull()) qFatal("Null transform."); qDebug("%d training files", data.size()); QTime time; time.start(); qDebug("Training Enrollment"); transform->train(data); if (!distance.isNull()) { qDebug("Projecting Enrollment"); data >> *transform; qDebug("Training Comparison"); distance->train(data); } if (!model.isEmpty()) { qDebug("Storing %s", qPrintable(QFileInfo(model).fileName())); store(model); } qDebug("Training Time (sec): %d", time.elapsed()/1000); } void store(const QString &model) const { // Create stream QByteArray data; QDataStream out(&data, QFile::WriteOnly); // Serialize algorithm to stream out << name; transform->store(out); const bool hasComparer = !distance.isNull(); out << hasComparer; if (hasComparer) distance->store(out); out << Globals->classes; // Compress and save to file QtUtils::writeFile(model, data, -1); } void load(const QString &model) { // Load from file and decompress QByteArray data; QtUtils::readFile(model, data, true); // Create stream QDataStream in(&data, QFile::ReadOnly); // Load algorithm in >> name; init(Globals->abbreviations.contains(name) ? Globals->abbreviations[name] : name); transform->load(in); bool hasDistance; in >> hasDistance; if (hasDistance) distance->load(in); in >> Globals->classes; } File getMemoryGallery(const File &file) const { return name + file.baseName() + file.hash() + ".mem"; } FileList enroll(File input, File gallery = File()) { FileList fileList; if (gallery.isNull()) gallery = getMemoryGallery(input); QScopedPointer g(Gallery::make(gallery)); if (g.isNull()) return FileList(); if (gallery.contains("read") || gallery.contains("cache")) fileList = g->files(); if (!fileList.isEmpty() && gallery.contains("cache")) return fileList; const TemplateList i(TemplateList::fromGallery(input)); if (i.isEmpty()) return fileList; // Nothing to enroll if (transform.isNull()) qFatal("Null transform."); const int blocks = Globals->blocks(i.size()); Globals->currentStep = 0; Globals->totalSteps = i.size(); Globals->startTime.start(); const bool noDuplicates = gallery.contains("noDuplicates"); QStringList fileNames = noDuplicates ? fileList.names() : QStringList(); const int subBlockSize = 4*std::max(1, Globals->parallelism); const int numSubBlocks = ceil(1.0*Globals->blockSize/subBlockSize); int totalCount = 0, failureCount = 0; double totalBytes = 0; for (int block=0; blockblockSize + subBlock*subBlockSize, subBlockSize); if (data.isEmpty()) break; if (noDuplicates) for (int i=data.size()-1; i>=0; i--) if (fileNames.contains(data[i].file.name)) data.removeAt(i); const int numFiles = data.size(); if (Globals->backProject) { TemplateList backProjectedData; transform->backProject(data, backProjectedData); data = backProjectedData; } else { data >> *transform; } g->writeBlock(data); const FileList newFiles = data.files(); fileList.append(newFiles); totalCount += newFiles.size(); failureCount += newFiles.failures(); totalBytes += data.bytes(); Globals->currentStep += numFiles; Globals->printStatus(); } } const float speed = 1000 * Globals->totalSteps / Globals->startTime.elapsed() / std::max(1, abs(Globals->parallelism)); if (!Globals->quiet && (Globals->totalSteps > 1)) fprintf(stderr, "\rSPEED=%.1e SIZE=%.4g FAILURES=%d/%d \n", speed, totalBytes/totalCount, failureCount, totalCount); Globals->totalSteps = 0; return fileList; } void retrieveOrEnroll(const File &file, QScopedPointer &gallery, FileList &galleryFiles) { gallery.reset(Gallery::make(file)); if ((file.suffix() != "gal") && (file.suffix() != "mem")) { enroll(file); gallery.reset(Gallery::make(getMemoryGallery(file))); galleryFiles = gallery->files(); } else { galleryFiles = gallery->files(); } } void compare(File targetGallery, File queryGallery, File output) { if (output.exists() && output.getBool("cache")) return; if (queryGallery == ".") queryGallery = targetGallery; QScopedPointer t, q; FileList targetFiles, queryFiles; retrieveOrEnroll(targetGallery, t, targetFiles); retrieveOrEnroll(queryGallery, q, queryFiles); QScopedPointer o(Output::make(output, targetFiles, queryFiles)); if (distance.isNull()) qFatal("Null distance."); Globals->currentStep = 0; Globals->totalSteps = double(targetFiles.size()) * double(queryFiles.size()); Globals->startTime.start(); int queryBlock = -1; bool queryDone = false; while (!queryDone) { queryBlock++; TemplateList queries = q->readBlock(&queryDone); int targetBlock = -1; bool targetDone = false; while (!targetDone) { targetBlock++; TemplateList targets = t->readBlock(&targetDone); o->setBlock(queryBlock, targetBlock); distance->compare(targets, queries, o.data()); Globals->currentStep += double(targets.size()) * double(queries.size()); Globals->printStatus(); } } const float speed = 1000 * Globals->totalSteps / Globals->startTime.elapsed() / std::max(1, abs(Globals->parallelism)); if (!Globals->quiet && (Globals->totalSteps > 1)) fprintf(stderr, "\rSPEED=%.1e \n", speed); Globals->totalSteps = 0; } private: QString name; QString getFileName(const QString &description) const { const QString file = Globals->sdkPath + "/share/openbr/models/algorithms/" + description; return QFileInfo(file).exists() ? file : QString(); } void init(QString description) { // Check if a trained binary already exists for this algorithm const QString file = getFileName(description); if (!file.isEmpty()) description = file; if (QFileInfo(description).exists()) { qDebug("Loading %s", qPrintable(QFileInfo(description).fileName())); load(description); return; } // Expand abbreviated algorithms to their full strings if (Globals->abbreviations.contains(description)) return init(Globals->abbreviations[description]); QStringList words = description.split(':'); if (words.size() > 2) qFatal("Invalid algorithm format."); transform = QSharedPointer(Transform::make(words[0], NULL)); if (words.size() > 1) distance = QSharedPointer(Distance::make(words[1], NULL)); } }; class AlgorithmManager : public Initializer { Q_OBJECT public: static QHash > algorithms; static QMutex algorithmsLock; void initialize() const {} void finalize() const { algorithms.clear(); } static QSharedPointer getAlgorithm(const QString &algorithm) { if (algorithm.isEmpty()) qFatal("No default algorithm set."); if (!algorithms.contains(algorithm)) { // Some algorithms are recursive, so we need to construct them outside the lock. QSharedPointer algorithmCore(new AlgorithmCore(algorithm)); algorithmsLock.lock(); if (!algorithms.contains(algorithm)) algorithms.insert(algorithm, algorithmCore); algorithmsLock.unlock(); } return algorithms[algorithm]; } }; QHash > AlgorithmManager::algorithms; QMutex AlgorithmManager::algorithmsLock; BR_REGISTER(Initializer, AlgorithmManager) bool br::IsClassifier(const QString &algorithm) { qDebug("Checking if %s is a classifier", qPrintable(algorithm)); return AlgorithmManager::getAlgorithm(algorithm)->isClassifier(); } void br::Train(const File &input, const File &model) { qDebug("Training on %s%s", qPrintable(input.flat()), model.isNull() ? "" : qPrintable(" to " + model.flat())); AlgorithmManager::getAlgorithm(model.getString("algorithm"))->train(input, model); } FileList br::Enroll(const File &input, const File &gallery) { qDebug("Enrolling %s%s", qPrintable(input.flat()), gallery.isNull() ? "" : qPrintable(" to " + gallery.flat())); return AlgorithmManager::getAlgorithm(gallery.getString("algorithm"))->enroll(input, gallery); } void br::Compare(const File &targetGallery, const File &queryGallery, const File &output) { qDebug("Comparing %s and %s%s", qPrintable(targetGallery.flat()), qPrintable(queryGallery.flat()), output.isNull() ? "" : qPrintable(" to " + output.flat())); AlgorithmManager::getAlgorithm(output.getString("algorithm"))->compare(targetGallery, queryGallery, output); } void br::Convert(const File &src, const File &dst) { qDebug("Converting %s to %s", qPrintable(src.flat()), qPrintable(dst.flat())); QScopedPointer before(Factory::make(src)); QScopedPointer after(Factory::make(dst)); after->write(before->read()); } void br::Cat(const QStringList &inputGalleries, const QString &outputGallery) { qDebug("Concatenating %d galleries to %s", inputGalleries.size(), qPrintable(outputGallery)); foreach (const QString &inputGallery, inputGalleries) if (inputGallery == outputGallery) qFatal("outputGallery must not be in inputGalleries."); QScopedPointer og(Gallery::make(outputGallery)); foreach (const QString &inputGallery, inputGalleries) { QScopedPointer ig(Gallery::make(inputGallery)); bool done = false; while (!done) og->writeBlock(ig->readBlock(&done)); } } QSharedPointer br::Transform::fromAlgorithm(const QString &algorithm) { return AlgorithmManager::getAlgorithm(algorithm)->transform; } QSharedPointer br::Distance::fromAlgorithm(const QString &algorithm) { return AlgorithmManager::getAlgorithm(algorithm)->distance; } #include "core.moc"