Commit 32f3dd6379dd4875791994d63c44bed96097cbb0
Merge pull request #389 from biometrics/caffe
Caffe
Showing
5 changed files
with
351 additions
and
0 deletions
openbr/plugins/classification/caffe.cpp
0 → 100644
| 1 | +#include <openbr/plugins/openbr_internal.h> | ||
| 2 | +#include <openbr/core/opencvutils.h> | ||
| 3 | +#include <openbr/core/qtutils.h> | ||
| 4 | + | ||
| 5 | +#include <opencv2/imgproc/imgproc.hpp> | ||
| 6 | +#include <caffe/caffe.hpp> | ||
| 7 | + | ||
| 8 | +using caffe::Caffe; | ||
| 9 | +using caffe::Net; | ||
| 10 | +using caffe::MemoryDataLayer; | ||
| 11 | +using caffe::Blob; | ||
| 12 | +using caffe::shared_ptr; | ||
| 13 | + | ||
| 14 | +using namespace cv; | ||
| 15 | + | ||
| 16 | +namespace br | ||
| 17 | +{ | ||
| 18 | + | ||
| 19 | +// Net doesn't expose a default constructor which is expected by the default resource allocator. | ||
| 20 | +// To get around that we make this custom stub class which has a default constructor that passes | ||
| 21 | +// empty values to the Net constructor. | ||
| 22 | +class CaffeNet : public Net<float> | ||
| 23 | +{ | ||
| 24 | +public: | ||
| 25 | + CaffeNet() : Net<float>("", caffe::TEST) {} | ||
| 26 | + CaffeNet(const QString &model, caffe::Phase phase) : Net<float>(model.toStdString(), phase) {} | ||
| 27 | +}; | ||
| 28 | + | ||
| 29 | +class CaffeResourceMaker : public ResourceMaker<CaffeNet> | ||
| 30 | +{ | ||
| 31 | + QString model; | ||
| 32 | + QString weights; | ||
| 33 | + int gpuDevice; | ||
| 34 | + | ||
| 35 | +public: | ||
| 36 | + CaffeResourceMaker(const QString &model, const QString &weights, int gpuDevice) : model(model), weights(weights), gpuDevice(gpuDevice) {} | ||
| 37 | + | ||
| 38 | +private: | ||
| 39 | + CaffeNet *make() const | ||
| 40 | + { | ||
| 41 | + if (gpuDevice >= 0) { | ||
| 42 | + Caffe::SetDevice(gpuDevice); | ||
| 43 | + Caffe::set_mode(Caffe::GPU); | ||
| 44 | + } else { | ||
| 45 | + Caffe::set_mode(Caffe::CPU); | ||
| 46 | + } | ||
| 47 | + | ||
| 48 | + CaffeNet *net = new CaffeNet(model, caffe::TEST); | ||
| 49 | + net->CopyTrainedLayersFrom(weights.toStdString()); | ||
| 50 | + return net; | ||
| 51 | + } | ||
| 52 | +}; | ||
| 53 | + | ||
| 54 | +/*! | ||
| 55 | + * \brief A transform that wraps the Caffe deep learning library. This transform expects the input to a given Caffe model to be a MemoryDataLayer. | ||
| 56 | + * The output of the Caffe network is treated as a feature vector and is stored in dst. Batch processing is possible. For a given batch size set in | ||
| 57 | + * the memory data layer, src is expected to have an equal number of mats. Dst will always have the same size (number of mats) as src and the ordering | ||
| 58 | + * will be preserved, so dst[1] is the output of src[1] after it passes through the neural net. | ||
| 59 | + * \author Jordan Cheney \cite jcheney | ||
| 60 | + * \br_property QString model path to prototxt model file | ||
| 61 | + * \br_property QString weights path to caffemodel file | ||
| 62 | + * \br_property int gpuDevice ID of GPU to use. gpuDevice < 0 runs on the CPU only. | ||
| 63 | + * \br_link Caffe Integration Tutorial ../tutorials.md#caffe | ||
| 64 | + * \br_link Caffe website http://caffe.berkeleyvision.org | ||
| 65 | + */ | ||
| 66 | +class CaffeFVTransform : public UntrainableTransform | ||
| 67 | +{ | ||
| 68 | + Q_OBJECT | ||
| 69 | + | ||
| 70 | + Q_PROPERTY(QString model READ get_model WRITE set_model RESET reset_model STORED false) | ||
| 71 | + Q_PROPERTY(QString weights READ get_weights WRITE set_weights RESET reset_weights STORED false) | ||
| 72 | + Q_PROPERTY(int gpuDevice READ get_gpuDevice WRITE set_gpuDevice RESET reset_gpuDevice STORED false) | ||
| 73 | + BR_PROPERTY(QString, model, "") | ||
| 74 | + BR_PROPERTY(QString, weights, "") | ||
| 75 | + BR_PROPERTY(int, gpuDevice, -1) | ||
| 76 | + | ||
| 77 | + Resource<CaffeNet> caffeResource; | ||
| 78 | + | ||
| 79 | + void init() | ||
| 80 | + { | ||
| 81 | + caffeResource.setResourceMaker(new CaffeResourceMaker(model, weights, gpuDevice)); | ||
| 82 | + } | ||
| 83 | + | ||
| 84 | + bool timeVarying() const | ||
| 85 | + { | ||
| 86 | + return gpuDevice < 0 ? false : true; | ||
| 87 | + } | ||
| 88 | + | ||
| 89 | + void project(const Template &src, Template &dst) const | ||
| 90 | + { | ||
| 91 | + CaffeNet *net = caffeResource.acquire(); | ||
| 92 | + | ||
| 93 | + MemoryDataLayer<float> *data_layer = static_cast<MemoryDataLayer<float> *>(net->layers()[0].get()); | ||
| 94 | + | ||
| 95 | + if (src.size() != data_layer->batch_size()) | ||
| 96 | + qFatal("src should have %d (batch size) mats. It has %d mats.", data_layer->batch_size(), src.size()); | ||
| 97 | + | ||
| 98 | + dst.file = src.file; | ||
| 99 | + | ||
| 100 | + data_layer->AddMatVector(src.toVector().toStdVector(), std::vector<int>(src.size(), 0)); | ||
| 101 | + | ||
| 102 | + Blob<float> *output = net->ForwardPrefilled()[1]; // index 0 is the labels from the data layer (in this case the 0 array we passed in above). | ||
| 103 | + // index 1 is the ouput of the final layer, which is what we want | ||
| 104 | + int dim_features = output->count() / data_layer->batch_size(); | ||
| 105 | + for (int n = 0; n < data_layer->batch_size(); n++) | ||
| 106 | + dst += Mat(1, dim_features, CV_32FC1, output->mutable_cpu_data() + output->offset(n)); | ||
| 107 | + | ||
| 108 | + caffeResource.release(net); | ||
| 109 | + } | ||
| 110 | +}; | ||
| 111 | + | ||
| 112 | +BR_REGISTER(Transform, CaffeFVTransform) | ||
| 113 | + | ||
| 114 | +} // namespace br | ||
| 115 | + | ||
| 116 | +#include "classification/caffe.moc" |
openbr/plugins/cmake/caffe.cmake
0 → 100644
| 1 | +set(BR_WITH_CAFFE OFF CACHE BOOL "Build with Caffe") | ||
| 2 | + | ||
| 3 | +if(${BR_WITH_CAFFE}) | ||
| 4 | + find_package(Caffe) | ||
| 5 | + include_directories(${Caffe_INCLUDE_DIRS}) | ||
| 6 | + add_definitions(${Caffe_DEFINITIONS}) | ||
| 7 | + set(BR_THIRDPARTY_LIBS ${BR_THIRDPARTY_LIBS} ${Caffe_LIBRARIES}) | ||
| 8 | +else() | ||
| 9 | + set(BR_EXCLUDED_PLUGINS ${BR_EXCLUDED_PLUGINS} plugins/classification/caffe.cpp) | ||
| 10 | + set(BR_EXCLUDED_PLUGINS ${BR_EXCLUDED_PLUGINS} plugins/gallery/lmdb.cpp) | ||
| 11 | +endif() |
openbr/plugins/gallery/lmdb.cpp
0 → 100644
| 1 | +#include <openbr/openbr_plugin.h> | ||
| 2 | +#include "openbr/plugins/openbr_internal.h" | ||
| 3 | +#include <openbr/core/qtutils.h> | ||
| 4 | + | ||
| 5 | +#include <QFutureSynchronizer> | ||
| 6 | +#include <QtConcurrentRun> | ||
| 7 | +#include <QMutexLocker> | ||
| 8 | +#include <QWaitCondition> | ||
| 9 | + | ||
| 10 | +#include "caffe/util/db.hpp" | ||
| 11 | +#include "caffe/util/io.hpp" | ||
| 12 | + | ||
| 13 | +using namespace br; | ||
| 14 | + | ||
| 15 | +class lmdbGallery : public Gallery | ||
| 16 | +{ | ||
| 17 | + Q_OBJECT | ||
| 18 | + | ||
| 19 | + TemplateList readBlock(bool *done) | ||
| 20 | + { | ||
| 21 | + *done = false; | ||
| 22 | + if (!initialized) { | ||
| 23 | + db = QSharedPointer<caffe::db::DB>(caffe::db::GetDB("lmdb")); | ||
| 24 | + db->Open(file.name.toStdString(),caffe::db::READ); | ||
| 25 | + cursor = QSharedPointer<caffe::db::Cursor>(db->NewCursor()); | ||
| 26 | + initialized = true; | ||
| 27 | + } | ||
| 28 | + | ||
| 29 | + caffe::Datum datum; | ||
| 30 | + datum.ParseFromString(cursor->value()); | ||
| 31 | + | ||
| 32 | + cv::Mat img; | ||
| 33 | + if (datum.encoded()) { | ||
| 34 | + img = caffe::DecodeDatumToCVMatNative(datum); | ||
| 35 | + } | ||
| 36 | + else { | ||
| 37 | + // create output image of appropriate size | ||
| 38 | + img.create(datum.height(), datum.width(), CV_MAKETYPE(CV_8U, datum.channels())); | ||
| 39 | + // copy matrix data from datum. | ||
| 40 | + for (int h = 0; h < datum.height(); ++h) { | ||
| 41 | + uchar* ptr = img.ptr<uchar>(h); | ||
| 42 | + int img_index = 0; | ||
| 43 | + for (int w = 0; w < datum.width(); ++w) { | ||
| 44 | + for (int c = 0; c < datum.channels(); ++c) { | ||
| 45 | + int datum_index = (c * datum.height() + h) * datum.width() + w; | ||
| 46 | + ptr[img_index++] = (unsigned char)datum.data()[datum_index]; | ||
| 47 | + } | ||
| 48 | + } | ||
| 49 | + } | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + // We acquired the image data, now decode filename from db key | ||
| 53 | + QString baseKey = cursor->key().c_str(); | ||
| 54 | + | ||
| 55 | + int idx = baseKey.indexOf("_"); | ||
| 56 | + if (idx != -1) | ||
| 57 | + baseKey = baseKey.right(baseKey.size() - idx - 1); | ||
| 58 | + | ||
| 59 | + TemplateList output; | ||
| 60 | + output.append(Template(img)); | ||
| 61 | + output.last().file.name = baseKey; | ||
| 62 | + output.last().file.set("Label", datum.label()); | ||
| 63 | + | ||
| 64 | + cursor->Next(); | ||
| 65 | + | ||
| 66 | + if (!cursor->valid()) | ||
| 67 | + *done = true; | ||
| 68 | + | ||
| 69 | + return output; | ||
| 70 | + } | ||
| 71 | + | ||
| 72 | + bool initialized; | ||
| 73 | + QSharedPointer<caffe::db::DB> db; | ||
| 74 | + QSharedPointer<caffe::db::Cursor> cursor; | ||
| 75 | + | ||
| 76 | + QFutureSynchronizer<void> aThread; | ||
| 77 | + QMutex dataLock; | ||
| 78 | + QWaitCondition dataWait; | ||
| 79 | + | ||
| 80 | + bool should_end; | ||
| 81 | + TemplateList data; | ||
| 82 | + | ||
| 83 | + QHash<QString, int> observedLabels; | ||
| 84 | + | ||
| 85 | + static void commitLoop(lmdbGallery * base) | ||
| 86 | + { | ||
| 87 | + QSharedPointer<caffe::db::Transaction> txn; | ||
| 88 | + | ||
| 89 | + int total_count = 0; | ||
| 90 | + | ||
| 91 | + // Acquire the lock | ||
| 92 | + QMutexLocker lock(&base->dataLock); | ||
| 93 | + | ||
| 94 | + while (true) { | ||
| 95 | + // wait for data, or end signal | ||
| 96 | + while(base->data.isEmpty() && !base->should_end) | ||
| 97 | + base->dataWait.wait(&base->dataLock); | ||
| 98 | + | ||
| 99 | + // If should_end, but there is still data, we need another commit | ||
| 100 | + // round | ||
| 101 | + if (base->should_end && base->data.isEmpty()) | ||
| 102 | + break; | ||
| 103 | + | ||
| 104 | + txn = QSharedPointer<caffe::db::Transaction>(base->db->NewTransaction()); | ||
| 105 | + | ||
| 106 | + TemplateList working = base->data; | ||
| 107 | + base->data.clear(); | ||
| 108 | + | ||
| 109 | + // no longer blocking dataLock | ||
| 110 | + lock.unlock(); | ||
| 111 | + | ||
| 112 | + foreach(const Template &t, working) { | ||
| 113 | + // add current image to transaction | ||
| 114 | + caffe::Datum datum; | ||
| 115 | + caffe::CVMatToDatum(t.m(), &datum); | ||
| 116 | + | ||
| 117 | + QVariant base_label = t.file.value("Label"); | ||
| 118 | + QString label_str = base_label.toString(); | ||
| 119 | + | ||
| 120 | + | ||
| 121 | + if (!base->observedLabels.contains(label_str) ) | ||
| 122 | + base->observedLabels[label_str] = base->observedLabels.size(); | ||
| 123 | + | ||
| 124 | + datum.set_label(base->observedLabels[label_str]); | ||
| 125 | + | ||
| 126 | + std::string out; | ||
| 127 | + datum.SerializeToString(&out); | ||
| 128 | + | ||
| 129 | + char key_cstr[256]; | ||
| 130 | + int len = snprintf(key_cstr, 256, "%08d_%s", total_count, qPrintable(t.file.name)); | ||
| 131 | + txn->Put(std::string(key_cstr, len), out); | ||
| 132 | + | ||
| 133 | + total_count++; | ||
| 134 | + } | ||
| 135 | + | ||
| 136 | + txn->Commit(); | ||
| 137 | + lock.relock(); | ||
| 138 | + } | ||
| 139 | + } | ||
| 140 | + | ||
| 141 | + void write(const Template &t) | ||
| 142 | + { | ||
| 143 | + if (!initialized) { | ||
| 144 | + db = QSharedPointer<caffe::db::DB> (caffe::db::GetDB("lmdb")); | ||
| 145 | + db->Open(file.name.toStdString(), caffe::db::NEW); | ||
| 146 | + observedLabels.clear(); | ||
| 147 | + initialized = true; | ||
| 148 | + should_end = false; | ||
| 149 | + // fire thread | ||
| 150 | + aThread.clearFutures(); | ||
| 151 | + aThread.addFuture(QtConcurrent::run(lmdbGallery::commitLoop, this)); | ||
| 152 | + } | ||
| 153 | + | ||
| 154 | + QMutexLocker lock(&dataLock); | ||
| 155 | + data.append(t); | ||
| 156 | + dataWait.wakeAll(); | ||
| 157 | + } | ||
| 158 | + | ||
| 159 | + ~lmdbGallery() | ||
| 160 | + { | ||
| 161 | + if (initialized) { | ||
| 162 | + QMutexLocker lock(&dataLock); | ||
| 163 | + should_end = true; | ||
| 164 | + dataWait.wakeAll(); | ||
| 165 | + lock.unlock(); | ||
| 166 | + | ||
| 167 | + aThread.waitForFinished(); | ||
| 168 | + } | ||
| 169 | + } | ||
| 170 | + | ||
| 171 | + | ||
| 172 | + void init() | ||
| 173 | + { | ||
| 174 | + initialized = false; | ||
| 175 | + should_end = false; | ||
| 176 | + } | ||
| 177 | +}; | ||
| 178 | + | ||
| 179 | +BR_REGISTER(Gallery, lmdbGallery) | ||
| 180 | + | ||
| 181 | + | ||
| 182 | +#include "gallery/lmdb.moc" | ||
| 183 | + |
openbr/plugins/imgproc/pad.cpp
0 → 100644
| 1 | +#include <openbr/plugins/openbr_internal.h> | ||
| 2 | + | ||
| 3 | +using namespace cv; | ||
| 4 | + | ||
| 5 | +namespace br | ||
| 6 | +{ | ||
| 7 | + | ||
| 8 | +class PadTransform : public UntrainableTransform | ||
| 9 | +{ | ||
| 10 | + Q_OBJECT | ||
| 11 | + | ||
| 12 | + Q_PROPERTY(int padSize READ get_padSize WRITE set_padSize RESET reset_padSize STORED false) | ||
| 13 | + Q_PROPERTY(int padValue READ get_padValue WRITE set_padValue RESET reset_padValue STORED false) | ||
| 14 | + BR_PROPERTY(int, padSize, 0) | ||
| 15 | + BR_PROPERTY(int, padValue, 0) | ||
| 16 | + | ||
| 17 | + void project(const Template &src, Template &dst) const | ||
| 18 | + { | ||
| 19 | + dst.file = src.file; | ||
| 20 | + | ||
| 21 | + foreach (const Mat &m, src) { | ||
| 22 | + Mat padded = padValue * Mat::ones(m.rows + 2*padSize, m.cols + 2*padSize, m.type()); | ||
| 23 | + padded(Rect(padSize, padSize, padded.cols - padSize, padded.rows - padSize)) = m; | ||
| 24 | + dst += padded; | ||
| 25 | + } | ||
| 26 | + } | ||
| 27 | +}; | ||
| 28 | + | ||
| 29 | +BR_REGISTER(Transform, PadTransform) | ||
| 30 | + | ||
| 31 | +} // namespace br | ||
| 32 | + | ||
| 33 | +#include "imgproc/pad.moc" |
openbr/plugins/imgproc/roi.cpp
| @@ -26,12 +26,16 @@ namespace br | @@ -26,12 +26,16 @@ namespace br | ||
| 26 | * \ingroup transforms | 26 | * \ingroup transforms |
| 27 | * \brief Crops the rectangular regions of interest. | 27 | * \brief Crops the rectangular regions of interest. |
| 28 | * \author Josh Klontz \cite jklontz | 28 | * \author Josh Klontz \cite jklontz |
| 29 | + * \br_property QString propName Optional property name for a rectangle in metadata. If no propName is given the transform will use rects stored in the file.rects field or build a rectangle using "X", "Y", "Width", and "Height" fields if they exist. | ||
| 30 | + * \br_property bool copyOnCrop If true make a clone of each crop before appending the crop to dst. This guarantees that the crops will be continuous in memory, which is an occasionally useful property. Default is false. | ||
| 29 | */ | 31 | */ |
| 30 | class ROITransform : public UntrainableTransform | 32 | class ROITransform : public UntrainableTransform |
| 31 | { | 33 | { |
| 32 | Q_OBJECT | 34 | Q_OBJECT |
| 33 | Q_PROPERTY(QString propName READ get_propName WRITE set_propName RESET reset_propName STORED false) | 35 | Q_PROPERTY(QString propName READ get_propName WRITE set_propName RESET reset_propName STORED false) |
| 36 | + Q_PROPERTY(bool copyOnCrop READ get_copyOnCrop WRITE set_copyOnCrop RESET reset_copyOnCrop STORED false) | ||
| 34 | BR_PROPERTY(QString, propName, "") | 37 | BR_PROPERTY(QString, propName, "") |
| 38 | + BR_PROPERTY(bool, copyOnCrop, false) | ||
| 35 | 39 | ||
| 36 | void project(const Template &src, Template &dst) const | 40 | void project(const Template &src, Template &dst) const |
| 37 | { | 41 | { |
| @@ -52,6 +56,10 @@ class ROITransform : public UntrainableTransform | @@ -52,6 +56,10 @@ class ROITransform : public UntrainableTransform | ||
| 52 | qWarning("No rects present in file."); | 56 | qWarning("No rects present in file."); |
| 53 | } | 57 | } |
| 54 | dst.file.clearRects(); | 58 | dst.file.clearRects(); |
| 59 | + | ||
| 60 | + if (copyOnCrop) | ||
| 61 | + for (int i = 0; i < dst.size(); i++) | ||
| 62 | + dst.replace(i, dst[i].clone()); | ||
| 55 | } | 63 | } |
| 56 | }; | 64 | }; |
| 57 | 65 |