Commit 32f3dd6379dd4875791994d63c44bed96097cbb0

Authored by JordanCheney
2 parents f17baf81 b5c16f94

Merge pull request #389 from biometrics/caffe

Caffe
openbr/plugins/classification/caffe.cpp 0 → 100644
  1 +#include <openbr/plugins/openbr_internal.h>
  2 +#include <openbr/core/opencvutils.h>
  3 +#include <openbr/core/qtutils.h>
  4 +
  5 +#include <opencv2/imgproc/imgproc.hpp>
  6 +#include <caffe/caffe.hpp>
  7 +
  8 +using caffe::Caffe;
  9 +using caffe::Net;
  10 +using caffe::MemoryDataLayer;
  11 +using caffe::Blob;
  12 +using caffe::shared_ptr;
  13 +
  14 +using namespace cv;
  15 +
  16 +namespace br
  17 +{
  18 +
  19 +// Net doesn't expose a default constructor which is expected by the default resource allocator.
  20 +// To get around that we make this custom stub class which has a default constructor that passes
  21 +// empty values to the Net constructor.
  22 +class CaffeNet : public Net<float>
  23 +{
  24 +public:
  25 + CaffeNet() : Net<float>("", caffe::TEST) {}
  26 + CaffeNet(const QString &model, caffe::Phase phase) : Net<float>(model.toStdString(), phase) {}
  27 +};
  28 +
  29 +class CaffeResourceMaker : public ResourceMaker<CaffeNet>
  30 +{
  31 + QString model;
  32 + QString weights;
  33 + int gpuDevice;
  34 +
  35 +public:
  36 + CaffeResourceMaker(const QString &model, const QString &weights, int gpuDevice) : model(model), weights(weights), gpuDevice(gpuDevice) {}
  37 +
  38 +private:
  39 + CaffeNet *make() const
  40 + {
  41 + if (gpuDevice >= 0) {
  42 + Caffe::SetDevice(gpuDevice);
  43 + Caffe::set_mode(Caffe::GPU);
  44 + } else {
  45 + Caffe::set_mode(Caffe::CPU);
  46 + }
  47 +
  48 + CaffeNet *net = new CaffeNet(model, caffe::TEST);
  49 + net->CopyTrainedLayersFrom(weights.toStdString());
  50 + return net;
  51 + }
  52 +};
  53 +
  54 +/*!
  55 + * \brief A transform that wraps the Caffe deep learning library. This transform expects the input to a given Caffe model to be a MemoryDataLayer.
  56 + * The output of the Caffe network is treated as a feature vector and is stored in dst. Batch processing is possible. For a given batch size set in
  57 + * the memory data layer, src is expected to have an equal number of mats. Dst will always have the same size (number of mats) as src and the ordering
  58 + * will be preserved, so dst[1] is the output of src[1] after it passes through the neural net.
  59 + * \author Jordan Cheney \cite jcheney
  60 + * \br_property QString model path to prototxt model file
  61 + * \br_property QString weights path to caffemodel file
  62 + * \br_property int gpuDevice ID of GPU to use. gpuDevice < 0 runs on the CPU only.
  63 + * \br_link Caffe Integration Tutorial ../tutorials.md#caffe
  64 + * \br_link Caffe website http://caffe.berkeleyvision.org
  65 + */
  66 +class CaffeFVTransform : public UntrainableTransform
  67 +{
  68 + Q_OBJECT
  69 +
  70 + Q_PROPERTY(QString model READ get_model WRITE set_model RESET reset_model STORED false)
  71 + Q_PROPERTY(QString weights READ get_weights WRITE set_weights RESET reset_weights STORED false)
  72 + Q_PROPERTY(int gpuDevice READ get_gpuDevice WRITE set_gpuDevice RESET reset_gpuDevice STORED false)
  73 + BR_PROPERTY(QString, model, "")
  74 + BR_PROPERTY(QString, weights, "")
  75 + BR_PROPERTY(int, gpuDevice, -1)
  76 +
  77 + Resource<CaffeNet> caffeResource;
  78 +
  79 + void init()
  80 + {
  81 + caffeResource.setResourceMaker(new CaffeResourceMaker(model, weights, gpuDevice));
  82 + }
  83 +
  84 + bool timeVarying() const
  85 + {
  86 + return gpuDevice < 0 ? false : true;
  87 + }
  88 +
  89 + void project(const Template &src, Template &dst) const
  90 + {
  91 + CaffeNet *net = caffeResource.acquire();
  92 +
  93 + MemoryDataLayer<float> *data_layer = static_cast<MemoryDataLayer<float> *>(net->layers()[0].get());
  94 +
  95 + if (src.size() != data_layer->batch_size())
  96 + qFatal("src should have %d (batch size) mats. It has %d mats.", data_layer->batch_size(), src.size());
  97 +
  98 + dst.file = src.file;
  99 +
  100 + data_layer->AddMatVector(src.toVector().toStdVector(), std::vector<int>(src.size(), 0));
  101 +
  102 + Blob<float> *output = net->ForwardPrefilled()[1]; // index 0 is the labels from the data layer (in this case the 0 array we passed in above).
  103 + // index 1 is the ouput of the final layer, which is what we want
  104 + int dim_features = output->count() / data_layer->batch_size();
  105 + for (int n = 0; n < data_layer->batch_size(); n++)
  106 + dst += Mat(1, dim_features, CV_32FC1, output->mutable_cpu_data() + output->offset(n));
  107 +
  108 + caffeResource.release(net);
  109 + }
  110 +};
  111 +
  112 +BR_REGISTER(Transform, CaffeFVTransform)
  113 +
  114 +} // namespace br
  115 +
  116 +#include "classification/caffe.moc"
... ...
openbr/plugins/cmake/caffe.cmake 0 → 100644
  1 +set(BR_WITH_CAFFE OFF CACHE BOOL "Build with Caffe")
  2 +
  3 +if(${BR_WITH_CAFFE})
  4 + find_package(Caffe)
  5 + include_directories(${Caffe_INCLUDE_DIRS})
  6 + add_definitions(${Caffe_DEFINITIONS})
  7 + set(BR_THIRDPARTY_LIBS ${BR_THIRDPARTY_LIBS} ${Caffe_LIBRARIES})
  8 +else()
  9 + set(BR_EXCLUDED_PLUGINS ${BR_EXCLUDED_PLUGINS} plugins/classification/caffe.cpp)
  10 + set(BR_EXCLUDED_PLUGINS ${BR_EXCLUDED_PLUGINS} plugins/gallery/lmdb.cpp)
  11 +endif()
... ...
openbr/plugins/gallery/lmdb.cpp 0 → 100644
  1 +#include <openbr/openbr_plugin.h>
  2 +#include "openbr/plugins/openbr_internal.h"
  3 +#include <openbr/core/qtutils.h>
  4 +
  5 +#include <QFutureSynchronizer>
  6 +#include <QtConcurrentRun>
  7 +#include <QMutexLocker>
  8 +#include <QWaitCondition>
  9 +
  10 +#include "caffe/util/db.hpp"
  11 +#include "caffe/util/io.hpp"
  12 +
  13 +using namespace br;
  14 +
  15 +class lmdbGallery : public Gallery
  16 +{
  17 + Q_OBJECT
  18 +
  19 + TemplateList readBlock(bool *done)
  20 + {
  21 + *done = false;
  22 + if (!initialized) {
  23 + db = QSharedPointer<caffe::db::DB>(caffe::db::GetDB("lmdb"));
  24 + db->Open(file.name.toStdString(),caffe::db::READ);
  25 + cursor = QSharedPointer<caffe::db::Cursor>(db->NewCursor());
  26 + initialized = true;
  27 + }
  28 +
  29 + caffe::Datum datum;
  30 + datum.ParseFromString(cursor->value());
  31 +
  32 + cv::Mat img;
  33 + if (datum.encoded()) {
  34 + img = caffe::DecodeDatumToCVMatNative(datum);
  35 + }
  36 + else {
  37 + // create output image of appropriate size
  38 + img.create(datum.height(), datum.width(), CV_MAKETYPE(CV_8U, datum.channels()));
  39 + // copy matrix data from datum.
  40 + for (int h = 0; h < datum.height(); ++h) {
  41 + uchar* ptr = img.ptr<uchar>(h);
  42 + int img_index = 0;
  43 + for (int w = 0; w < datum.width(); ++w) {
  44 + for (int c = 0; c < datum.channels(); ++c) {
  45 + int datum_index = (c * datum.height() + h) * datum.width() + w;
  46 + ptr[img_index++] = (unsigned char)datum.data()[datum_index];
  47 + }
  48 + }
  49 + }
  50 + }
  51 +
  52 + // We acquired the image data, now decode filename from db key
  53 + QString baseKey = cursor->key().c_str();
  54 +
  55 + int idx = baseKey.indexOf("_");
  56 + if (idx != -1)
  57 + baseKey = baseKey.right(baseKey.size() - idx - 1);
  58 +
  59 + TemplateList output;
  60 + output.append(Template(img));
  61 + output.last().file.name = baseKey;
  62 + output.last().file.set("Label", datum.label());
  63 +
  64 + cursor->Next();
  65 +
  66 + if (!cursor->valid())
  67 + *done = true;
  68 +
  69 + return output;
  70 + }
  71 +
  72 + bool initialized;
  73 + QSharedPointer<caffe::db::DB> db;
  74 + QSharedPointer<caffe::db::Cursor> cursor;
  75 +
  76 + QFutureSynchronizer<void> aThread;
  77 + QMutex dataLock;
  78 + QWaitCondition dataWait;
  79 +
  80 + bool should_end;
  81 + TemplateList data;
  82 +
  83 + QHash<QString, int> observedLabels;
  84 +
  85 + static void commitLoop(lmdbGallery * base)
  86 + {
  87 + QSharedPointer<caffe::db::Transaction> txn;
  88 +
  89 + int total_count = 0;
  90 +
  91 + // Acquire the lock
  92 + QMutexLocker lock(&base->dataLock);
  93 +
  94 + while (true) {
  95 + // wait for data, or end signal
  96 + while(base->data.isEmpty() && !base->should_end)
  97 + base->dataWait.wait(&base->dataLock);
  98 +
  99 + // If should_end, but there is still data, we need another commit
  100 + // round
  101 + if (base->should_end && base->data.isEmpty())
  102 + break;
  103 +
  104 + txn = QSharedPointer<caffe::db::Transaction>(base->db->NewTransaction());
  105 +
  106 + TemplateList working = base->data;
  107 + base->data.clear();
  108 +
  109 + // no longer blocking dataLock
  110 + lock.unlock();
  111 +
  112 + foreach(const Template &t, working) {
  113 + // add current image to transaction
  114 + caffe::Datum datum;
  115 + caffe::CVMatToDatum(t.m(), &datum);
  116 +
  117 + QVariant base_label = t.file.value("Label");
  118 + QString label_str = base_label.toString();
  119 +
  120 +
  121 + if (!base->observedLabels.contains(label_str) )
  122 + base->observedLabels[label_str] = base->observedLabels.size();
  123 +
  124 + datum.set_label(base->observedLabels[label_str]);
  125 +
  126 + std::string out;
  127 + datum.SerializeToString(&out);
  128 +
  129 + char key_cstr[256];
  130 + int len = snprintf(key_cstr, 256, "%08d_%s", total_count, qPrintable(t.file.name));
  131 + txn->Put(std::string(key_cstr, len), out);
  132 +
  133 + total_count++;
  134 + }
  135 +
  136 + txn->Commit();
  137 + lock.relock();
  138 + }
  139 + }
  140 +
  141 + void write(const Template &t)
  142 + {
  143 + if (!initialized) {
  144 + db = QSharedPointer<caffe::db::DB> (caffe::db::GetDB("lmdb"));
  145 + db->Open(file.name.toStdString(), caffe::db::NEW);
  146 + observedLabels.clear();
  147 + initialized = true;
  148 + should_end = false;
  149 + // fire thread
  150 + aThread.clearFutures();
  151 + aThread.addFuture(QtConcurrent::run(lmdbGallery::commitLoop, this));
  152 + }
  153 +
  154 + QMutexLocker lock(&dataLock);
  155 + data.append(t);
  156 + dataWait.wakeAll();
  157 + }
  158 +
  159 + ~lmdbGallery()
  160 + {
  161 + if (initialized) {
  162 + QMutexLocker lock(&dataLock);
  163 + should_end = true;
  164 + dataWait.wakeAll();
  165 + lock.unlock();
  166 +
  167 + aThread.waitForFinished();
  168 + }
  169 + }
  170 +
  171 +
  172 + void init()
  173 + {
  174 + initialized = false;
  175 + should_end = false;
  176 + }
  177 +};
  178 +
  179 +BR_REGISTER(Gallery, lmdbGallery)
  180 +
  181 +
  182 +#include "gallery/lmdb.moc"
  183 +
... ...
openbr/plugins/imgproc/pad.cpp 0 → 100644
  1 +#include <openbr/plugins/openbr_internal.h>
  2 +
  3 +using namespace cv;
  4 +
  5 +namespace br
  6 +{
  7 +
  8 +class PadTransform : public UntrainableTransform
  9 +{
  10 + Q_OBJECT
  11 +
  12 + Q_PROPERTY(int padSize READ get_padSize WRITE set_padSize RESET reset_padSize STORED false)
  13 + Q_PROPERTY(int padValue READ get_padValue WRITE set_padValue RESET reset_padValue STORED false)
  14 + BR_PROPERTY(int, padSize, 0)
  15 + BR_PROPERTY(int, padValue, 0)
  16 +
  17 + void project(const Template &src, Template &dst) const
  18 + {
  19 + dst.file = src.file;
  20 +
  21 + foreach (const Mat &m, src) {
  22 + Mat padded = padValue * Mat::ones(m.rows + 2*padSize, m.cols + 2*padSize, m.type());
  23 + padded(Rect(padSize, padSize, padded.cols - padSize, padded.rows - padSize)) = m;
  24 + dst += padded;
  25 + }
  26 + }
  27 +};
  28 +
  29 +BR_REGISTER(Transform, PadTransform)
  30 +
  31 +} // namespace br
  32 +
  33 +#include "imgproc/pad.moc"
... ...
openbr/plugins/imgproc/roi.cpp
... ... @@ -26,12 +26,16 @@ namespace br
26 26 * \ingroup transforms
27 27 * \brief Crops the rectangular regions of interest.
28 28 * \author Josh Klontz \cite jklontz
  29 + * \br_property QString propName Optional property name for a rectangle in metadata. If no propName is given the transform will use rects stored in the file.rects field or build a rectangle using "X", "Y", "Width", and "Height" fields if they exist.
  30 + * \br_property bool copyOnCrop If true make a clone of each crop before appending the crop to dst. This guarantees that the crops will be continuous in memory, which is an occasionally useful property. Default is false.
29 31 */
30 32 class ROITransform : public UntrainableTransform
31 33 {
32 34 Q_OBJECT
33 35 Q_PROPERTY(QString propName READ get_propName WRITE set_propName RESET reset_propName STORED false)
  36 + Q_PROPERTY(bool copyOnCrop READ get_copyOnCrop WRITE set_copyOnCrop RESET reset_copyOnCrop STORED false)
34 37 BR_PROPERTY(QString, propName, "")
  38 + BR_PROPERTY(bool, copyOnCrop, false)
35 39  
36 40 void project(const Template &src, Template &dst) const
37 41 {
... ... @@ -52,6 +56,10 @@ class ROITransform : public UntrainableTransform
52 56 qWarning("No rects present in file.");
53 57 }
54 58 dst.file.clearRects();
  59 +
  60 + if (copyOnCrop)
  61 + for (int i = 0; i < dst.size(); i++)
  62 + dst.replace(i, dst[i].clone());
55 63 }
56 64 };
57 65  
... ...