Commit 32f3dd6379dd4875791994d63c44bed96097cbb0

Authored by JordanCheney
2 parents f17baf81 b5c16f94

Merge pull request #389 from biometrics/caffe

Caffe
openbr/plugins/classification/caffe.cpp 0 → 100644
  1 +#include <openbr/plugins/openbr_internal.h>
  2 +#include <openbr/core/opencvutils.h>
  3 +#include <openbr/core/qtutils.h>
  4 +
  5 +#include <opencv2/imgproc/imgproc.hpp>
  6 +#include <caffe/caffe.hpp>
  7 +
  8 +using caffe::Caffe;
  9 +using caffe::Net;
  10 +using caffe::MemoryDataLayer;
  11 +using caffe::Blob;
  12 +using caffe::shared_ptr;
  13 +
  14 +using namespace cv;
  15 +
  16 +namespace br
  17 +{
  18 +
  19 +// Net doesn't expose a default constructor which is expected by the default resource allocator.
  20 +// To get around that we make this custom stub class which has a default constructor that passes
  21 +// empty values to the Net constructor.
  22 +class CaffeNet : public Net<float>
  23 +{
  24 +public:
  25 + CaffeNet() : Net<float>("", caffe::TEST) {}
  26 + CaffeNet(const QString &model, caffe::Phase phase) : Net<float>(model.toStdString(), phase) {}
  27 +};
  28 +
  29 +class CaffeResourceMaker : public ResourceMaker<CaffeNet>
  30 +{
  31 + QString model;
  32 + QString weights;
  33 + int gpuDevice;
  34 +
  35 +public:
  36 + CaffeResourceMaker(const QString &model, const QString &weights, int gpuDevice) : model(model), weights(weights), gpuDevice(gpuDevice) {}
  37 +
  38 +private:
  39 + CaffeNet *make() const
  40 + {
  41 + if (gpuDevice >= 0) {
  42 + Caffe::SetDevice(gpuDevice);
  43 + Caffe::set_mode(Caffe::GPU);
  44 + } else {
  45 + Caffe::set_mode(Caffe::CPU);
  46 + }
  47 +
  48 + CaffeNet *net = new CaffeNet(model, caffe::TEST);
  49 + net->CopyTrainedLayersFrom(weights.toStdString());
  50 + return net;
  51 + }
  52 +};
  53 +
  54 +/*!
  55 + * \brief A transform that wraps the Caffe deep learning library. This transform expects the input to a given Caffe model to be a MemoryDataLayer.
  56 + * The output of the Caffe network is treated as a feature vector and is stored in dst. Batch processing is possible. For a given batch size set in
  57 + * the memory data layer, src is expected to have an equal number of mats. Dst will always have the same size (number of mats) as src and the ordering
  58 + * will be preserved, so dst[1] is the output of src[1] after it passes through the neural net.
  59 + * \author Jordan Cheney \cite jcheney
  60 + * \br_property QString model path to prototxt model file
  61 + * \br_property QString weights path to caffemodel file
  62 + * \br_property int gpuDevice ID of GPU to use. gpuDevice < 0 runs on the CPU only.
  63 + * \br_link Caffe Integration Tutorial ../tutorials.md#caffe
  64 + * \br_link Caffe website http://caffe.berkeleyvision.org
  65 + */
  66 +class CaffeFVTransform : public UntrainableTransform
  67 +{
  68 + Q_OBJECT
  69 +
  70 + Q_PROPERTY(QString model READ get_model WRITE set_model RESET reset_model STORED false)
  71 + Q_PROPERTY(QString weights READ get_weights WRITE set_weights RESET reset_weights STORED false)
  72 + Q_PROPERTY(int gpuDevice READ get_gpuDevice WRITE set_gpuDevice RESET reset_gpuDevice STORED false)
  73 + BR_PROPERTY(QString, model, "")
  74 + BR_PROPERTY(QString, weights, "")
  75 + BR_PROPERTY(int, gpuDevice, -1)
  76 +
  77 + Resource<CaffeNet> caffeResource;
  78 +
  79 + void init()
  80 + {
  81 + caffeResource.setResourceMaker(new CaffeResourceMaker(model, weights, gpuDevice));
  82 + }
  83 +
  84 + bool timeVarying() const
  85 + {
  86 + return gpuDevice < 0 ? false : true;
  87 + }
  88 +
  89 + void project(const Template &src, Template &dst) const
  90 + {
  91 + CaffeNet *net = caffeResource.acquire();
  92 +
  93 + MemoryDataLayer<float> *data_layer = static_cast<MemoryDataLayer<float> *>(net->layers()[0].get());
  94 +
  95 + if (src.size() != data_layer->batch_size())
  96 + qFatal("src should have %d (batch size) mats. It has %d mats.", data_layer->batch_size(), src.size());
  97 +
  98 + dst.file = src.file;
  99 +
  100 + data_layer->AddMatVector(src.toVector().toStdVector(), std::vector<int>(src.size(), 0));
  101 +
  102 + Blob<float> *output = net->ForwardPrefilled()[1]; // index 0 is the labels from the data layer (in this case the 0 array we passed in above).
  103 + // index 1 is the ouput of the final layer, which is what we want
  104 + int dim_features = output->count() / data_layer->batch_size();
  105 + for (int n = 0; n < data_layer->batch_size(); n++)
  106 + dst += Mat(1, dim_features, CV_32FC1, output->mutable_cpu_data() + output->offset(n));
  107 +
  108 + caffeResource.release(net);
  109 + }
  110 +};
  111 +
  112 +BR_REGISTER(Transform, CaffeFVTransform)
  113 +
  114 +} // namespace br
  115 +
  116 +#include "classification/caffe.moc"
openbr/plugins/cmake/caffe.cmake 0 → 100644
  1 +set(BR_WITH_CAFFE OFF CACHE BOOL "Build with Caffe")
  2 +
  3 +if(${BR_WITH_CAFFE})
  4 + find_package(Caffe)
  5 + include_directories(${Caffe_INCLUDE_DIRS})
  6 + add_definitions(${Caffe_DEFINITIONS})
  7 + set(BR_THIRDPARTY_LIBS ${BR_THIRDPARTY_LIBS} ${Caffe_LIBRARIES})
  8 +else()
  9 + set(BR_EXCLUDED_PLUGINS ${BR_EXCLUDED_PLUGINS} plugins/classification/caffe.cpp)
  10 + set(BR_EXCLUDED_PLUGINS ${BR_EXCLUDED_PLUGINS} plugins/gallery/lmdb.cpp)
  11 +endif()
openbr/plugins/gallery/lmdb.cpp 0 → 100644
  1 +#include <openbr/openbr_plugin.h>
  2 +#include "openbr/plugins/openbr_internal.h"
  3 +#include <openbr/core/qtutils.h>
  4 +
  5 +#include <QFutureSynchronizer>
  6 +#include <QtConcurrentRun>
  7 +#include <QMutexLocker>
  8 +#include <QWaitCondition>
  9 +
  10 +#include "caffe/util/db.hpp"
  11 +#include "caffe/util/io.hpp"
  12 +
  13 +using namespace br;
  14 +
  15 +class lmdbGallery : public Gallery
  16 +{
  17 + Q_OBJECT
  18 +
  19 + TemplateList readBlock(bool *done)
  20 + {
  21 + *done = false;
  22 + if (!initialized) {
  23 + db = QSharedPointer<caffe::db::DB>(caffe::db::GetDB("lmdb"));
  24 + db->Open(file.name.toStdString(),caffe::db::READ);
  25 + cursor = QSharedPointer<caffe::db::Cursor>(db->NewCursor());
  26 + initialized = true;
  27 + }
  28 +
  29 + caffe::Datum datum;
  30 + datum.ParseFromString(cursor->value());
  31 +
  32 + cv::Mat img;
  33 + if (datum.encoded()) {
  34 + img = caffe::DecodeDatumToCVMatNative(datum);
  35 + }
  36 + else {
  37 + // create output image of appropriate size
  38 + img.create(datum.height(), datum.width(), CV_MAKETYPE(CV_8U, datum.channels()));
  39 + // copy matrix data from datum.
  40 + for (int h = 0; h < datum.height(); ++h) {
  41 + uchar* ptr = img.ptr<uchar>(h);
  42 + int img_index = 0;
  43 + for (int w = 0; w < datum.width(); ++w) {
  44 + for (int c = 0; c < datum.channels(); ++c) {
  45 + int datum_index = (c * datum.height() + h) * datum.width() + w;
  46 + ptr[img_index++] = (unsigned char)datum.data()[datum_index];
  47 + }
  48 + }
  49 + }
  50 + }
  51 +
  52 + // We acquired the image data, now decode filename from db key
  53 + QString baseKey = cursor->key().c_str();
  54 +
  55 + int idx = baseKey.indexOf("_");
  56 + if (idx != -1)
  57 + baseKey = baseKey.right(baseKey.size() - idx - 1);
  58 +
  59 + TemplateList output;
  60 + output.append(Template(img));
  61 + output.last().file.name = baseKey;
  62 + output.last().file.set("Label", datum.label());
  63 +
  64 + cursor->Next();
  65 +
  66 + if (!cursor->valid())
  67 + *done = true;
  68 +
  69 + return output;
  70 + }
  71 +
  72 + bool initialized;
  73 + QSharedPointer<caffe::db::DB> db;
  74 + QSharedPointer<caffe::db::Cursor> cursor;
  75 +
  76 + QFutureSynchronizer<void> aThread;
  77 + QMutex dataLock;
  78 + QWaitCondition dataWait;
  79 +
  80 + bool should_end;
  81 + TemplateList data;
  82 +
  83 + QHash<QString, int> observedLabels;
  84 +
  85 + static void commitLoop(lmdbGallery * base)
  86 + {
  87 + QSharedPointer<caffe::db::Transaction> txn;
  88 +
  89 + int total_count = 0;
  90 +
  91 + // Acquire the lock
  92 + QMutexLocker lock(&base->dataLock);
  93 +
  94 + while (true) {
  95 + // wait for data, or end signal
  96 + while(base->data.isEmpty() && !base->should_end)
  97 + base->dataWait.wait(&base->dataLock);
  98 +
  99 + // If should_end, but there is still data, we need another commit
  100 + // round
  101 + if (base->should_end && base->data.isEmpty())
  102 + break;
  103 +
  104 + txn = QSharedPointer<caffe::db::Transaction>(base->db->NewTransaction());
  105 +
  106 + TemplateList working = base->data;
  107 + base->data.clear();
  108 +
  109 + // no longer blocking dataLock
  110 + lock.unlock();
  111 +
  112 + foreach(const Template &t, working) {
  113 + // add current image to transaction
  114 + caffe::Datum datum;
  115 + caffe::CVMatToDatum(t.m(), &datum);
  116 +
  117 + QVariant base_label = t.file.value("Label");
  118 + QString label_str = base_label.toString();
  119 +
  120 +
  121 + if (!base->observedLabels.contains(label_str) )
  122 + base->observedLabels[label_str] = base->observedLabels.size();
  123 +
  124 + datum.set_label(base->observedLabels[label_str]);
  125 +
  126 + std::string out;
  127 + datum.SerializeToString(&out);
  128 +
  129 + char key_cstr[256];
  130 + int len = snprintf(key_cstr, 256, "%08d_%s", total_count, qPrintable(t.file.name));
  131 + txn->Put(std::string(key_cstr, len), out);
  132 +
  133 + total_count++;
  134 + }
  135 +
  136 + txn->Commit();
  137 + lock.relock();
  138 + }
  139 + }
  140 +
  141 + void write(const Template &t)
  142 + {
  143 + if (!initialized) {
  144 + db = QSharedPointer<caffe::db::DB> (caffe::db::GetDB("lmdb"));
  145 + db->Open(file.name.toStdString(), caffe::db::NEW);
  146 + observedLabels.clear();
  147 + initialized = true;
  148 + should_end = false;
  149 + // fire thread
  150 + aThread.clearFutures();
  151 + aThread.addFuture(QtConcurrent::run(lmdbGallery::commitLoop, this));
  152 + }
  153 +
  154 + QMutexLocker lock(&dataLock);
  155 + data.append(t);
  156 + dataWait.wakeAll();
  157 + }
  158 +
  159 + ~lmdbGallery()
  160 + {
  161 + if (initialized) {
  162 + QMutexLocker lock(&dataLock);
  163 + should_end = true;
  164 + dataWait.wakeAll();
  165 + lock.unlock();
  166 +
  167 + aThread.waitForFinished();
  168 + }
  169 + }
  170 +
  171 +
  172 + void init()
  173 + {
  174 + initialized = false;
  175 + should_end = false;
  176 + }
  177 +};
  178 +
  179 +BR_REGISTER(Gallery, lmdbGallery)
  180 +
  181 +
  182 +#include "gallery/lmdb.moc"
  183 +
openbr/plugins/imgproc/pad.cpp 0 → 100644
  1 +#include <openbr/plugins/openbr_internal.h>
  2 +
  3 +using namespace cv;
  4 +
  5 +namespace br
  6 +{
  7 +
  8 +class PadTransform : public UntrainableTransform
  9 +{
  10 + Q_OBJECT
  11 +
  12 + Q_PROPERTY(int padSize READ get_padSize WRITE set_padSize RESET reset_padSize STORED false)
  13 + Q_PROPERTY(int padValue READ get_padValue WRITE set_padValue RESET reset_padValue STORED false)
  14 + BR_PROPERTY(int, padSize, 0)
  15 + BR_PROPERTY(int, padValue, 0)
  16 +
  17 + void project(const Template &src, Template &dst) const
  18 + {
  19 + dst.file = src.file;
  20 +
  21 + foreach (const Mat &m, src) {
  22 + Mat padded = padValue * Mat::ones(m.rows + 2*padSize, m.cols + 2*padSize, m.type());
  23 + padded(Rect(padSize, padSize, padded.cols - padSize, padded.rows - padSize)) = m;
  24 + dst += padded;
  25 + }
  26 + }
  27 +};
  28 +
  29 +BR_REGISTER(Transform, PadTransform)
  30 +
  31 +} // namespace br
  32 +
  33 +#include "imgproc/pad.moc"
openbr/plugins/imgproc/roi.cpp
@@ -26,12 +26,16 @@ namespace br @@ -26,12 +26,16 @@ namespace br
26 * \ingroup transforms 26 * \ingroup transforms
27 * \brief Crops the rectangular regions of interest. 27 * \brief Crops the rectangular regions of interest.
28 * \author Josh Klontz \cite jklontz 28 * \author Josh Klontz \cite jklontz
  29 + * \br_property QString propName Optional property name for a rectangle in metadata. If no propName is given the transform will use rects stored in the file.rects field or build a rectangle using "X", "Y", "Width", and "Height" fields if they exist.
  30 + * \br_property bool copyOnCrop If true make a clone of each crop before appending the crop to dst. This guarantees that the crops will be continuous in memory, which is an occasionally useful property. Default is false.
29 */ 31 */
30 class ROITransform : public UntrainableTransform 32 class ROITransform : public UntrainableTransform
31 { 33 {
32 Q_OBJECT 34 Q_OBJECT
33 Q_PROPERTY(QString propName READ get_propName WRITE set_propName RESET reset_propName STORED false) 35 Q_PROPERTY(QString propName READ get_propName WRITE set_propName RESET reset_propName STORED false)
  36 + Q_PROPERTY(bool copyOnCrop READ get_copyOnCrop WRITE set_copyOnCrop RESET reset_copyOnCrop STORED false)
34 BR_PROPERTY(QString, propName, "") 37 BR_PROPERTY(QString, propName, "")
  38 + BR_PROPERTY(bool, copyOnCrop, false)
35 39
36 void project(const Template &src, Template &dst) const 40 void project(const Template &src, Template &dst) const
37 { 41 {
@@ -52,6 +56,10 @@ class ROITransform : public UntrainableTransform @@ -52,6 +56,10 @@ class ROITransform : public UntrainableTransform
52 qWarning("No rects present in file."); 56 qWarning("No rects present in file.");
53 } 57 }
54 dst.file.clearRects(); 58 dst.file.clearRects();
  59 +
  60 + if (copyOnCrop)
  61 + for (int i = 0; i < dst.size(); i++)
  62 + dst.replace(i, dst[i].clone());
55 } 63 }
56 }; 64 };
57 65