Commit ebc479299ac66c4ccdad97b284c61b4e9084db15

Authored by JordanCheney
2 parents d465545b 1d6d0144

Merge pull request #405 from biometrics/clustering

Meanshift clustering
openbr/plugins/cluster/meanshift.cpp 0 → 100644
  1 +#include <openbr/plugins/openbr_internal.h>
  2 +#include <openbr/core/opencvutils.h>
  3 +
  4 +using namespace cv;
  5 +
  6 +namespace br
  7 +{
  8 +
  9 +/*!
  10 + * \brief A transform implementing the mean shift clustering algorithm.
  11 + * \author Jordan Cheney \cite JordanCheney
  12 + * \br_property br::Distance* distance The distance used to compute the distance between templates
  13 + * \br_property int kernelBandwidth The size of the kernel used to converge points to a cluster center
  14 + * \br_property float shiftThreshold The cutoff threshold distance for a shifted point. A value lower then this threshold indicates a point has finished shifting to a cluster center.
  15 + * \br_property float distanceThreshold The distance threshold for a point to join a cluster. A point must be at least this close to another point to be included in the same cluster as that point.
  16 + * \br_link http://spin.atomicobject.com/2015/05/26/mean-shift-clustering/
  17 + */
  18 +class MeanShiftClusteringTransform : public TimeVaryingTransform
  19 +{
  20 + Q_OBJECT
  21 +
  22 + Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
  23 + Q_PROPERTY(int kernelBandwidth READ get_kernelBandwidth WRITE set_kernelBandwidth RESET reset_kernelBandwidth STORED false)
  24 + Q_PROPERTY(float shiftThreshold READ get_shiftThreshold WRITE set_shiftThreshold RESET reset_shiftThreshold STORED false)
  25 + Q_PROPERTY(float distanceThreshold READ get_distanceThreshold WRITE set_distanceThreshold RESET reset_distanceThreshold STORED false)
  26 + BR_PROPERTY(br::Distance*, distance, Distance::make(".Dist(L2, false)", NULL))
  27 + BR_PROPERTY(int, kernelBandwidth, 3)
  28 + BR_PROPERTY(float, shiftThreshold, 1e-3)
  29 + BR_PROPERTY(float, distanceThreshold, 1e-1)
  30 +
  31 +public:
  32 + MeanShiftClusteringTransform() : TimeVaryingTransform(false, false) {}
  33 +
  34 +private:
  35 + void projectUpdate(const TemplateList &src, TemplateList &)
  36 + {
  37 + templates.append(src);
  38 + }
  39 +
  40 + void finalize(TemplateList &output)
  41 + {
  42 + output.clear();
  43 +
  44 + QList<Mat> original_points, shifted_points;
  45 + original_points = shifted_points = templates.data();
  46 +
  47 + Mat shift_mask = Mat::zeros(1, shifted_points.size(), CV_32S);
  48 + while (countNonZero(shift_mask) != shifted_points.size()) {
  49 + for (int i = 0; i < shifted_points.size(); i++) {
  50 + if (shift_mask.at<int>(0, i))
  51 + continue;
  52 +
  53 + Mat point = shifted_points[i];
  54 + Mat shifted_point = point.clone();
  55 + meanshift(shifted_point, original_points);
  56 +
  57 + float dist = distance->compare(point, shifted_point);
  58 + if (dist < shiftThreshold)
  59 + shift_mask.at<int>(0, i) = 1;
  60 +
  61 + shifted_points[i] = shifted_point;
  62 + }
  63 + }
  64 +
  65 + QList<int> clusters = assignClusterID(shifted_points);
  66 + for (int i = 0; i < templates.size(); i++)
  67 + templates[i].file.set("Cluster", clusters[i]);
  68 + output.append(templates);
  69 + }
  70 +
  71 + void meanshift(Mat &point, const QList<Mat> &original_points)
  72 + {
  73 + Mat distances(1, original_points.size(), CV_32FC1);
  74 + for (int i = 0; i < original_points.size(); i++)
  75 + distances.at<float>(0, i) = distance->compare(point, original_points[i]);
  76 +
  77 + Mat weights = gaussianKernel(distances, kernelBandwidth);
  78 + point = (weights * OpenCVUtils::toMat(original_points)) / sum(weights)[0];
  79 + }
  80 +
  81 + inline Mat gaussianKernel(const Mat &distance, const float bandwidth)
  82 + {
  83 + Mat p, e;
  84 + pow(distance / bandwidth, 2, p);
  85 + exp(-0.5 * p, e);
  86 +
  87 + return (1.0 / (bandwidth * sqrt(2 * M_PI))) * e;
  88 + }
  89 +
  90 + QList<int> assignClusterID(const QList<Mat> &points)
  91 + {
  92 + QList<int> groups;
  93 + int newGroupIdx = 0;
  94 + foreach (const Mat &point, points) {
  95 + int group = nearestGroup(point, points, groups);
  96 + if (group < 0)
  97 + group = newGroupIdx++;
  98 + groups.append(group);
  99 + }
  100 +
  101 + if (Globals->verbose)
  102 + qDebug("created %d clusters from %d templates", newGroupIdx, points.size());
  103 +
  104 + return groups;
  105 + }
  106 +
  107 + int nearestGroup(const Mat &point, const QList<Mat> &points, const QList<int> groups)
  108 + {
  109 + for (int i = 0; i < groups.size(); i++) {
  110 + float dist = distance->compare(point, points[i]);
  111 + if (dist < distanceThreshold)
  112 + return groups[i];
  113 + }
  114 + return -1;
  115 + }
  116 +
  117 + TemplateList templates;
  118 +};
  119 +
  120 +BR_REGISTER(Transform, MeanShiftClusteringTransform)
  121 +
  122 +} // namespace br
  123 +
  124 +#include "cluster/meanshift.moc"
... ...