diff --git a/sdk/core/bee.cpp b/sdk/core/bee.cpp index bf26fbc..d5f70d6 100644 --- a/sdk/core/bee.cpp +++ b/sdk/core/bee.cpp @@ -47,6 +47,9 @@ FileList BEE::readSigset(QString sigset, bool ignoreMetadata) file.close(); QDomElement docElem = doc.documentElement(); + if (docElem.nodeName() != "biometric-signature-set") + return fileList; + QDomNode subject = docElem.firstChild(); while (!subject.isNull()) { // Looping through subjects diff --git a/sdk/core/plot.cpp b/sdk/core/plot.cpp index 8ee8fd2..1499138 100644 --- a/sdk/core/plot.cpp +++ b/sdk/core/plot.cpp @@ -431,7 +431,7 @@ bool br::Plot(const QStringList &files, const QString &destination, bool show) p.file.write(qPrintable(QString("qplot(X, Y, data=DET, geom=\"line\"") + (p.majorSize > 1 ? QString(", colour=factor(%1)").arg(p.majorHeader) : QString()) + (p.minorSize > 1 ? QString(", linetype=factor(%1)").arg(p.minorHeader) : QString()) + - QString(", xlab=\"False Accept Rate\", ylab=\"False Reject Rate\") + geom_abline(alpha=0.5, colour=\"grey\", linetype=\"dashed\") + theme_bw()") + + QString(", xlab=\"False Accept Rate\", ylab=\"False Reject Rate\") + geom_abline(alpha=0.5, colour=\"grey\", linetype=\"dashed\") + theme_minimal()") + (p.majorSize > 1 ? getScale("colour", p.majorHeader, p.majorSize) : QString()) + (p.minorSize > 1 ? QString(" + scale_linetype_discrete(\"%1\")").arg(p.minorHeader) : QString()) + QString(" + scale_x_continuous(trans=\"log10\") + scale_y_continuous(trans=\"log10\")") + @@ -440,7 +440,7 @@ bool br::Plot(const QStringList &files, const QString &destination, bool show) p.file.write(qPrintable(QString("qplot(X, 1-Y, data=DET, geom=\"line\"") + (p.majorSize > 1 ? QString(", colour=factor(%1)").arg(p.majorHeader) : QString()) + (p.minorSize > 1 ? QString(", linetype=factor(%1)").arg(p.minorHeader) : QString()) + - QString(", xlab=\"False Accept Rate\", ylab=\"True Accept Rate\") + theme_bw()") + + QString(", xlab=\"False Accept Rate\", ylab=\"True Accept Rate\") + theme_minimal()") + (p.majorSize > 1 ? getScale("colour", p.majorHeader, p.majorSize) : QString()) + (p.minorSize > 1 ? QString(" + scale_linetype_discrete(\"%1\")").arg(p.minorHeader) : QString()) + QString(" + scale_x_continuous(trans=\"log10\") + scale_y_continuous(labels=percent)") + @@ -449,15 +449,15 @@ bool br::Plot(const QStringList &files, const QString &destination, bool show) p.file.write(qPrintable(QString("qplot(X, data=SD, geom=\"histogram\", fill=Y, position=\"identity\", alpha=I(1/2)") + QString(", xlab=\"Score%1\"").arg((p.flip ? p.majorSize : p.minorSize) > 1 ? " / " + (p.flip ? p.majorHeader : p.minorHeader) : QString()) + QString(", ylab=\"Frequency%1\"").arg((p.flip ? p.minorSize : p.majorSize) > 1 ? " / " + (p.flip ? p.minorHeader : p.majorHeader) : QString()) + - QString(") + scale_fill_manual(\"Ground Truth\", values=c(\"blue\", \"red\")) + theme_bw() + scale_x_continuous(minor_breaks=NULL) + scale_y_continuous(minor_breaks=NULL) + opts(axis.text.y=theme_blank(), axis.ticks=theme_blank(), axis.text.x=theme_text(angle=-90, hjust=0))") + + QString(") + scale_fill_manual(\"Ground Truth\", values=c(\"blue\", \"red\")) + theme_minimal() + scale_x_continuous(minor_breaks=NULL) + scale_y_continuous(minor_breaks=NULL) + theme(axis.text.y=element_blank(), axis.ticks=element_blank(), axis.text.x=element_text(angle=-90, hjust=0))") + (p.majorSize > 1 ? (p.minorSize > 1 ? QString(" + facet_grid(%2 ~ %1, scales=\"free\")").arg((p.flip ? p.majorHeader : p.minorHeader), (p.flip ? p.minorHeader : p.majorHeader)) : QString(" + facet_wrap(~ %1, scales = \"free\")").arg(p.majorHeader)) : QString()) + - QString(" + opts(aspect.ratio=1)") + + QString(" + theme(aspect.ratio=1)") + QString("\nggsave(\"%1\")\n").arg(p.subfile("SD")))); p.file.write(qPrintable(QString("qplot(X, Y, data=CMC, geom=\"line\", xlab=\"Rank\", ylab=\"Retrieval Rate\"") + (p.majorSize > 1 ? QString(", colour=factor(%1)").arg(p.majorHeader) : QString()) + (p.minorSize > 1 ? QString(", linetype=factor(%1)").arg(p.minorHeader) : QString()) + - QString(") + theme_bw() + scale_x_continuous(limits = c(1,25), breaks = c(1,5,10,25))") + + QString(") + theme_minimal() + scale_x_continuous(limits = c(1,25), breaks = c(1,5,10,25))") + (p.majorSize > 1 ? getScale("colour", p.majorHeader, p.majorSize) : QString()) + (p.minorSize > 1 ? QString(" + scale_linetype_discrete(\"%1\")").arg(p.minorHeader) : QString()) + QString(" + scale_y_continuous(labels=percent)") + @@ -466,28 +466,28 @@ bool br::Plot(const QStringList &files, const QString &destination, bool show) p.file.write(qPrintable(QString("qplot(factor(%1), data=BC, geom=\"bar\", position=\"dodge\", weight=Y").arg(p.majorHeader) + (p.majorSize > 1 ? QString(", fill=factor(%1)").arg(p.majorHeader) : QString()) + QString(", xlab=\"%1False Accept Rate\"").arg(p.majorSize > 1 ? p.majorHeader + " / " : QString()) + - QString(", ylab=\"True Accept Rate%1\") + theme_bw()").arg(p.minorSize > 1 ? " / " + p.minorHeader : QString()) + + QString(", ylab=\"True Accept Rate%1\") + theme_minimal()").arg(p.minorSize > 1 ? " / " + p.minorHeader : QString()) + (p.majorSize > 1 ? getScale("fill", p.majorHeader, p.majorSize) : QString()) + (p.minorSize > 1 ? QString(" + facet_grid(%2 ~ X)").arg(p.minorHeader) : QString(" + facet_wrap(~ X)")) + - QString(" + opts(legend.position=\"none\", axis.text.x=theme_text(angle=-90, hjust=0)) + geom_text(data=BC, aes(label=Y, y=0.05))") + + QString(" + theme(legend.position=\"none\", axis.text.x=element_text(angle=-90, hjust=0)) + geom_text(data=BC, aes(label=Y, y=0.05))") + QString("\nggsave(\"%1\")\n").arg(p.subfile("BC")))); p.file.write(qPrintable(QString("qplot(X, Y, data=FAR, geom=\"line\"") + ((p.flip ? p.majorSize : p.minorSize) > 1 ? QString(", colour=factor(%1)").arg(p.flip ? p.majorHeader : p.minorHeader) : QString()) + - QString(", xlab=\"Score%1\", ylab=\"False Accept Rate\") + theme_bw()").arg((p.flip ? p.minorSize : p.majorSize) > 1 ? " / " + (p.flip ? p.minorHeader : p.majorHeader) : QString()) + + QString(", xlab=\"Score%1\", ylab=\"False Accept Rate\") + theme_minimal()").arg((p.flip ? p.minorSize : p.majorSize) > 1 ? " / " + (p.flip ? p.minorHeader : p.majorHeader) : QString()) + ((p.flip ? p.majorSize : p.minorSize) > 1 ? getScale("colour", p.flip ? p.majorHeader : p.minorHeader, p.flip ? p.majorSize : p.minorSize) : QString()) + QString(" + scale_y_continuous(trans=\"log10\")") + ((p.flip ? p.minorSize : p.majorSize) > 1 ? QString(" + facet_wrap(~ %1, scales=\"free_x\")").arg(p.flip ? p.minorHeader : p.majorHeader) : QString()) + - QString(" + opts(aspect.ratio=1)") + + QString(" + theme(aspect.ratio=1)") + QString("\nggsave(\"%1\")\n").arg(p.subfile("FAR")))); p.file.write(qPrintable(QString("qplot(X, Y, data=FRR, geom=\"line\"") + ((p.flip ? p.majorSize : p.minorSize) > 1 ? QString(", colour=factor(%1)").arg(p.flip ? p.majorHeader : p.minorHeader) : QString()) + - QString(", xlab=\"Score%1\", ylab=\"False Reject Rate\") + theme_bw()").arg((p.flip ? p.minorSize : p.majorSize) > 1 ? " / " + (p.flip ? p.minorHeader : p.majorHeader) : QString()) + + QString(", xlab=\"Score%1\", ylab=\"False Reject Rate\") + theme_minimal()").arg((p.flip ? p.minorSize : p.majorSize) > 1 ? " / " + (p.flip ? p.minorHeader : p.majorHeader) : QString()) + ((p.flip ? p.majorSize : p.minorSize) > 1 ? getScale("colour", p.flip ? p.majorHeader : p.minorHeader, p.flip ? p.majorSize : p.minorSize) : QString()) + QString(" + scale_y_continuous(trans=\"log10\")") + ((p.flip ? p.minorSize : p.majorSize) > 1 ? QString(" + facet_wrap(~ %1, scales=\"free_x\")").arg(p.flip ? p.minorHeader : p.majorHeader) : QString()) + - QString(" + opts(aspect.ratio=1)") + + QString(" + theme(aspect.ratio=1)") + QString("\nggsave(\"%1\")\n").arg(p.subfile("FRR")))); return p.finalize(show); @@ -499,6 +499,6 @@ bool br::PlotMetadata(const QStringList &files, const QString &columns, bool sho RPlot p(files, "PlotMetadata", false); foreach (const QString &column, columns.split(";")) - p.file.write(qPrintable(QString("qplot(%1, %2, data=data, geom=\"violin\", fill=%1) + coord_flip() + theme_bw()\nggsave(\"%2.pdf\")\n").arg(p.majorHeader, column))); + p.file.write(qPrintable(QString("qplot(%1, %2, data=data, geom=\"violin\", fill=%1) + coord_flip() + theme_minimal()\nggsave(\"%2.pdf\")\n").arg(p.majorHeader, column))); return p.finalize(show); } diff --git a/sdk/openbr_plugin.cpp b/sdk/openbr_plugin.cpp index ce4cb8f..3941a46 100644 --- a/sdk/openbr_plugin.cpp +++ b/sdk/openbr_plugin.cpp @@ -112,13 +112,12 @@ float File::label() const const QVariant variant = value("Label"); if (variant.isNull()) return -1; - if (variant.canConvert(QVariant::Double)) { - bool ok; - float val = variant.toFloat(&ok); - if (ok) return val; - } + if (Globals->classes.contains(variant.toString())) + return Globals->classes.value(variant.toString()); - return Globals->classes.value(variant.toString(), -1); + bool ok; + const float val = variant.toFloat(&ok); + return ok ? val : -1; } void File::remove(const QString &key) @@ -130,10 +129,17 @@ void File::set(const QString &key, const QVariant &value) { if (key == "Label") { bool ok = false; - if (value.canConvert(QVariant::Double)) + const QString valueString = value.toString(); + + /* We assume that if the value starts with '0' + then it was probably intended to to be a string UID + and that it's numerical value is not relevant. */ + if (value.canConvert(QVariant::Double) && + (!valueString.startsWith('0') || (valueString == "0"))) value.toFloat(&ok); - if (!ok && !Globals->classes.contains(value.toString())) - Globals->classes.insert(value.toString(), Globals->classes.size()); + + if (!ok && !Globals->classes.contains(valueString)) + Globals->classes.insert(valueString, Globals->classes.size()); } m_metadata.insert(key, value); @@ -377,6 +383,10 @@ TemplateList TemplateList::fromInput(const br::File &input) QScopedPointer i(Gallery::make(file)); TemplateList newTemplates = i->read(); + // If input is a Format not a Gallery + if (newTemplates.isEmpty()) + newTemplates.append(input); + // Propogate metadata for (int i=0; i()->description(); + } else if (type == "QStringList") { + return "[" + variant.toStringList().join(",") + "]"; } return variant.toString(); @@ -484,6 +496,10 @@ void Object::store(QDataStream &stream) const stream << property.read(this).toFloat(); } else if (type == "double") { stream << property.read(this).toDouble(); + } else if (type == "QString") { + stream << property.read(this).toString(); + } else if (type == "QStringList") { + stream << property.read(this).toStringList(); } else { qFatal("Can't serialize value of type: %s", qPrintable(type)); } @@ -520,6 +536,14 @@ void Object::load(QDataStream &stream) double value; stream >> value; property.write(this, value); + } else if (type == "QString") { + QString value; + stream >> value; + property.write(this, value); + } else if (type == "QStringList") { + QStringList value; + stream >> value; + property.write(this, value); } else { qFatal("Can't serialize value of type: %s", qPrintable(type)); } @@ -560,6 +584,8 @@ void Object::setProperty(const QString &name, const QString &value) } } else if (type == "br::Transform*") { variant.setValue(Transform::make(value, this)); + } else if (type == "QStringList") { + variant.setValue(parse(value.mid(1, value.size()-2))); } else if (type == "bool") { if (value.isEmpty()) variant = true; else if (value == "false") variant = false; @@ -1218,6 +1244,29 @@ void Distance::compare(const TemplateList &target, const TemplateList &query, Ou if (Globals->parallelism) Globals->trackFutures(futures); } +float Distance::compare(const Template &target, const Template &query) const +{ + if (!Globals->demographicFilters.isEmpty()) { + // The if statement is a faster check then iterating over an empty list of filters + foreach (const QString &filter, Globals->demographicFilters) { + const QString targetMetadata = target.file.getString(filter, ""); + const QString queryMetadata = query.file.getString(filter, ""); + if (targetMetadata.isEmpty() || queryMetadata.isEmpty()) continue; + if (targetMetadata != queryMetadata) return -std::numeric_limits::max(); + } + } + + if (Globals->ageDelta < std::numeric_limits::max()) { + const float targetAge = target.file.getFloat("Age", -1); + const float queryAge = target.file.getFloat("Age", -1); + if ((targetAge != -1) && (queryAge != -1) && (abs(targetAge - queryAge) > Globals->ageDelta)) + return -std::numeric_limits::max(); + } + + return a * (_compare(target, query) - b); +} + +/* Distance - private methods */ void Distance::compareBlock(const TemplateList &target, const TemplateList &query, Output *output, int targetOffset, int queryOffset) const { for (int i=0; i::max()) + QHash abbreviations; /*!< \brief Used by br::Transform::make() to expand abbreviated algorithms into their complete definitions. */ QHash classes; /*!< \brief Used by classifiers to associate text class labels with unique integers IDs. */ QTime startTime; /*!< \brief Used to estimate timeRemaining(). */ @@ -821,9 +833,9 @@ private: /*! * \ingroup formats - * \brief Plugin base class for reading matrices from disk. + * \brief Plugin base class for reading a template from disk. * - * A \em format is a br::File representing a matrix (ex. jpg image) on disk. + * A \em format is a br::File representing a template (ex. jpg image) on disk. * br::File::suffix() is used to determine which plugin should handle the format. */ class BR_EXPORT Format : public Object @@ -832,7 +844,7 @@ class BR_EXPORT Format : public Object public: virtual ~Format() {} - virtual QList read() const = 0; /*!< \brief Returns a list of matrices created by reading #br::Object::file. */ + virtual Template read() const = 0; /*!< \brief Returns a br::Template created by reading #br::Object::file. */ }; /*! @@ -1034,7 +1046,7 @@ public: static QSharedPointer fromAlgorithm(const QString &algorithm); /*!< \brief Retrieve an algorithm's distance. */ virtual void train(const TemplateList &src); /*!< \brief Train the distance. */ virtual void compare(const TemplateList &target, const TemplateList &query, Output *output) const; /*!< \brief Compare two template lists. */ - inline float compare(const Template &target, const Template &query) const { return a * (_compare(target, query) - b); } /*!< \brief Compute the normalized distance between two templates. */ + float compare(const Template &target, const Template &query) const; /*!< \brief Compute the normalized distance between two templates. */ private: virtual void compareBlock(const TemplateList &target, const TemplateList &query, Output *output, int targetOffset, int queryOffset) const; diff --git a/sdk/plugins/compare.cpp b/sdk/plugins/compare.cpp index 62b52d5..5deb254 100644 --- a/sdk/plugins/compare.cpp +++ b/sdk/plugins/compare.cpp @@ -42,7 +42,7 @@ public: INF, L1, L2, - CosineSimilarity }; + Cosine }; private: BR_PROPERTY(Metric, metric, L2) @@ -76,8 +76,8 @@ private: case L2: result = norm(a, b, NORM_L2); break; - case CosineSimilarity: - result = cosineSimilarity(a, b); + case Cosine: + result = cosine(a, b); break; default: qFatal("Invalid metric"); @@ -89,33 +89,29 @@ private: return -log(result+1); } - static float cosineSimilarity(const Mat &a, const Mat &b) + static float cosine(const Mat &a, const Mat &b) { - assert((a.type() == CV_32FC1) && (b.type() == CV_32FC1)); - assert((a.rows == b.rows) && (a.cols == b.cols)); - - float denom = 0; - float tnum = 0; - float qnum = 0; + float dot = 0; + float magA = 0; + float magB = 0; for (int row=0; row(row,col); - float query = b.at(row,col); + const float target = a.at(row,col); + const float query = b.at(row,col); - denom += target * query; - tnum += target * target; - qnum += query * query; + dot += target * query; + magA += target * target; + magB += query * query; } } - return denom / (sqrt(tnum)*sqrt(qnum)); + return dot / (sqrt(magA)*sqrt(magB)); } }; BR_REGISTER(Distance, Dist) - /*! * \ingroup distances * \brief Fast 8-bit L1 distance diff --git a/sdk/plugins/eigen3.cpp b/sdk/plugins/eigen3.cpp index d9c2c99..e5a3154 100644 --- a/sdk/plugins/eigen3.cpp +++ b/sdk/plugins/eigen3.cpp @@ -276,10 +276,10 @@ class LDA : public Transform int dimsIn = ldaTrainingSet.first().m().rows * ldaTrainingSet.first().m().cols; - // MM ensures that classes values range from 0 to numClasses-1. - QList classes = trainingSet.labels(); // PCA doesn't project metadata + // OpenBR ensures that class values range from 0 to numClasses-1. + QList classes = trainingSet.labels(); QMap classCounts = trainingSet.labelCounts(); - int numClasses = classCounts.size(); + const int numClasses = classCounts.size(); // Map Eigen into OpenCV Eigen::MatrixXd data = Eigen::MatrixXd(dimsIn, instances); diff --git a/sdk/plugins/filter.cpp b/sdk/plugins/filter.cpp index 15fe345..1e28ba5 100644 --- a/sdk/plugins/filter.cpp +++ b/sdk/plugins/filter.cpp @@ -153,6 +153,7 @@ class CSDN : public UntrainableTransform } } + m.convertTo(m, CV_8UC1); dst = m; } diff --git a/sdk/plugins/format.cpp b/sdk/plugins/format.cpp index ab8960d..9e1e7b2 100644 --- a/sdk/plugins/format.cpp +++ b/sdk/plugins/format.cpp @@ -14,9 +14,11 @@ * limitations under the License. * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ +#include #ifndef BR_EMBEDDED #include #include +#include #endif // BR_EMBEDDED #include #include @@ -33,7 +35,7 @@ class csvFormat : public Format { Q_OBJECT - QList read() const + Template read() const { QFile f(file.name); f.open(QFile::ReadOnly); @@ -59,9 +61,7 @@ class csvFormat : public Format } } - QList mats; - mats.append(m); - return mats; + return Template(m); } }; @@ -76,9 +76,9 @@ class DefaultFormat : public Format { Q_OBJECT - QList read() const + Template read() const { - QList mats; + Template t; if (file.name.startsWith("http://") || file.name.startsWith("www.")) { #ifndef BR_EMBEDDED @@ -94,16 +94,16 @@ class DefaultFormat : public Format delete reply; Mat m = imdecode(Mat(1, data.size(), CV_8UC1, data.data()), 1); - if (m.data) mats.append(m); + if (m.data) t.append(m); #endif // BR_EMBEDDED } else { QString prefix = ""; if (!QFileInfo(file.name).exists()) prefix = file.getString("path") + "/"; Mat m = imread((prefix+file.name).toStdString()); - if (m.data) mats.append(m); + if (m.data) t.append(m); } - return mats; + return t; } }; @@ -118,7 +118,7 @@ class webcamFormat : public Format { Q_OBJECT - QList read() const + Template read() const { static QScopedPointer videoCapture; @@ -127,11 +127,73 @@ class webcamFormat : public Format Mat m; videoCapture->read(m); - - return QList() << m; + return Template(m); } }; BR_REGISTER(Format, webcamFormat) +#ifndef BR_EMBEDDED +/*! + * \ingroup formats + * \brief Decodes images from Base64 xml + * \author Scott Klum \cite sklum + * \author Josh Klontz \cite jklontz + */ +class xmlFormat : public Format +{ + Q_OBJECT + + Template read() const + { + QDomDocument doc(file); + QFile f(file); + if (!f.open(QIODevice::ReadOnly)) qFatal("xmlFormat::read unable to open %s for reading.", qPrintable(file.flat())); + if (!doc.setContent(&f)) qFatal("xmlFormat::read unable to parse %s.", qPrintable(file.flat())); + f.close(); + + Template t; + QDomElement docElem = doc.documentElement(); + QDomNode subject = docElem.firstChild(); + while (!subject.isNull()) { + QDomNode fileNode = subject.firstChild(); + + while (!fileNode.isNull()) { + QDomElement e = fileNode.toElement(); + + if (e.tagName() == "FORMAL_IMG") { + QByteArray byteArray = QByteArray::fromBase64(qPrintable(e.text())); + Mat m = imdecode(Mat(1, byteArray.size(), CV_8UC1, byteArray.data()), CV_LOAD_IMAGE_ANYDEPTH); + if (!m.data) qWarning("xmlFormat::read failed to decode image data."); + t.append(m); + } else if ((e.tagName() == "RELEASE_IMG") || + (e.tagName() == "PREBOOK_IMG") || + (e.tagName() == "LPROFILE") || + (e.tagName() == "RPROFILE")) { + // Ignore these other image fields for now + } else { + t.file.insert(e.tagName(), e.text()); + } + + fileNode = fileNode.nextSibling(); + } + subject = subject.nextSibling(); + } + + // Calculate age + if (t.file.contains("DOB")) { + const QDate dob = QDate::fromString(t.file.getString("DOB").left(10), "yyyy-MM-dd"); + const QDate current = QDate::currentDate(); + int age = current.year() - dob.year(); + if (current.month() < dob.month()) age--; + t.file.insert("Age", age); + } + + return t; + } +}; + +BR_REGISTER(Format, xmlFormat) +#endif // BR_EMBEDDED + #include "format.moc" diff --git a/sdk/plugins/llvm.cpp b/sdk/plugins/llvm.cpp index 19cfdce..6cdd96f 100644 --- a/sdk/plugins/llvm.cpp +++ b/sdk/plugins/llvm.cpp @@ -65,6 +65,22 @@ static Matrix MatrixFromMat(const cv::Mat &mat) return m; } +static Mat MatFromMatrix(const Matrix &m) +{ + int depth = -1; + switch (m.type()) { + case Matrix::u8: depth = CV_8U; break; + case Matrix::s8: depth = CV_8S; break; + case Matrix::u16: depth = CV_16U; break; + case Matrix::s16: depth = CV_16S; break; + case Matrix::s32: depth = CV_32S; break; + case Matrix::f32: depth = CV_32F; break; + case Matrix::f64: depth = CV_64F; break; + default: qFatal("Unrecognized matrix depth."); + } + return Mat(m.rows, m.columns, CV_MAKETYPE(depth, m.channels), m.data).clone(); +} + static void AllocateMatrixFromMat(Matrix &m, cv::Mat &mat) { int cvType = -1; @@ -230,26 +246,29 @@ struct MatrixBuilder : public Matrix Value *compareLT(Value *i, Value *j) const { return isFloating() ? b->CreateFCmpOLT(i, j) : (isSigned() ? b->CreateICmpSLT(i, j) : b->CreateICmpULT(i, j)); } Value *compareGT(Value *i, Value *j) const { return isFloating() ? b->CreateFCmpOGT(i, j) : (isSigned() ? b->CreateICmpSGT(i, j) : b->CreateICmpUGT(i, j)); } - static PHINode *beginLoop(IRBuilder<> &builder, Function *function, BasicBlock *parent, BasicBlock **current, const Twine &name = "") { - *current = BasicBlock::Create(getGlobalContext(), "loop_"+name, function); - builder.CreateBr(*current); - builder.SetInsertPoint(*current); - PHINode *j = builder.CreatePHI(Type::getInt32Ty(getGlobalContext()), 2, name); - j->addIncoming(MatrixBuilder::zero(), parent); - return j; - } - PHINode *beginLoop(BasicBlock *parent, BasicBlock **current, const Twine &name = "") const { return beginLoop(*b, f, parent, current, name); } - static void endLoop(IRBuilder<> &builder, Function *function, BasicBlock *current, PHINode *j, Value *end, const Twine &name = "") { - BasicBlock *loop = BasicBlock::Create(getGlobalContext(), "loop_"+name+"_end", function); + static PHINode *beginLoop(IRBuilder<> &builder, Function *function, BasicBlock *entry, BasicBlock *&loop, BasicBlock *&exit, Value *stop, const Twine &name = "") { + loop = BasicBlock::Create(getGlobalContext(), "loop_"+name, function); builder.CreateBr(loop); builder.SetInsertPoint(loop); - Value *increment = builder.CreateAdd(j, MatrixBuilder::one(), "increment_"+name); - j->addIncoming(increment, loop); - BasicBlock *exit = BasicBlock::Create(getGlobalContext(), "loop_"+name+"_exit", function); - builder.CreateCondBr(builder.CreateICmpNE(increment, end, "loop_"+name+"_test"), current, exit); + + PHINode *i = builder.CreatePHI(Type::getInt32Ty(getGlobalContext()), 2, name); + i->addIncoming(MatrixBuilder::zero(), entry); + Value *increment = builder.CreateAdd(i, MatrixBuilder::one(), "increment_"+name); + BasicBlock *body = BasicBlock::Create(getGlobalContext(), "loop_"+name+"_body", function); + i->addIncoming(increment, body); + + exit = BasicBlock::Create(getGlobalContext(), "loop_"+name+"_exit", function); + builder.CreateCondBr(builder.CreateICmpEQ(i, stop, "loop_"+name+"_test"), exit, body); + builder.SetInsertPoint(body); + return i; + } + PHINode *beginLoop(BasicBlock *entry, BasicBlock *&loop, BasicBlock *&exit, Value *stop, const Twine &name = "") const { return beginLoop(*b, f, entry, loop, exit, stop, name); } + + static void endLoop(IRBuilder<> &builder, BasicBlock *loop, BasicBlock *exit) { + builder.CreateBr(loop); builder.SetInsertPoint(exit); } - void endLoop(BasicBlock *current, PHINode *j, Value *end, const Twine &name = "") const { endLoop(*b, f, current, j, end, name); } + void endLoop(BasicBlock *loop, BasicBlock *exit) const { endLoop(*b, loop, exit); } template inline static std::vector toVector(T value) { std::vector vector; vector.push_back(value); return vector; } @@ -440,14 +459,14 @@ private: BasicBlock *entry = BasicBlock::Create(getGlobalContext(), "entry", function); IRBuilder<> builder(entry); - BasicBlock *kernel; - PHINode *i = MatrixBuilder::beginLoop(builder, function, entry, &kernel, "i"); + BasicBlock *loop, *exit; + PHINode *i = MatrixBuilder::beginLoop(builder, function, entry, loop, exit, len, "i"); Matrix n; preallocate(m, n); build(MatrixBuilder(m, src, &builder, function, "src"), MatrixBuilder(n, dst, &builder, function, "dst"), i); - MatrixBuilder::endLoop(builder, function, kernel, i, len, "i"); + MatrixBuilder::endLoop(builder, loop, exit); builder.CreateRetVoid(); return function; @@ -533,14 +552,14 @@ private: BasicBlock *entry = BasicBlock::Create(getGlobalContext(), "entry", function); IRBuilder<> builder(entry); - BasicBlock *kernel; - PHINode *i = MatrixBuilder::beginLoop(builder, function, entry, &kernel, "i"); + BasicBlock *loop, *exit; + PHINode *i = MatrixBuilder::beginLoop(builder, function, entry, loop, exit, len, "i"); Matrix o; preallocate(m, n, o); build(MatrixBuilder(m, srcA, &builder, function, "srcA"), MatrixBuilder(n, srcB, &builder, function, "srcB"), MatrixBuilder(o, dst, &builder, function, "dst"), i); - MatrixBuilder::endLoop(builder, function, kernel, i, len, "i"); + MatrixBuilder::endLoop(builder, loop, exit); builder.CreateRetVoid(); return function; @@ -744,53 +763,52 @@ class sumTransform : public UnaryKernel dst.deindex(i, &c, &x, &y, &t); AllocaInst *sum = dst.autoAlloca(0, "sum"); - QList loops; - QList blocks; - blocks.push_back(i->getParent()); + QList loops, exits; + loops.push_back(i->getParent()); Value *src_c, *src_x, *src_y, *src_t; if (frames && !src.singleFrame()) { - BasicBlock *block; - loops.append(dst.beginLoop(blocks.last(), &block, "src_t")); - blocks.append(block); - src_t = loops.last(); + BasicBlock *loop, *exit; + src_t = dst.beginLoop(loops.last(), loop, exit, src.getFrames(), "src_t"); + loops.append(loop); + exits.append(exit); } else { src_t = t; } if (rows && !src.singleRow()) { - BasicBlock *block; - loops.append(dst.beginLoop(blocks.last(), &block, "src_y")); - blocks.append(block); - src_y = loops.last(); + BasicBlock *loop, *exit; + src_y = dst.beginLoop(loops.last(), loop, exit, src.getRows(), "src_y"); + loops.append(loop); + exits.append(exit); } else { src_y = y; } if (columns && !src.singleColumn()) { - BasicBlock *block; - loops.append(dst.beginLoop(blocks.last(), &block, "src_x")); - blocks.append(block); - src_x = loops.last(); + BasicBlock *loop, *exit; + src_x = dst.beginLoop(loops.last(), loop, exit, src.getColumns(), "src_x"); + loops.append(loop); + exits.append(exit); } else { src_x = x; } if (channels && !src.singleChannel()) { - BasicBlock *block; - loops.append(dst.beginLoop(blocks.last(), &block, "src_c")); - blocks.append(block); - src_c = loops.last(); + BasicBlock *loop, *exit; + src_c = dst.beginLoop(loops.last(), loop, exit, src.getChannels(), "src_c"); + loops.append(loop); + exits.append(exit); } else { src_c = c; } dst.b->CreateStore(dst.add(dst.b->CreateLoad(sum), src.cast(src.load(src.aliasIndex(dst, src_c, src_x, src_y, src_t)), dst), "accumulate"), sum); - if (channels && !src.singleChannel()) dst.endLoop(blocks.takeLast(), loops.takeLast(), src.getChannels(), "src_c"); - if (columns && !src.singleColumn()) dst.endLoop(blocks.takeLast(), loops.takeLast(), src.getColumns(), "src_x"); - if (rows && !src.singleRow()) dst.endLoop(blocks.takeLast(), loops.takeLast(), src.getRows(), "src_y"); - if (frames && !src.singleFrame()) dst.endLoop(blocks.takeLast(), loops.takeLast(), src.getFrames(), "src_t"); + if (channels && !src.singleChannel()) dst.endLoop(loops.takeLast(), exits.takeLast()); + if (columns && !src.singleColumn()) dst.endLoop(loops.takeLast(), exits.takeLast()); + if (rows && !src.singleRow()) dst.endLoop(loops.takeLast(), exits.takeLast()); + if (frames && !src.singleFrame()) dst.endLoop(loops.takeLast(), exits.takeLast()); dst.store(i, dst.b->CreateLoad(sum)); } diff --git a/sdk/plugins/meta.cpp b/sdk/plugins/meta.cpp index f6240bb..84f95b8 100644 --- a/sdk/plugins/meta.cpp +++ b/sdk/plugins/meta.cpp @@ -329,12 +329,15 @@ class LoadStoreTransform : public MetaTransform { Q_OBJECT Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) - Q_PROPERTY(br::Transform *transform READ get_transform WRITE set_transform RESET reset_transform STORED false) BR_PROPERTY(QString, description, "Identity") - BR_PROPERTY(br::Transform*, transform, NULL) + Transform *transform; QString baseName; +public: + LoadStoreTransform() : transform(NULL) {} + +private: void init() { if (transform != NULL) return; @@ -354,7 +357,7 @@ class LoadStoreTransform : public MetaTransform QDataStream stream(&byteArray, QFile::WriteOnly); stream << description; transform->store(stream); - QtUtils::writeFile(getFileName(), byteArray); + QtUtils::writeFile(baseName, byteArray); } void project(const Template &src, Template &dst) const @@ -369,6 +372,7 @@ class LoadStoreTransform : public MetaTransform QString getFileName() const { + if (QFileInfo(baseName).exists()) return baseName; const QString file = Globals->sdkPath + "/share/openbr/models/transforms/" + baseName; return QFileInfo(file).exists() ? file : QString(); } diff --git a/sdk/plugins/misc.cpp b/sdk/plugins/misc.cpp index c480680..bfef00f 100644 --- a/sdk/plugins/misc.cpp +++ b/sdk/plugins/misc.cpp @@ -34,18 +34,15 @@ class OpenTransform : public UntrainableMetaTransform void project(const Template &src, Template &dst) const { if (Globals->verbose) qDebug("Opening %s", qPrintable(src.file.flat())); - bool fto = false; + dst.file = src.file; foreach (const File &file, src.file.split()) { QScopedPointer format(Factory::make(file)); - QList mats = format->read(); - if (mats.isEmpty()) { - qWarning("Can't open %s", qPrintable(file.flat())); - fto = true; - } - dst += mats; + Template t = format->read(); + if (t.isEmpty()) qWarning("Can't open %s", qPrintable(file.flat())); + dst.append(t); + dst.file.append(t.file.localMetadata()); } - dst.file = src.file; - dst.file.insert("FTO", fto); + dst.file.insert("FTO", dst.isEmpty()); } }; @@ -93,6 +90,27 @@ BR_REGISTER(Transform, ShowTransform) /*! * \ingroup transforms + * \brief Prints the template's file to stdout or stderr. + * \author Josh Klontz \cite jklontz + */ +class PrintTransform : public UntrainableMetaTransform +{ + Q_OBJECT + Q_PROPERTY(bool error READ get_error WRITE set_error RESET reset_error) + BR_PROPERTY(bool, error, false) + + void project(const Template &src, Template &dst) const + { + dst = src; + if (error) qDebug("%s\n", qPrintable(src.file.flat())); + else printf("%s\n", qPrintable(src.file.flat())); + } +}; + +BR_REGISTER(Transform, PrintTransform) + +/*! + * \ingroup transforms * \brief Sets the template's matrix data to the br::File::name. * \author Josh Klontz \cite jklontz */ diff --git a/sdk/plugins/regions.cpp b/sdk/plugins/regions.cpp index 9d9086e..561fc74 100644 --- a/sdk/plugins/regions.cpp +++ b/sdk/plugins/regions.cpp @@ -79,23 +79,29 @@ BR_REGISTER(Transform, ByRow) class Cat : public UntrainableMetaTransform { Q_OBJECT + Q_PROPERTY(int partitions READ get_partitions WRITE set_partitions RESET reset_partitions) + BR_PROPERTY(int, partitions, 1) void project(const Template &src, Template &dst) const { - int vals = 0; - foreach (const cv::Mat &m, src) - vals += m.total() * m.channels(); - - Mat cat(1, (int)vals, CV_32FC1); - int offset = 0; - foreach (const cv::Mat &m, src) { - size_t size = m.total() * m.elemSize(); - memcpy(&cat.data[offset], m.ptr(), size); - offset += size; - } - dst.file = src.file; - dst = cat; + + if (src.size() % partitions != 0) + qFatal("Cat %d partitions does not evenly divide %d matrices.", partitions, src.size()); + QVector sizes(partitions, 0); + for (int i=0; i offsets(partitions, 0); + for (int i=0; i