Commit a09c27a19b51113a48525d0478447d8c2fbb6c45
1 parent
dc84c45c
Added neural networks
Showing
1 changed file
with
117 additions
and
0 deletions
openbr/plugins/nn.cpp
0 → 100644
| 1 | +#include <opencv2/ml/ml.hpp> | ||
| 2 | + | ||
| 3 | +#include "openbr_internal.h" | ||
| 4 | +#include "openbr/core/qtutils.h" | ||
| 5 | +#include "openbr/core/opencvutils.h" | ||
| 6 | +#include "openbr/core/eigenutils.h" | ||
| 7 | +#include <QString> | ||
| 8 | +#include <QTemporaryFile> | ||
| 9 | + | ||
| 10 | +using namespace std; | ||
| 11 | +using namespace cv; | ||
| 12 | + | ||
| 13 | +namespace br | ||
| 14 | +{ | ||
| 15 | + | ||
| 16 | +static void storeMLP(const CvANN_MLP &mlp, QDataStream &stream) | ||
| 17 | +{ | ||
| 18 | + // Create local file | ||
| 19 | + QTemporaryFile tempFile; | ||
| 20 | + tempFile.open(); | ||
| 21 | + tempFile.close(); | ||
| 22 | + | ||
| 23 | + // Save SVM to local file | ||
| 24 | + mlp.save(qPrintable(tempFile.fileName())); | ||
| 25 | + | ||
| 26 | + // Copy local file contents to stream | ||
| 27 | + tempFile.open(); | ||
| 28 | + QByteArray data = tempFile.readAll(); | ||
| 29 | + tempFile.close(); | ||
| 30 | + stream << data; | ||
| 31 | +} | ||
| 32 | + | ||
| 33 | +static void loadMLP(CvANN_MLP &mlp, QDataStream &stream) | ||
| 34 | +{ | ||
| 35 | + // Copy local file contents from stream | ||
| 36 | + QByteArray data; | ||
| 37 | + stream >> data; | ||
| 38 | + | ||
| 39 | + // Create local file | ||
| 40 | + QTemporaryFile tempFile(QDir::tempPath()+"/MLP"); | ||
| 41 | + tempFile.open(); | ||
| 42 | + tempFile.write(data); | ||
| 43 | + tempFile.close(); | ||
| 44 | + | ||
| 45 | + // Load SVM from local file | ||
| 46 | + mlp.load(qPrintable(tempFile.fileName())); | ||
| 47 | +} | ||
| 48 | + | ||
| 49 | +/*! | ||
| 50 | + * \ingroup transforms | ||
| 51 | + * \brief Wraps OpenCV's multi-layer perceptron framework | ||
| 52 | + * \author Scott Klum \cite sklum | ||
| 53 | + */ | ||
| 54 | +class MLPTransform : public MetaTransform | ||
| 55 | +{ | ||
| 56 | + Q_OBJECT | ||
| 57 | + | ||
| 58 | + Q_PROPERTY(QStringList inputVariables READ get_inputVariables WRITE set_inputVariables RESET reset_inputVariables STORED false) | ||
| 59 | + BR_PROPERTY(QStringList, inputVariables, QStringList()) | ||
| 60 | + Q_PROPERTY(QStringList outputVariables READ get_outputVariables WRITE set_outputVariables RESET reset_outputVariables STORED false) | ||
| 61 | + BR_PROPERTY(QStringList, outputVariables, QStringList()) | ||
| 62 | + Q_PROPERTY(int hiddenLayers READ get_hiddenLayers WRITE set_hiddenLayers RESET reset_hiddenLayers STORED false) | ||
| 63 | + BR_PROPERTY(int, hiddenLayers, 1) | ||
| 64 | + Q_PROPERTY(QList<int> neuronsPerLayer READ get_neuronsPerLayer WRITE set_neuronsPerLayer RESET reset_neuronsPerLayer STORED false) | ||
| 65 | + BR_PROPERTY(QList<int>, neuronsPerLayer, QList<int>() << 6) | ||
| 66 | + | ||
| 67 | + CvANN_MLP mlp; | ||
| 68 | + | ||
| 69 | + void init() | ||
| 70 | + { | ||
| 71 | + Mat layers = Mat(hiddenLayers, 1, CV_32SC1); | ||
| 72 | + for (int i=0; i<hiddenLayers; i++) { | ||
| 73 | + layers.row(i) = Scalar(neuronsPerLayer.at(i)); | ||
| 74 | + } | ||
| 75 | + mlp.create(layers,CvANN_MLP::SIGMOID_SYM, 1, 1); | ||
| 76 | + } | ||
| 77 | + | ||
| 78 | + void train(const TemplateList &data) | ||
| 79 | + { | ||
| 80 | + Mat _data = OpenCVUtils::toMat(data.data()); | ||
| 81 | + | ||
| 82 | + // Assuming data has n templates | ||
| 83 | + // _data needs to be n x size of input layer | ||
| 84 | + // Labels needs to be a n x outputs matrix | ||
| 85 | + // For the time being we're going to assume a single output | ||
| 86 | + Mat labels = Mat::zeros(data.size(),inputVariables.size(),CV_32F); | ||
| 87 | + for (int i=0; i<inputVariables.size(); i++) | ||
| 88 | + labels.col(i) += OpenCVUtils::toMat(File::get<float>(data, inputVariables.at(i))); | ||
| 89 | + | ||
| 90 | + mlp.train(_data,labels,Mat()); | ||
| 91 | + } | ||
| 92 | + | ||
| 93 | + void project(const Template &src, Template &dst) const | ||
| 94 | + { | ||
| 95 | + // See above for response dimensionality | ||
| 96 | + Mat response(outputVariables.size(), 1, CV_32FC1); | ||
| 97 | + mlp.predict(src.m().reshape(1,1),response); | ||
| 98 | + | ||
| 99 | + for (int i=0; i<outputVariables.size(); i++) dst.file.set(outputVariables.at(i),response.at<float>(i,0)); | ||
| 100 | + } | ||
| 101 | + | ||
| 102 | + void load(QDataStream &stream) | ||
| 103 | + { | ||
| 104 | + loadMLP(mlp,stream); | ||
| 105 | + } | ||
| 106 | + | ||
| 107 | + void store(QDataStream &stream) const | ||
| 108 | + { | ||
| 109 | + storeMLP(mlp,stream); | ||
| 110 | + } | ||
| 111 | +}; | ||
| 112 | + | ||
| 113 | +BR_REGISTER(Transform, MLPTransform) | ||
| 114 | + | ||
| 115 | +} // namespace br | ||
| 116 | + | ||
| 117 | +#include "nn.moc" |