Commit 194180ecf8f933f06d3c15a43d0579fcadeb118f

Authored by Jordan Cheney
1 parent 200c6d2b

Caffe integration

openbr/plugins/classification/caffe.cpp 0 → 100644
  1 +#include <openbr/plugins/openbr_internal.h>
  2 +
  3 +#include <caffe/caffe.hpp>
  4 +
  5 +using caffe::Caffe;
  6 +using caffe::Solver;
  7 +using caffe::Net;
  8 +using caffe::Blob;
  9 +using caffe::shared_ptr;
  10 +using caffe::vector;
  11 +
  12 +namespace br
  13 +{
  14 +
  15 +/*!
  16 + * \brief A transform that wraps the Caffe Deep learning library
  17 + * \author Jordan Cheney \cite JordanCheney
  18 + * \
  19 + */
  20 +class CaffeTransform : public Transform
  21 +{
  22 + Q_OBJECT
  23 +
  24 + Q_PROPERTY(QString modelFile READ get_modelFile WRITE set_modelFile RESET reset_modelFile STORED false)
  25 + Q_PROPERTY(QString solverFile READ get_solverFile WRITE set_solverFile RESET reset_solverFile STORED false)
  26 + Q_PROPERTY(QString weightsFile READ get_weightsFile WRITE set_weightsFile RESET reset_weightsFile STORED false)
  27 + Q_PROPERTY(int gpuDevice READ get_gpuDevice WRITE set_gpuDevice RESET reset_gpuDevice STORED false)
  28 + BR_PROPERTY(QString, modelFile, "")
  29 + BR_PROPERTY(QString, solverFile, "")
  30 + BR_PROPERTY(QString, weightsFile, "")
  31 + BR_PROPERTY(int, gpuDevice, -1)
  32 +
  33 + void init()
  34 + {
  35 + if (gpuDevice >= 0) {
  36 + Caffe::SetDevice(gpuDevice);
  37 + Caffe::set_mode(Caffe::GPU);
  38 + } else {
  39 + Caffe::set_mode(Caffe::CPU);
  40 + }
  41 + }
  42 +
  43 + void train(const TemplateList &data)
  44 + {
  45 + (void) data;
  46 +
  47 + caffe::SolverParameter solver_param;
  48 + caffe::ReadProtoFromTextFileOrDie(solverFile.toStdString(), &solver_param);
  49 +
  50 + shared_ptr<Solver<float> > solver(caffe::GetSolver<float>(solver_param));
  51 + solver->Solve();
  52 + }
  53 +
  54 + void project(const Template &src, Template &dst) const
  55 + {
  56 + (void)src; (void)dst;
  57 + Net<float> net(modelFile.toStdString(), caffe::TEST);
  58 + net.CopyTrainedLayersFrom(weightsFile.toStdString());
  59 +
  60 + vector<Blob<float> *> bottom_vec; // perhaps src data should go here?
  61 + vector<int> test_score_output_id;
  62 + vector<float> test_score;
  63 +
  64 + float loss;
  65 + const vector<Blob<float> *> &result = net.Forward(bottom_vec, &loss);
  66 +
  67 + int idx = 0;
  68 + for (int i = 0; i < (int)result.size(); i++) {
  69 + const float *result_data = result[i]->cpu_data();
  70 + for (int j = 0; j < result[i]->count(); j++, idx++) {
  71 + test_score.push_back(result_data[j]);
  72 + test_score_output_id.push_back(i);
  73 +
  74 + if (Globals->verbose)
  75 + qDebug("%s = %f", net.blob_names()[net.output_blob_indices()[i]].c_str(), result_data[j]);
  76 + }
  77 + }
  78 + }
  79 +};
  80 +
  81 +BR_REGISTER(Transform, CaffeTransform)
  82 +
  83 +} // namespace br
  84 +
  85 +#include "classification/caffe.moc"
... ...
openbr/plugins/cmake/caffe.cmake 0 → 100644
  1 +set(BR_WITH_CAFFE OFF CACHE BOOL "Build with Caffe")
  2 +
  3 +if(${BR_WITH_CAFFE})
  4 + find_package(Caffe)
  5 + include_directories(${Caffe_INCLUDE_DIRS})
  6 + add_definitions(${Caffe_DEFINITIONS})
  7 + set(BR_THIRDPARTY_LIBS ${BR_THIRDPARTY_LIBS} ${Caffe_LIBRARIES})
  8 +else()
  9 + set(BR_EXCLUDED_PLUGINS ${BR_EXCLUDED_PLUGINS} plugins/classification/caffe.cpp)
  10 +endif()
... ...