Commit 9a608d692ffe7b5e7743f1287f262f41c9856a8e
Merge pull request #262 from biometrics/neural_networks
Neural networks
Showing
6 changed files
with
295 additions
and
18 deletions
openbr/core/eval.cpp
| @@ -1044,12 +1044,14 @@ void EvalRegression(const QString &predictedGallery, const QString &truthGallery | @@ -1044,12 +1044,14 @@ void EvalRegression(const QString &predictedGallery, const QString &truthGallery | ||
| 1044 | if (predicted[i].file.name != truth[i].file.name) | 1044 | if (predicted[i].file.name != truth[i].file.name) |
| 1045 | qFatal("Input order mismatch."); | 1045 | qFatal("Input order mismatch."); |
| 1046 | 1046 | ||
| 1047 | - float difference = predicted[i].file.get<float>(predictedProperty) - truth[i].file.get<float>(truthProperty); | 1047 | + if (predicted[i].file.contains(predictedProperty) && truth[i].file.contains(truthProperty)) { |
| 1048 | + float difference = predicted[i].file.get<float>(predictedProperty) - truth[i].file.get<float>(truthProperty); | ||
| 1048 | 1049 | ||
| 1049 | - rmsError += pow(difference, 2.f); | ||
| 1050 | - maeError += fabsf(difference); | ||
| 1051 | - truthValues.append(QString::number(truth[i].file.get<float>(truthProperty))); | ||
| 1052 | - predictedValues.append(QString::number(predicted[i].file.get<float>(predictedProperty))); | 1050 | + rmsError += pow(difference, 2.f); |
| 1051 | + maeError += fabsf(difference); | ||
| 1052 | + truthValues.append(QString::number(truth[i].file.get<float>(truthProperty))); | ||
| 1053 | + predictedValues.append(QString::number(predicted[i].file.get<float>(predictedProperty))); | ||
| 1054 | + } | ||
| 1053 | } | 1055 | } |
| 1054 | 1056 | ||
| 1055 | QStringList rSource; | 1057 | QStringList rSource; |
openbr/core/opencvutils.cpp
| @@ -121,6 +121,17 @@ Mat OpenCVUtils::toMat(const QList<float> &src, int rows) | @@ -121,6 +121,17 @@ Mat OpenCVUtils::toMat(const QList<float> &src, int rows) | ||
| 121 | return dst; | 121 | return dst; |
| 122 | } | 122 | } |
| 123 | 123 | ||
| 124 | +Mat OpenCVUtils::pointsToMatrix(const QList<QPointF> &qPoints) | ||
| 125 | +{ | ||
| 126 | + QList<float> points; | ||
| 127 | + foreach(const QPointF &point, qPoints) { | ||
| 128 | + points.append(point.x()); | ||
| 129 | + points.append(point.y()); | ||
| 130 | + } | ||
| 131 | + | ||
| 132 | + return toMat(points); | ||
| 133 | +} | ||
| 134 | + | ||
| 124 | Mat OpenCVUtils::toMat(const QList<QList<float> > &srcs, int rows) | 135 | Mat OpenCVUtils::toMat(const QList<QList<float> > &srcs, int rows) |
| 125 | { | 136 | { |
| 126 | QList<float> flat; | 137 | QList<float> flat; |
openbr/core/opencvutils.h
| @@ -84,6 +84,7 @@ namespace OpenCVUtils | @@ -84,6 +84,7 @@ namespace OpenCVUtils | ||
| 84 | QPointF fromPoint(const cv::Point2f &cvPoint); | 84 | QPointF fromPoint(const cv::Point2f &cvPoint); |
| 85 | QList<cv::Point2f> toPoints(const QList<QPointF> &qPoints); | 85 | QList<cv::Point2f> toPoints(const QList<QPointF> &qPoints); |
| 86 | QList<QPointF> fromPoints(const QList<cv::Point2f> &cvPoints); | 86 | QList<QPointF> fromPoints(const QList<cv::Point2f> &cvPoints); |
| 87 | + cv::Mat pointsToMatrix(const QList<QPointF> &qPoints); | ||
| 87 | cv::Rect toRect(const QRectF &qRect); | 88 | cv::Rect toRect(const QRectF &qRect); |
| 88 | QRectF fromRect(const cv::Rect &cvRect); | 89 | QRectF fromRect(const cv::Rect &cvRect); |
| 89 | QList<cv::Rect> toRects(const QList<QRectF> &qRects); | 90 | QList<cv::Rect> toRects(const QList<QRectF> &qRects); |
openbr/plugins/landmarks.cpp
| @@ -369,12 +369,14 @@ BR_REGISTER(Transform, ReadLandmarksTransform) | @@ -369,12 +369,14 @@ BR_REGISTER(Transform, ReadLandmarksTransform) | ||
| 369 | 369 | ||
| 370 | /*! | 370 | /*! |
| 371 | * \ingroup transforms | 371 | * \ingroup transforms |
| 372 | - * \brief Name a point | 372 | + * \brief Name a point/rect |
| 373 | * \author Scott Klum \cite sklum | 373 | * \author Scott Klum \cite sklum |
| 374 | */ | 374 | */ |
| 375 | -class NamePointsTransform : public UntrainableMetadataTransform | 375 | +class NameLandmarksTransform : public UntrainableMetadataTransform |
| 376 | { | 376 | { |
| 377 | Q_OBJECT | 377 | Q_OBJECT |
| 378 | + Q_PROPERTY(bool point READ get_point WRITE set_point RESET reset_point STORED false) | ||
| 379 | + BR_PROPERTY(bool, point, true) | ||
| 378 | Q_PROPERTY(QList<int> indices READ get_indices WRITE set_indices RESET reset_indices STORED false) | 380 | Q_PROPERTY(QList<int> indices READ get_indices WRITE set_indices RESET reset_indices STORED false) |
| 379 | Q_PROPERTY(QStringList names READ get_names WRITE set_names RESET reset_names STORED false) | 381 | Q_PROPERTY(QStringList names READ get_names WRITE set_names RESET reset_names STORED false) |
| 380 | BR_PROPERTY(QList<int>, indices, QList<int>()) | 382 | BR_PROPERTY(QList<int>, indices, QList<int>()) |
| @@ -382,27 +384,36 @@ class NamePointsTransform : public UntrainableMetadataTransform | @@ -382,27 +384,36 @@ class NamePointsTransform : public UntrainableMetadataTransform | ||
| 382 | 384 | ||
| 383 | void projectMetadata(const File &src, File &dst) const | 385 | void projectMetadata(const File &src, File &dst) const |
| 384 | { | 386 | { |
| 385 | - if (indices.size() != names.size()) qFatal("Point/name size mismatch"); | 387 | + if (indices.size() != names.size()) qFatal("Index/name size mismatch"); |
| 386 | 388 | ||
| 387 | dst = src; | 389 | dst = src; |
| 388 | 390 | ||
| 389 | - QList<QPointF> points = src.points(); | 391 | + if (point) { |
| 392 | + QList<QPointF> points = src.points(); | ||
| 390 | 393 | ||
| 391 | - for (int i=0; i<indices.size(); i++) { | ||
| 392 | - if (indices[i] < points.size()) dst.set(names[i], points[indices[i]]); | ||
| 393 | - else qFatal("Index out of range."); | 394 | + for (int i=0; i<indices.size(); i++) { |
| 395 | + if (indices[i] < points.size()) dst.set(names[i], points[indices[i]]); | ||
| 396 | + else qFatal("Index out of range."); | ||
| 397 | + } | ||
| 398 | + } else { | ||
| 399 | + QList<QRectF> rects = src.rects(); | ||
| 400 | + | ||
| 401 | + for (int i=0; i<indices.size(); i++) { | ||
| 402 | + if (indices[i] < rects.size()) dst.set(names[i], rects[indices[i]]); | ||
| 403 | + else qFatal("Index out of range."); | ||
| 404 | + } | ||
| 394 | } | 405 | } |
| 395 | } | 406 | } |
| 396 | }; | 407 | }; |
| 397 | 408 | ||
| 398 | -BR_REGISTER(Transform, NamePointsTransform) | 409 | +BR_REGISTER(Transform, NameLandmarksTransform) |
| 399 | 410 | ||
| 400 | /*! | 411 | /*! |
| 401 | * \ingroup transforms | 412 | * \ingroup transforms |
| 402 | - * \brief Remove a name from a point | 413 | + * \brief Remove a name from a point/rect |
| 403 | * \author Scott Klum \cite sklum | 414 | * \author Scott Klum \cite sklum |
| 404 | */ | 415 | */ |
| 405 | -class AnonymizePointsTransform : public UntrainableMetadataTransform | 416 | +class AnonymizeLandmarksTransform : public UntrainableMetadataTransform |
| 406 | { | 417 | { |
| 407 | Q_OBJECT | 418 | Q_OBJECT |
| 408 | Q_PROPERTY(QStringList names READ get_names WRITE set_names RESET reset_names STORED false) | 419 | Q_PROPERTY(QStringList names READ get_names WRITE set_names RESET reset_names STORED false) |
| @@ -412,12 +423,112 @@ class AnonymizePointsTransform : public UntrainableMetadataTransform | @@ -412,12 +423,112 @@ class AnonymizePointsTransform : public UntrainableMetadataTransform | ||
| 412 | { | 423 | { |
| 413 | dst = src; | 424 | dst = src; |
| 414 | 425 | ||
| 415 | - foreach (const QString &name, names) | ||
| 416 | - if (src.contains(name)) dst.appendPoint(src.get<QPointF>(name)); | 426 | + foreach (const QString &name, names) { |
| 427 | + if (src.contains(name)) { | ||
| 428 | + QVariant variant = src.value(name); | ||
| 429 | + if (variant.canConvert(QMetaType::QPointF)) { | ||
| 430 | + dst.appendPoint(variant.toPointF()); | ||
| 431 | + } else if (variant.canConvert(QMetaType::QRectF)) { | ||
| 432 | + dst.appendRect(variant.toRectF()); | ||
| 433 | + } else { | ||
| 434 | + qFatal("Cannot convert landmark to point or rect."); | ||
| 435 | + } | ||
| 436 | + } | ||
| 437 | + } | ||
| 438 | + } | ||
| 439 | +}; | ||
| 440 | + | ||
| 441 | +BR_REGISTER(Transform, AnonymizeLandmarksTransform) | ||
| 442 | + | ||
| 443 | +/*! | ||
| 444 | + * \ingroup transforms | ||
| 445 | + * \brief Converts either the file::points() list or a QList<QPointF> metadata item to be the template's matrix | ||
| 446 | + * \author Scott Klum \cite sklum | ||
| 447 | + */ | ||
| 448 | +class PointsToMatrixTransform : public UntrainableTransform | ||
| 449 | +{ | ||
| 450 | + Q_OBJECT | ||
| 451 | + | ||
| 452 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | ||
| 453 | + BR_PROPERTY(QString, inputVariable, QString()) | ||
| 454 | + | ||
| 455 | + void project(const Template &src, Template &dst) const | ||
| 456 | + { | ||
| 457 | + dst = src; | ||
| 458 | + | ||
| 459 | + if (inputVariable.isEmpty()) { | ||
| 460 | + dst.m() = OpenCVUtils::pointsToMatrix(dst.file.points()); | ||
| 461 | + } else { | ||
| 462 | + if (src.file.contains(inputVariable)) | ||
| 463 | + dst.m() = OpenCVUtils::pointsToMatrix(dst.file.get<QList<QPointF> >(inputVariable)); | ||
| 464 | + } | ||
| 465 | + } | ||
| 466 | +}; | ||
| 467 | + | ||
| 468 | +BR_REGISTER(Transform, PointsToMatrixTransform) | ||
| 469 | + | ||
| 470 | +/*! | ||
| 471 | + * \ingroup transforms | ||
| 472 | + * \brief Normalize points to be relative to a single point | ||
| 473 | + * \author Scott Klum \cite sklum | ||
| 474 | + */ | ||
| 475 | +class NormalizePointsTransform : public UntrainableTransform | ||
| 476 | +{ | ||
| 477 | + Q_OBJECT | ||
| 478 | + | ||
| 479 | + Q_PROPERTY(int index READ get_index WRITE set_index RESET reset_index STORED false) | ||
| 480 | + BR_PROPERTY(int, index, 0) | ||
| 481 | + | ||
| 482 | + void project(const Template &src, Template &dst) const | ||
| 483 | + { | ||
| 484 | + dst = src; | ||
| 485 | + | ||
| 486 | + QList<QPointF> points = dst.file.points(); | ||
| 487 | + QPointF normPoint = points.at(index); | ||
| 488 | + | ||
| 489 | + QList<QPointF> normalizedPoints; | ||
| 490 | + | ||
| 491 | + for (int i=0; i<points.size(); i++) | ||
| 492 | + if (i!=index) | ||
| 493 | + normalizedPoints.append(normPoint-points[i]); | ||
| 494 | + | ||
| 495 | + dst.file.setPoints(normalizedPoints); | ||
| 496 | + } | ||
| 497 | +}; | ||
| 498 | + | ||
| 499 | +BR_REGISTER(Transform, NormalizePointsTransform) | ||
| 500 | + | ||
| 501 | +/*! | ||
| 502 | + * \ingroup transforms | ||
| 503 | + * \brief Normalize points to be relative to a single point | ||
| 504 | + * \author Scott Klum \cite sklum | ||
| 505 | + */ | ||
| 506 | +class PointDisplacementTransform : public UntrainableTransform | ||
| 507 | +{ | ||
| 508 | + Q_OBJECT | ||
| 509 | + | ||
| 510 | + void project(const Template &src, Template &dst) const | ||
| 511 | + { | ||
| 512 | + dst = src; | ||
| 513 | + | ||
| 514 | + QList<QPointF> points = dst.file.points(); | ||
| 515 | + QList<QPointF> normalizedPoints; | ||
| 516 | + | ||
| 517 | + for (int i=0; i<points.size(); i++) | ||
| 518 | + for (int j=0; j<points.size(); j++) | ||
| 519 | + // There is redundant information here | ||
| 520 | + if (j!=i) { | ||
| 521 | + QPointF normalizedPoint = points[i]-points[j]; | ||
| 522 | + normalizedPoint.setX(pow(normalizedPoint.x(),2)); | ||
| 523 | + normalizedPoint.setY(pow(normalizedPoint.y(),2)); | ||
| 524 | + normalizedPoints.append(normalizedPoint); | ||
| 525 | + } | ||
| 526 | + | ||
| 527 | + dst.file.setPoints(normalizedPoints); | ||
| 417 | } | 528 | } |
| 418 | }; | 529 | }; |
| 419 | 530 | ||
| 420 | -BR_REGISTER(Transform, AnonymizePointsTransform) | 531 | +BR_REGISTER(Transform, PointDisplacementTransform) |
| 421 | 532 | ||
| 422 | } // namespace br | 533 | } // namespace br |
| 423 | 534 |
openbr/plugins/misc.cpp
| @@ -937,6 +937,18 @@ class FileExclusionTransform : public UntrainableMetaTransform | @@ -937,6 +937,18 @@ class FileExclusionTransform : public UntrainableMetaTransform | ||
| 937 | 937 | ||
| 938 | BR_REGISTER(Transform, FileExclusionTransform) | 938 | BR_REGISTER(Transform, FileExclusionTransform) |
| 939 | 939 | ||
| 940 | +class TransposeTransform : public UntrainableTransform | ||
| 941 | +{ | ||
| 942 | + Q_OBJECT | ||
| 943 | + | ||
| 944 | + void project(const Template &src, Template &dst) const | ||
| 945 | + { | ||
| 946 | + dst.m() = src.m().t(); | ||
| 947 | + } | ||
| 948 | +}; | ||
| 949 | + | ||
| 950 | +BR_REGISTER(Transform, TransposeTransform) | ||
| 951 | + | ||
| 940 | } | 952 | } |
| 941 | 953 | ||
| 942 | #include "misc.moc" | 954 | #include "misc.moc" |
openbr/plugins/nn.cpp
0 โ 100644
| 1 | +#include <opencv2/ml/ml.hpp> | ||
| 2 | + | ||
| 3 | +#include "openbr_internal.h" | ||
| 4 | +#include "openbr/core/qtutils.h" | ||
| 5 | +#include "openbr/core/opencvutils.h" | ||
| 6 | +#include "openbr/core/eigenutils.h" | ||
| 7 | +#include <QString> | ||
| 8 | +#include <QTemporaryFile> | ||
| 9 | + | ||
| 10 | +using namespace std; | ||
| 11 | +using namespace cv; | ||
| 12 | + | ||
| 13 | +namespace br | ||
| 14 | +{ | ||
| 15 | + | ||
| 16 | +static void storeMLP(const CvANN_MLP &mlp, QDataStream &stream) | ||
| 17 | +{ | ||
| 18 | + // Create local file | ||
| 19 | + QTemporaryFile tempFile; | ||
| 20 | + tempFile.open(); | ||
| 21 | + tempFile.close(); | ||
| 22 | + | ||
| 23 | + // Save MLP to local file | ||
| 24 | + mlp.save(qPrintable(tempFile.fileName())); | ||
| 25 | + | ||
| 26 | + // Copy local file contents to stream | ||
| 27 | + tempFile.open(); | ||
| 28 | + QByteArray data = tempFile.readAll(); | ||
| 29 | + tempFile.close(); | ||
| 30 | + stream << data; | ||
| 31 | +} | ||
| 32 | + | ||
| 33 | +static void loadMLP(CvANN_MLP &mlp, QDataStream &stream) | ||
| 34 | +{ | ||
| 35 | + // Copy local file contents from stream | ||
| 36 | + QByteArray data; | ||
| 37 | + stream >> data; | ||
| 38 | + | ||
| 39 | + // Create local file | ||
| 40 | + QTemporaryFile tempFile(QDir::tempPath()+"/MLP"); | ||
| 41 | + tempFile.open(); | ||
| 42 | + tempFile.write(data); | ||
| 43 | + tempFile.close(); | ||
| 44 | + | ||
| 45 | + // Load MLP from local file | ||
| 46 | + mlp.load(qPrintable(tempFile.fileName())); | ||
| 47 | +} | ||
| 48 | + | ||
| 49 | +/*! | ||
| 50 | + * \ingroup transforms | ||
| 51 | + * \brief Wraps OpenCV's multi-layer perceptron framework | ||
| 52 | + * \author Scott Klum \cite sklum | ||
| 53 | + * \brief http://docs.opencv.org/modules/ml/doc/neural_networks.html | ||
| 54 | + */ | ||
| 55 | +class MLPTransform : public MetaTransform | ||
| 56 | +{ | ||
| 57 | + Q_OBJECT | ||
| 58 | + | ||
| 59 | + Q_ENUMS(Kernel) | ||
| 60 | + Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) | ||
| 61 | + Q_PROPERTY(float alpha READ get_alpha WRITE set_alpha RESET reset_alpha STORED false) | ||
| 62 | + Q_PROPERTY(float beta READ get_beta WRITE set_beta RESET reset_beta STORED false) | ||
| 63 | + Q_PROPERTY(QStringList inputVariables READ get_inputVariables WRITE set_inputVariables RESET reset_inputVariables STORED false) | ||
| 64 | + Q_PROPERTY(QStringList outputVariables READ get_outputVariables WRITE set_outputVariables RESET reset_outputVariables STORED false) | ||
| 65 | + Q_PROPERTY(QList<int> neuronsPerLayer READ get_neuronsPerLayer WRITE set_neuronsPerLayer RESET reset_neuronsPerLayer STORED false) | ||
| 66 | + | ||
| 67 | +public: | ||
| 68 | + | ||
| 69 | + enum Kernel { Identity = CvANN_MLP::IDENTITY, | ||
| 70 | + Sigmoid = CvANN_MLP::SIGMOID_SYM, | ||
| 71 | + Gaussian = CvANN_MLP::GAUSSIAN}; | ||
| 72 | + | ||
| 73 | +private: | ||
| 74 | + BR_PROPERTY(Kernel, kernel, Sigmoid) | ||
| 75 | + BR_PROPERTY(float, alpha, 1) | ||
| 76 | + BR_PROPERTY(float, beta, 1) | ||
| 77 | + BR_PROPERTY(QStringList, inputVariables, QStringList()) | ||
| 78 | + BR_PROPERTY(QStringList, outputVariables, QStringList()) | ||
| 79 | + BR_PROPERTY(QList<int>, neuronsPerLayer, QList<int>() << 1 << 1) | ||
| 80 | + | ||
| 81 | + CvANN_MLP mlp; | ||
| 82 | + | ||
| 83 | + void init() | ||
| 84 | + { | ||
| 85 | + if (kernel == Gaussian) | ||
| 86 | + qWarning("The OpenCV documentation warns that the Gaussian kernel, \"is not completely supported at the moment\""); | ||
| 87 | + | ||
| 88 | + Mat layers = Mat(neuronsPerLayer.size(), 1, CV_32SC1); | ||
| 89 | + for (int i=0; i<neuronsPerLayer.size(); i++) | ||
| 90 | + layers.row(i) = Scalar(neuronsPerLayer.at(i)); | ||
| 91 | + | ||
| 92 | + mlp.create(layers,kernel, alpha, beta); | ||
| 93 | + } | ||
| 94 | + | ||
| 95 | + void train(const TemplateList &data) | ||
| 96 | + { | ||
| 97 | + Mat _data = OpenCVUtils::toMat(data.data()); | ||
| 98 | + | ||
| 99 | + // Assuming data has n templates | ||
| 100 | + // _data needs to be n x size of input layer | ||
| 101 | + // Labels needs to be a n x outputs matrix | ||
| 102 | + // For the time being we're going to assume a single output | ||
| 103 | + Mat labels = Mat::zeros(data.size(),inputVariables.size(),CV_32F); | ||
| 104 | + for (int i=0; i<inputVariables.size(); i++) | ||
| 105 | + labels.col(i) += OpenCVUtils::toMat(File::get<float>(data, inputVariables.at(i))); | ||
| 106 | + | ||
| 107 | + mlp.train(_data,labels,Mat()); | ||
| 108 | + | ||
| 109 | + if (Globals->verbose) | ||
| 110 | + for (int i=0; i<neuronsPerLayer.size(); i++) qDebug() << *mlp.get_weights(i); | ||
| 111 | + } | ||
| 112 | + | ||
| 113 | + void project(const Template &src, Template &dst) const | ||
| 114 | + { | ||
| 115 | + dst = src; | ||
| 116 | + | ||
| 117 | + // See above for response dimensionality | ||
| 118 | + Mat response(outputVariables.size(), 1, CV_32FC1); | ||
| 119 | + mlp.predict(src.m().reshape(1,1),response); | ||
| 120 | + | ||
| 121 | + // Apparently mlp.predict reshapes the response matrix? | ||
| 122 | + for (int i=0; i<outputVariables.size(); i++) dst.file.set(outputVariables.at(i),response.at<float>(0,i)); | ||
| 123 | + } | ||
| 124 | + | ||
| 125 | + void load(QDataStream &stream) | ||
| 126 | + { | ||
| 127 | + loadMLP(mlp,stream); | ||
| 128 | + } | ||
| 129 | + | ||
| 130 | + void store(QDataStream &stream) const | ||
| 131 | + { | ||
| 132 | + storeMLP(mlp,stream); | ||
| 133 | + } | ||
| 134 | +}; | ||
| 135 | + | ||
| 136 | +BR_REGISTER(Transform, MLPTransform) | ||
| 137 | + | ||
| 138 | +} // namespace br | ||
| 139 | + | ||
| 140 | +#include "nn.moc" |