From 194180ecf8f933f06d3c15a43d0579fcadeb118f Mon Sep 17 00:00:00 2001 From: Jordan Cheney Date: Fri, 26 Jun 2015 16:45:24 -0400 Subject: [PATCH] Caffe integration --- openbr/plugins/classification/caffe.cpp | 85 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ openbr/plugins/cmake/caffe.cmake | 10 ++++++++++ 2 files changed, 95 insertions(+), 0 deletions(-) create mode 100644 openbr/plugins/classification/caffe.cpp create mode 100644 openbr/plugins/cmake/caffe.cmake diff --git a/openbr/plugins/classification/caffe.cpp b/openbr/plugins/classification/caffe.cpp new file mode 100644 index 0000000..e280abb --- /dev/null +++ b/openbr/plugins/classification/caffe.cpp @@ -0,0 +1,85 @@ +#include + +#include + +using caffe::Caffe; +using caffe::Solver; +using caffe::Net; +using caffe::Blob; +using caffe::shared_ptr; +using caffe::vector; + +namespace br +{ + +/*! + * \brief A transform that wraps the Caffe Deep learning library + * \author Jordan Cheney \cite JordanCheney + * \ + */ +class CaffeTransform : public Transform +{ + Q_OBJECT + + Q_PROPERTY(QString modelFile READ get_modelFile WRITE set_modelFile RESET reset_modelFile STORED false) + Q_PROPERTY(QString solverFile READ get_solverFile WRITE set_solverFile RESET reset_solverFile STORED false) + Q_PROPERTY(QString weightsFile READ get_weightsFile WRITE set_weightsFile RESET reset_weightsFile STORED false) + Q_PROPERTY(int gpuDevice READ get_gpuDevice WRITE set_gpuDevice RESET reset_gpuDevice STORED false) + BR_PROPERTY(QString, modelFile, "") + BR_PROPERTY(QString, solverFile, "") + BR_PROPERTY(QString, weightsFile, "") + BR_PROPERTY(int, gpuDevice, -1) + + void init() + { + if (gpuDevice >= 0) { + Caffe::SetDevice(gpuDevice); + Caffe::set_mode(Caffe::GPU); + } else { + Caffe::set_mode(Caffe::CPU); + } + } + + void train(const TemplateList &data) + { + (void) data; + + caffe::SolverParameter solver_param; + caffe::ReadProtoFromTextFileOrDie(solverFile.toStdString(), &solver_param); + + shared_ptr > solver(caffe::GetSolver(solver_param)); + solver->Solve(); + } + + void project(const Template &src, Template &dst) const + { + (void)src; (void)dst; + Net net(modelFile.toStdString(), caffe::TEST); + net.CopyTrainedLayersFrom(weightsFile.toStdString()); + + vector *> bottom_vec; // perhaps src data should go here? + vector test_score_output_id; + vector test_score; + + float loss; + const vector *> &result = net.Forward(bottom_vec, &loss); + + int idx = 0; + for (int i = 0; i < (int)result.size(); i++) { + const float *result_data = result[i]->cpu_data(); + for (int j = 0; j < result[i]->count(); j++, idx++) { + test_score.push_back(result_data[j]); + test_score_output_id.push_back(i); + + if (Globals->verbose) + qDebug("%s = %f", net.blob_names()[net.output_blob_indices()[i]].c_str(), result_data[j]); + } + } + } +}; + +BR_REGISTER(Transform, CaffeTransform) + +} // namespace br + +#include "classification/caffe.moc" diff --git a/openbr/plugins/cmake/caffe.cmake b/openbr/plugins/cmake/caffe.cmake new file mode 100644 index 0000000..dd1d43b --- /dev/null +++ b/openbr/plugins/cmake/caffe.cmake @@ -0,0 +1,10 @@ +set(BR_WITH_CAFFE OFF CACHE BOOL "Build with Caffe") + +if(${BR_WITH_CAFFE}) + find_package(Caffe) + include_directories(${Caffe_INCLUDE_DIRS}) + add_definitions(${Caffe_DEFINITIONS}) + set(BR_THIRDPARTY_LIBS ${BR_THIRDPARTY_LIBS} ${Caffe_LIBRARIES}) +else() + set(BR_EXCLUDED_PLUGINS ${BR_EXCLUDED_PLUGINS} plugins/classification/caffe.cpp) +endif() -- libgit2 0.21.4