Commit a09c27a19b51113a48525d0478447d8c2fbb6c45
1 parent
dc84c45c
Added neural networks
Showing
1 changed file
with
117 additions
and
0 deletions
openbr/plugins/nn.cpp
0 → 100644
| 1 | +#include <opencv2/ml/ml.hpp> | |
| 2 | + | |
| 3 | +#include "openbr_internal.h" | |
| 4 | +#include "openbr/core/qtutils.h" | |
| 5 | +#include "openbr/core/opencvutils.h" | |
| 6 | +#include "openbr/core/eigenutils.h" | |
| 7 | +#include <QString> | |
| 8 | +#include <QTemporaryFile> | |
| 9 | + | |
| 10 | +using namespace std; | |
| 11 | +using namespace cv; | |
| 12 | + | |
| 13 | +namespace br | |
| 14 | +{ | |
| 15 | + | |
| 16 | +static void storeMLP(const CvANN_MLP &mlp, QDataStream &stream) | |
| 17 | +{ | |
| 18 | + // Create local file | |
| 19 | + QTemporaryFile tempFile; | |
| 20 | + tempFile.open(); | |
| 21 | + tempFile.close(); | |
| 22 | + | |
| 23 | + // Save SVM to local file | |
| 24 | + mlp.save(qPrintable(tempFile.fileName())); | |
| 25 | + | |
| 26 | + // Copy local file contents to stream | |
| 27 | + tempFile.open(); | |
| 28 | + QByteArray data = tempFile.readAll(); | |
| 29 | + tempFile.close(); | |
| 30 | + stream << data; | |
| 31 | +} | |
| 32 | + | |
| 33 | +static void loadMLP(CvANN_MLP &mlp, QDataStream &stream) | |
| 34 | +{ | |
| 35 | + // Copy local file contents from stream | |
| 36 | + QByteArray data; | |
| 37 | + stream >> data; | |
| 38 | + | |
| 39 | + // Create local file | |
| 40 | + QTemporaryFile tempFile(QDir::tempPath()+"/MLP"); | |
| 41 | + tempFile.open(); | |
| 42 | + tempFile.write(data); | |
| 43 | + tempFile.close(); | |
| 44 | + | |
| 45 | + // Load SVM from local file | |
| 46 | + mlp.load(qPrintable(tempFile.fileName())); | |
| 47 | +} | |
| 48 | + | |
| 49 | +/*! | |
| 50 | + * \ingroup transforms | |
| 51 | + * \brief Wraps OpenCV's multi-layer perceptron framework | |
| 52 | + * \author Scott Klum \cite sklum | |
| 53 | + */ | |
| 54 | +class MLPTransform : public MetaTransform | |
| 55 | +{ | |
| 56 | + Q_OBJECT | |
| 57 | + | |
| 58 | + Q_PROPERTY(QStringList inputVariables READ get_inputVariables WRITE set_inputVariables RESET reset_inputVariables STORED false) | |
| 59 | + BR_PROPERTY(QStringList, inputVariables, QStringList()) | |
| 60 | + Q_PROPERTY(QStringList outputVariables READ get_outputVariables WRITE set_outputVariables RESET reset_outputVariables STORED false) | |
| 61 | + BR_PROPERTY(QStringList, outputVariables, QStringList()) | |
| 62 | + Q_PROPERTY(int hiddenLayers READ get_hiddenLayers WRITE set_hiddenLayers RESET reset_hiddenLayers STORED false) | |
| 63 | + BR_PROPERTY(int, hiddenLayers, 1) | |
| 64 | + Q_PROPERTY(QList<int> neuronsPerLayer READ get_neuronsPerLayer WRITE set_neuronsPerLayer RESET reset_neuronsPerLayer STORED false) | |
| 65 | + BR_PROPERTY(QList<int>, neuronsPerLayer, QList<int>() << 6) | |
| 66 | + | |
| 67 | + CvANN_MLP mlp; | |
| 68 | + | |
| 69 | + void init() | |
| 70 | + { | |
| 71 | + Mat layers = Mat(hiddenLayers, 1, CV_32SC1); | |
| 72 | + for (int i=0; i<hiddenLayers; i++) { | |
| 73 | + layers.row(i) = Scalar(neuronsPerLayer.at(i)); | |
| 74 | + } | |
| 75 | + mlp.create(layers,CvANN_MLP::SIGMOID_SYM, 1, 1); | |
| 76 | + } | |
| 77 | + | |
| 78 | + void train(const TemplateList &data) | |
| 79 | + { | |
| 80 | + Mat _data = OpenCVUtils::toMat(data.data()); | |
| 81 | + | |
| 82 | + // Assuming data has n templates | |
| 83 | + // _data needs to be n x size of input layer | |
| 84 | + // Labels needs to be a n x outputs matrix | |
| 85 | + // For the time being we're going to assume a single output | |
| 86 | + Mat labels = Mat::zeros(data.size(),inputVariables.size(),CV_32F); | |
| 87 | + for (int i=0; i<inputVariables.size(); i++) | |
| 88 | + labels.col(i) += OpenCVUtils::toMat(File::get<float>(data, inputVariables.at(i))); | |
| 89 | + | |
| 90 | + mlp.train(_data,labels,Mat()); | |
| 91 | + } | |
| 92 | + | |
| 93 | + void project(const Template &src, Template &dst) const | |
| 94 | + { | |
| 95 | + // See above for response dimensionality | |
| 96 | + Mat response(outputVariables.size(), 1, CV_32FC1); | |
| 97 | + mlp.predict(src.m().reshape(1,1),response); | |
| 98 | + | |
| 99 | + for (int i=0; i<outputVariables.size(); i++) dst.file.set(outputVariables.at(i),response.at<float>(i,0)); | |
| 100 | + } | |
| 101 | + | |
| 102 | + void load(QDataStream &stream) | |
| 103 | + { | |
| 104 | + loadMLP(mlp,stream); | |
| 105 | + } | |
| 106 | + | |
| 107 | + void store(QDataStream &stream) const | |
| 108 | + { | |
| 109 | + storeMLP(mlp,stream); | |
| 110 | + } | |
| 111 | +}; | |
| 112 | + | |
| 113 | +BR_REGISTER(Transform, MLPTransform) | |
| 114 | + | |
| 115 | +} // namespace br | |
| 116 | + | |
| 117 | +#include "nn.moc" | ... | ... |