Commit 9a608d692ffe7b5e7743f1287f262f41c9856a8e
Merge pull request #262 from biometrics/neural_networks
Neural networks
Showing
6 changed files
with
295 additions
and
18 deletions
openbr/core/eval.cpp
| ... | ... | @@ -1044,12 +1044,14 @@ void EvalRegression(const QString &predictedGallery, const QString &truthGallery |
| 1044 | 1044 | if (predicted[i].file.name != truth[i].file.name) |
| 1045 | 1045 | qFatal("Input order mismatch."); |
| 1046 | 1046 | |
| 1047 | - float difference = predicted[i].file.get<float>(predictedProperty) - truth[i].file.get<float>(truthProperty); | |
| 1047 | + if (predicted[i].file.contains(predictedProperty) && truth[i].file.contains(truthProperty)) { | |
| 1048 | + float difference = predicted[i].file.get<float>(predictedProperty) - truth[i].file.get<float>(truthProperty); | |
| 1048 | 1049 | |
| 1049 | - rmsError += pow(difference, 2.f); | |
| 1050 | - maeError += fabsf(difference); | |
| 1051 | - truthValues.append(QString::number(truth[i].file.get<float>(truthProperty))); | |
| 1052 | - predictedValues.append(QString::number(predicted[i].file.get<float>(predictedProperty))); | |
| 1050 | + rmsError += pow(difference, 2.f); | |
| 1051 | + maeError += fabsf(difference); | |
| 1052 | + truthValues.append(QString::number(truth[i].file.get<float>(truthProperty))); | |
| 1053 | + predictedValues.append(QString::number(predicted[i].file.get<float>(predictedProperty))); | |
| 1054 | + } | |
| 1053 | 1055 | } |
| 1054 | 1056 | |
| 1055 | 1057 | QStringList rSource; | ... | ... |
openbr/core/opencvutils.cpp
| ... | ... | @@ -121,6 +121,17 @@ Mat OpenCVUtils::toMat(const QList<float> &src, int rows) |
| 121 | 121 | return dst; |
| 122 | 122 | } |
| 123 | 123 | |
| 124 | +Mat OpenCVUtils::pointsToMatrix(const QList<QPointF> &qPoints) | |
| 125 | +{ | |
| 126 | + QList<float> points; | |
| 127 | + foreach(const QPointF &point, qPoints) { | |
| 128 | + points.append(point.x()); | |
| 129 | + points.append(point.y()); | |
| 130 | + } | |
| 131 | + | |
| 132 | + return toMat(points); | |
| 133 | +} | |
| 134 | + | |
| 124 | 135 | Mat OpenCVUtils::toMat(const QList<QList<float> > &srcs, int rows) |
| 125 | 136 | { |
| 126 | 137 | QList<float> flat; | ... | ... |
openbr/core/opencvutils.h
| ... | ... | @@ -84,6 +84,7 @@ namespace OpenCVUtils |
| 84 | 84 | QPointF fromPoint(const cv::Point2f &cvPoint); |
| 85 | 85 | QList<cv::Point2f> toPoints(const QList<QPointF> &qPoints); |
| 86 | 86 | QList<QPointF> fromPoints(const QList<cv::Point2f> &cvPoints); |
| 87 | + cv::Mat pointsToMatrix(const QList<QPointF> &qPoints); | |
| 87 | 88 | cv::Rect toRect(const QRectF &qRect); |
| 88 | 89 | QRectF fromRect(const cv::Rect &cvRect); |
| 89 | 90 | QList<cv::Rect> toRects(const QList<QRectF> &qRects); | ... | ... |
openbr/plugins/landmarks.cpp
| ... | ... | @@ -369,12 +369,14 @@ BR_REGISTER(Transform, ReadLandmarksTransform) |
| 369 | 369 | |
| 370 | 370 | /*! |
| 371 | 371 | * \ingroup transforms |
| 372 | - * \brief Name a point | |
| 372 | + * \brief Name a point/rect | |
| 373 | 373 | * \author Scott Klum \cite sklum |
| 374 | 374 | */ |
| 375 | -class NamePointsTransform : public UntrainableMetadataTransform | |
| 375 | +class NameLandmarksTransform : public UntrainableMetadataTransform | |
| 376 | 376 | { |
| 377 | 377 | Q_OBJECT |
| 378 | + Q_PROPERTY(bool point READ get_point WRITE set_point RESET reset_point STORED false) | |
| 379 | + BR_PROPERTY(bool, point, true) | |
| 378 | 380 | Q_PROPERTY(QList<int> indices READ get_indices WRITE set_indices RESET reset_indices STORED false) |
| 379 | 381 | Q_PROPERTY(QStringList names READ get_names WRITE set_names RESET reset_names STORED false) |
| 380 | 382 | BR_PROPERTY(QList<int>, indices, QList<int>()) |
| ... | ... | @@ -382,27 +384,36 @@ class NamePointsTransform : public UntrainableMetadataTransform |
| 382 | 384 | |
| 383 | 385 | void projectMetadata(const File &src, File &dst) const |
| 384 | 386 | { |
| 385 | - if (indices.size() != names.size()) qFatal("Point/name size mismatch"); | |
| 387 | + if (indices.size() != names.size()) qFatal("Index/name size mismatch"); | |
| 386 | 388 | |
| 387 | 389 | dst = src; |
| 388 | 390 | |
| 389 | - QList<QPointF> points = src.points(); | |
| 391 | + if (point) { | |
| 392 | + QList<QPointF> points = src.points(); | |
| 390 | 393 | |
| 391 | - for (int i=0; i<indices.size(); i++) { | |
| 392 | - if (indices[i] < points.size()) dst.set(names[i], points[indices[i]]); | |
| 393 | - else qFatal("Index out of range."); | |
| 394 | + for (int i=0; i<indices.size(); i++) { | |
| 395 | + if (indices[i] < points.size()) dst.set(names[i], points[indices[i]]); | |
| 396 | + else qFatal("Index out of range."); | |
| 397 | + } | |
| 398 | + } else { | |
| 399 | + QList<QRectF> rects = src.rects(); | |
| 400 | + | |
| 401 | + for (int i=0; i<indices.size(); i++) { | |
| 402 | + if (indices[i] < rects.size()) dst.set(names[i], rects[indices[i]]); | |
| 403 | + else qFatal("Index out of range."); | |
| 404 | + } | |
| 394 | 405 | } |
| 395 | 406 | } |
| 396 | 407 | }; |
| 397 | 408 | |
| 398 | -BR_REGISTER(Transform, NamePointsTransform) | |
| 409 | +BR_REGISTER(Transform, NameLandmarksTransform) | |
| 399 | 410 | |
| 400 | 411 | /*! |
| 401 | 412 | * \ingroup transforms |
| 402 | - * \brief Remove a name from a point | |
| 413 | + * \brief Remove a name from a point/rect | |
| 403 | 414 | * \author Scott Klum \cite sklum |
| 404 | 415 | */ |
| 405 | -class AnonymizePointsTransform : public UntrainableMetadataTransform | |
| 416 | +class AnonymizeLandmarksTransform : public UntrainableMetadataTransform | |
| 406 | 417 | { |
| 407 | 418 | Q_OBJECT |
| 408 | 419 | Q_PROPERTY(QStringList names READ get_names WRITE set_names RESET reset_names STORED false) |
| ... | ... | @@ -412,12 +423,112 @@ class AnonymizePointsTransform : public UntrainableMetadataTransform |
| 412 | 423 | { |
| 413 | 424 | dst = src; |
| 414 | 425 | |
| 415 | - foreach (const QString &name, names) | |
| 416 | - if (src.contains(name)) dst.appendPoint(src.get<QPointF>(name)); | |
| 426 | + foreach (const QString &name, names) { | |
| 427 | + if (src.contains(name)) { | |
| 428 | + QVariant variant = src.value(name); | |
| 429 | + if (variant.canConvert(QMetaType::QPointF)) { | |
| 430 | + dst.appendPoint(variant.toPointF()); | |
| 431 | + } else if (variant.canConvert(QMetaType::QRectF)) { | |
| 432 | + dst.appendRect(variant.toRectF()); | |
| 433 | + } else { | |
| 434 | + qFatal("Cannot convert landmark to point or rect."); | |
| 435 | + } | |
| 436 | + } | |
| 437 | + } | |
| 438 | + } | |
| 439 | +}; | |
| 440 | + | |
| 441 | +BR_REGISTER(Transform, AnonymizeLandmarksTransform) | |
| 442 | + | |
| 443 | +/*! | |
| 444 | + * \ingroup transforms | |
| 445 | + * \brief Converts either the file::points() list or a QList<QPointF> metadata item to be the template's matrix | |
| 446 | + * \author Scott Klum \cite sklum | |
| 447 | + */ | |
| 448 | +class PointsToMatrixTransform : public UntrainableTransform | |
| 449 | +{ | |
| 450 | + Q_OBJECT | |
| 451 | + | |
| 452 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | |
| 453 | + BR_PROPERTY(QString, inputVariable, QString()) | |
| 454 | + | |
| 455 | + void project(const Template &src, Template &dst) const | |
| 456 | + { | |
| 457 | + dst = src; | |
| 458 | + | |
| 459 | + if (inputVariable.isEmpty()) { | |
| 460 | + dst.m() = OpenCVUtils::pointsToMatrix(dst.file.points()); | |
| 461 | + } else { | |
| 462 | + if (src.file.contains(inputVariable)) | |
| 463 | + dst.m() = OpenCVUtils::pointsToMatrix(dst.file.get<QList<QPointF> >(inputVariable)); | |
| 464 | + } | |
| 465 | + } | |
| 466 | +}; | |
| 467 | + | |
| 468 | +BR_REGISTER(Transform, PointsToMatrixTransform) | |
| 469 | + | |
| 470 | +/*! | |
| 471 | + * \ingroup transforms | |
| 472 | + * \brief Normalize points to be relative to a single point | |
| 473 | + * \author Scott Klum \cite sklum | |
| 474 | + */ | |
| 475 | +class NormalizePointsTransform : public UntrainableTransform | |
| 476 | +{ | |
| 477 | + Q_OBJECT | |
| 478 | + | |
| 479 | + Q_PROPERTY(int index READ get_index WRITE set_index RESET reset_index STORED false) | |
| 480 | + BR_PROPERTY(int, index, 0) | |
| 481 | + | |
| 482 | + void project(const Template &src, Template &dst) const | |
| 483 | + { | |
| 484 | + dst = src; | |
| 485 | + | |
| 486 | + QList<QPointF> points = dst.file.points(); | |
| 487 | + QPointF normPoint = points.at(index); | |
| 488 | + | |
| 489 | + QList<QPointF> normalizedPoints; | |
| 490 | + | |
| 491 | + for (int i=0; i<points.size(); i++) | |
| 492 | + if (i!=index) | |
| 493 | + normalizedPoints.append(normPoint-points[i]); | |
| 494 | + | |
| 495 | + dst.file.setPoints(normalizedPoints); | |
| 496 | + } | |
| 497 | +}; | |
| 498 | + | |
| 499 | +BR_REGISTER(Transform, NormalizePointsTransform) | |
| 500 | + | |
| 501 | +/*! | |
| 502 | + * \ingroup transforms | |
| 503 | + * \brief Normalize points to be relative to a single point | |
| 504 | + * \author Scott Klum \cite sklum | |
| 505 | + */ | |
| 506 | +class PointDisplacementTransform : public UntrainableTransform | |
| 507 | +{ | |
| 508 | + Q_OBJECT | |
| 509 | + | |
| 510 | + void project(const Template &src, Template &dst) const | |
| 511 | + { | |
| 512 | + dst = src; | |
| 513 | + | |
| 514 | + QList<QPointF> points = dst.file.points(); | |
| 515 | + QList<QPointF> normalizedPoints; | |
| 516 | + | |
| 517 | + for (int i=0; i<points.size(); i++) | |
| 518 | + for (int j=0; j<points.size(); j++) | |
| 519 | + // There is redundant information here | |
| 520 | + if (j!=i) { | |
| 521 | + QPointF normalizedPoint = points[i]-points[j]; | |
| 522 | + normalizedPoint.setX(pow(normalizedPoint.x(),2)); | |
| 523 | + normalizedPoint.setY(pow(normalizedPoint.y(),2)); | |
| 524 | + normalizedPoints.append(normalizedPoint); | |
| 525 | + } | |
| 526 | + | |
| 527 | + dst.file.setPoints(normalizedPoints); | |
| 417 | 528 | } |
| 418 | 529 | }; |
| 419 | 530 | |
| 420 | -BR_REGISTER(Transform, AnonymizePointsTransform) | |
| 531 | +BR_REGISTER(Transform, PointDisplacementTransform) | |
| 421 | 532 | |
| 422 | 533 | } // namespace br |
| 423 | 534 | ... | ... |
openbr/plugins/misc.cpp
| ... | ... | @@ -937,6 +937,18 @@ class FileExclusionTransform : public UntrainableMetaTransform |
| 937 | 937 | |
| 938 | 938 | BR_REGISTER(Transform, FileExclusionTransform) |
| 939 | 939 | |
| 940 | +class TransposeTransform : public UntrainableTransform | |
| 941 | +{ | |
| 942 | + Q_OBJECT | |
| 943 | + | |
| 944 | + void project(const Template &src, Template &dst) const | |
| 945 | + { | |
| 946 | + dst.m() = src.m().t(); | |
| 947 | + } | |
| 948 | +}; | |
| 949 | + | |
| 950 | +BR_REGISTER(Transform, TransposeTransform) | |
| 951 | + | |
| 940 | 952 | } |
| 941 | 953 | |
| 942 | 954 | #include "misc.moc" | ... | ... |
openbr/plugins/nn.cpp
0 โ 100644
| 1 | +#include <opencv2/ml/ml.hpp> | |
| 2 | + | |
| 3 | +#include "openbr_internal.h" | |
| 4 | +#include "openbr/core/qtutils.h" | |
| 5 | +#include "openbr/core/opencvutils.h" | |
| 6 | +#include "openbr/core/eigenutils.h" | |
| 7 | +#include <QString> | |
| 8 | +#include <QTemporaryFile> | |
| 9 | + | |
| 10 | +using namespace std; | |
| 11 | +using namespace cv; | |
| 12 | + | |
| 13 | +namespace br | |
| 14 | +{ | |
| 15 | + | |
| 16 | +static void storeMLP(const CvANN_MLP &mlp, QDataStream &stream) | |
| 17 | +{ | |
| 18 | + // Create local file | |
| 19 | + QTemporaryFile tempFile; | |
| 20 | + tempFile.open(); | |
| 21 | + tempFile.close(); | |
| 22 | + | |
| 23 | + // Save MLP to local file | |
| 24 | + mlp.save(qPrintable(tempFile.fileName())); | |
| 25 | + | |
| 26 | + // Copy local file contents to stream | |
| 27 | + tempFile.open(); | |
| 28 | + QByteArray data = tempFile.readAll(); | |
| 29 | + tempFile.close(); | |
| 30 | + stream << data; | |
| 31 | +} | |
| 32 | + | |
| 33 | +static void loadMLP(CvANN_MLP &mlp, QDataStream &stream) | |
| 34 | +{ | |
| 35 | + // Copy local file contents from stream | |
| 36 | + QByteArray data; | |
| 37 | + stream >> data; | |
| 38 | + | |
| 39 | + // Create local file | |
| 40 | + QTemporaryFile tempFile(QDir::tempPath()+"/MLP"); | |
| 41 | + tempFile.open(); | |
| 42 | + tempFile.write(data); | |
| 43 | + tempFile.close(); | |
| 44 | + | |
| 45 | + // Load MLP from local file | |
| 46 | + mlp.load(qPrintable(tempFile.fileName())); | |
| 47 | +} | |
| 48 | + | |
| 49 | +/*! | |
| 50 | + * \ingroup transforms | |
| 51 | + * \brief Wraps OpenCV's multi-layer perceptron framework | |
| 52 | + * \author Scott Klum \cite sklum | |
| 53 | + * \brief http://docs.opencv.org/modules/ml/doc/neural_networks.html | |
| 54 | + */ | |
| 55 | +class MLPTransform : public MetaTransform | |
| 56 | +{ | |
| 57 | + Q_OBJECT | |
| 58 | + | |
| 59 | + Q_ENUMS(Kernel) | |
| 60 | + Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) | |
| 61 | + Q_PROPERTY(float alpha READ get_alpha WRITE set_alpha RESET reset_alpha STORED false) | |
| 62 | + Q_PROPERTY(float beta READ get_beta WRITE set_beta RESET reset_beta STORED false) | |
| 63 | + Q_PROPERTY(QStringList inputVariables READ get_inputVariables WRITE set_inputVariables RESET reset_inputVariables STORED false) | |
| 64 | + Q_PROPERTY(QStringList outputVariables READ get_outputVariables WRITE set_outputVariables RESET reset_outputVariables STORED false) | |
| 65 | + Q_PROPERTY(QList<int> neuronsPerLayer READ get_neuronsPerLayer WRITE set_neuronsPerLayer RESET reset_neuronsPerLayer STORED false) | |
| 66 | + | |
| 67 | +public: | |
| 68 | + | |
| 69 | + enum Kernel { Identity = CvANN_MLP::IDENTITY, | |
| 70 | + Sigmoid = CvANN_MLP::SIGMOID_SYM, | |
| 71 | + Gaussian = CvANN_MLP::GAUSSIAN}; | |
| 72 | + | |
| 73 | +private: | |
| 74 | + BR_PROPERTY(Kernel, kernel, Sigmoid) | |
| 75 | + BR_PROPERTY(float, alpha, 1) | |
| 76 | + BR_PROPERTY(float, beta, 1) | |
| 77 | + BR_PROPERTY(QStringList, inputVariables, QStringList()) | |
| 78 | + BR_PROPERTY(QStringList, outputVariables, QStringList()) | |
| 79 | + BR_PROPERTY(QList<int>, neuronsPerLayer, QList<int>() << 1 << 1) | |
| 80 | + | |
| 81 | + CvANN_MLP mlp; | |
| 82 | + | |
| 83 | + void init() | |
| 84 | + { | |
| 85 | + if (kernel == Gaussian) | |
| 86 | + qWarning("The OpenCV documentation warns that the Gaussian kernel, \"is not completely supported at the moment\""); | |
| 87 | + | |
| 88 | + Mat layers = Mat(neuronsPerLayer.size(), 1, CV_32SC1); | |
| 89 | + for (int i=0; i<neuronsPerLayer.size(); i++) | |
| 90 | + layers.row(i) = Scalar(neuronsPerLayer.at(i)); | |
| 91 | + | |
| 92 | + mlp.create(layers,kernel, alpha, beta); | |
| 93 | + } | |
| 94 | + | |
| 95 | + void train(const TemplateList &data) | |
| 96 | + { | |
| 97 | + Mat _data = OpenCVUtils::toMat(data.data()); | |
| 98 | + | |
| 99 | + // Assuming data has n templates | |
| 100 | + // _data needs to be n x size of input layer | |
| 101 | + // Labels needs to be a n x outputs matrix | |
| 102 | + // For the time being we're going to assume a single output | |
| 103 | + Mat labels = Mat::zeros(data.size(),inputVariables.size(),CV_32F); | |
| 104 | + for (int i=0; i<inputVariables.size(); i++) | |
| 105 | + labels.col(i) += OpenCVUtils::toMat(File::get<float>(data, inputVariables.at(i))); | |
| 106 | + | |
| 107 | + mlp.train(_data,labels,Mat()); | |
| 108 | + | |
| 109 | + if (Globals->verbose) | |
| 110 | + for (int i=0; i<neuronsPerLayer.size(); i++) qDebug() << *mlp.get_weights(i); | |
| 111 | + } | |
| 112 | + | |
| 113 | + void project(const Template &src, Template &dst) const | |
| 114 | + { | |
| 115 | + dst = src; | |
| 116 | + | |
| 117 | + // See above for response dimensionality | |
| 118 | + Mat response(outputVariables.size(), 1, CV_32FC1); | |
| 119 | + mlp.predict(src.m().reshape(1,1),response); | |
| 120 | + | |
| 121 | + // Apparently mlp.predict reshapes the response matrix? | |
| 122 | + for (int i=0; i<outputVariables.size(); i++) dst.file.set(outputVariables.at(i),response.at<float>(0,i)); | |
| 123 | + } | |
| 124 | + | |
| 125 | + void load(QDataStream &stream) | |
| 126 | + { | |
| 127 | + loadMLP(mlp,stream); | |
| 128 | + } | |
| 129 | + | |
| 130 | + void store(QDataStream &stream) const | |
| 131 | + { | |
| 132 | + storeMLP(mlp,stream); | |
| 133 | + } | |
| 134 | +}; | |
| 135 | + | |
| 136 | +BR_REGISTER(Transform, MLPTransform) | |
| 137 | + | |
| 138 | +} // namespace br | |
| 139 | + | |
| 140 | +#include "nn.moc" | ... | ... |