Commit 45c70cabec018bdc90020b24b36667ceee888184

Authored by Scott Klum
1 parent a09c27a1

Removed layer parameter

Showing 1 changed file with 11 additions and 7 deletions
openbr/plugins/nn.cpp
... ... @@ -59,20 +59,18 @@ class MLPTransform : public MetaTransform
59 59 BR_PROPERTY(QStringList, inputVariables, QStringList())
60 60 Q_PROPERTY(QStringList outputVariables READ get_outputVariables WRITE set_outputVariables RESET reset_outputVariables STORED false)
61 61 BR_PROPERTY(QStringList, outputVariables, QStringList())
62   - Q_PROPERTY(int hiddenLayers READ get_hiddenLayers WRITE set_hiddenLayers RESET reset_hiddenLayers STORED false)
63   - BR_PROPERTY(int, hiddenLayers, 1)
64 62 Q_PROPERTY(QList<int> neuronsPerLayer READ get_neuronsPerLayer WRITE set_neuronsPerLayer RESET reset_neuronsPerLayer STORED false)
65   - BR_PROPERTY(QList<int>, neuronsPerLayer, QList<int>() << 6)
  63 + BR_PROPERTY(QList<int>, neuronsPerLayer, QList<int>() << 1 << 1)
66 64  
67 65 CvANN_MLP mlp;
68 66  
69 67 void init()
70 68 {
71   - Mat layers = Mat(hiddenLayers, 1, CV_32SC1);
72   - for (int i=0; i<hiddenLayers; i++) {
  69 + Mat layers = Mat(neuronsPerLayer.size(), 1, CV_32SC1);
  70 + for (int i=0; i<neuronsPerLayer.size(); i++) {
73 71 layers.row(i) = Scalar(neuronsPerLayer.at(i));
74 72 }
75   - mlp.create(layers,CvANN_MLP::SIGMOID_SYM, 1, 1);
  73 + mlp.create(layers,CvANN_MLP::SIGMOID_SYM, .8, .6);
76 74 }
77 75  
78 76 void train(const TemplateList &data)
... ... @@ -88,15 +86,21 @@ class MLPTransform : public MetaTransform
88 86 labels.col(i) += OpenCVUtils::toMat(File::get<float>(data, inputVariables.at(i)));
89 87  
90 88 mlp.train(_data,labels,Mat());
  89 +
  90 + if (Globals->verbose)
  91 + for (int i=0; i<neuronsPerLayer.size(); i++) qDebug() << *mlp.get_weights(i);
91 92 }
92 93  
93 94 void project(const Template &src, Template &dst) const
94 95 {
  96 + dst = src;
  97 +
95 98 // See above for response dimensionality
96 99 Mat response(outputVariables.size(), 1, CV_32FC1);
97 100 mlp.predict(src.m().reshape(1,1),response);
98 101  
99   - for (int i=0; i<outputVariables.size(); i++) dst.file.set(outputVariables.at(i),response.at<float>(i,0));
  102 + // Apparently mlp.predict reshapes the response matrix?
  103 + for (int i=0; i<outputVariables.size(); i++) dst.file.set(outputVariables.at(i),response.at<float>(0,i));
100 104 }
101 105  
102 106 void load(QDataStream &stream)
... ...