Commit ebc479299ac66c4ccdad97b284c61b4e9084db15
Merge pull request #405 from biometrics/clustering
Meanshift clustering
Showing
1 changed file
with
124 additions
and
0 deletions
openbr/plugins/cluster/meanshift.cpp
0 → 100644
| 1 | +#include <openbr/plugins/openbr_internal.h> | ||
| 2 | +#include <openbr/core/opencvutils.h> | ||
| 3 | + | ||
| 4 | +using namespace cv; | ||
| 5 | + | ||
| 6 | +namespace br | ||
| 7 | +{ | ||
| 8 | + | ||
| 9 | +/*! | ||
| 10 | + * \brief A transform implementing the mean shift clustering algorithm. | ||
| 11 | + * \author Jordan Cheney \cite JordanCheney | ||
| 12 | + * \br_property br::Distance* distance The distance used to compute the distance between templates | ||
| 13 | + * \br_property int kernelBandwidth The size of the kernel used to converge points to a cluster center | ||
| 14 | + * \br_property float shiftThreshold The cutoff threshold distance for a shifted point. A value lower then this threshold indicates a point has finished shifting to a cluster center. | ||
| 15 | + * \br_property float distanceThreshold The distance threshold for a point to join a cluster. A point must be at least this close to another point to be included in the same cluster as that point. | ||
| 16 | + * \br_link http://spin.atomicobject.com/2015/05/26/mean-shift-clustering/ | ||
| 17 | + */ | ||
| 18 | +class MeanShiftClusteringTransform : public TimeVaryingTransform | ||
| 19 | +{ | ||
| 20 | + Q_OBJECT | ||
| 21 | + | ||
| 22 | + Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) | ||
| 23 | + Q_PROPERTY(int kernelBandwidth READ get_kernelBandwidth WRITE set_kernelBandwidth RESET reset_kernelBandwidth STORED false) | ||
| 24 | + Q_PROPERTY(float shiftThreshold READ get_shiftThreshold WRITE set_shiftThreshold RESET reset_shiftThreshold STORED false) | ||
| 25 | + Q_PROPERTY(float distanceThreshold READ get_distanceThreshold WRITE set_distanceThreshold RESET reset_distanceThreshold STORED false) | ||
| 26 | + BR_PROPERTY(br::Distance*, distance, Distance::make(".Dist(L2, false)", NULL)) | ||
| 27 | + BR_PROPERTY(int, kernelBandwidth, 3) | ||
| 28 | + BR_PROPERTY(float, shiftThreshold, 1e-3) | ||
| 29 | + BR_PROPERTY(float, distanceThreshold, 1e-1) | ||
| 30 | + | ||
| 31 | +public: | ||
| 32 | + MeanShiftClusteringTransform() : TimeVaryingTransform(false, false) {} | ||
| 33 | + | ||
| 34 | +private: | ||
| 35 | + void projectUpdate(const TemplateList &src, TemplateList &) | ||
| 36 | + { | ||
| 37 | + templates.append(src); | ||
| 38 | + } | ||
| 39 | + | ||
| 40 | + void finalize(TemplateList &output) | ||
| 41 | + { | ||
| 42 | + output.clear(); | ||
| 43 | + | ||
| 44 | + QList<Mat> original_points, shifted_points; | ||
| 45 | + original_points = shifted_points = templates.data(); | ||
| 46 | + | ||
| 47 | + Mat shift_mask = Mat::zeros(1, shifted_points.size(), CV_32S); | ||
| 48 | + while (countNonZero(shift_mask) != shifted_points.size()) { | ||
| 49 | + for (int i = 0; i < shifted_points.size(); i++) { | ||
| 50 | + if (shift_mask.at<int>(0, i)) | ||
| 51 | + continue; | ||
| 52 | + | ||
| 53 | + Mat point = shifted_points[i]; | ||
| 54 | + Mat shifted_point = point.clone(); | ||
| 55 | + meanshift(shifted_point, original_points); | ||
| 56 | + | ||
| 57 | + float dist = distance->compare(point, shifted_point); | ||
| 58 | + if (dist < shiftThreshold) | ||
| 59 | + shift_mask.at<int>(0, i) = 1; | ||
| 60 | + | ||
| 61 | + shifted_points[i] = shifted_point; | ||
| 62 | + } | ||
| 63 | + } | ||
| 64 | + | ||
| 65 | + QList<int> clusters = assignClusterID(shifted_points); | ||
| 66 | + for (int i = 0; i < templates.size(); i++) | ||
| 67 | + templates[i].file.set("Cluster", clusters[i]); | ||
| 68 | + output.append(templates); | ||
| 69 | + } | ||
| 70 | + | ||
| 71 | + void meanshift(Mat &point, const QList<Mat> &original_points) | ||
| 72 | + { | ||
| 73 | + Mat distances(1, original_points.size(), CV_32FC1); | ||
| 74 | + for (int i = 0; i < original_points.size(); i++) | ||
| 75 | + distances.at<float>(0, i) = distance->compare(point, original_points[i]); | ||
| 76 | + | ||
| 77 | + Mat weights = gaussianKernel(distances, kernelBandwidth); | ||
| 78 | + point = (weights * OpenCVUtils::toMat(original_points)) / sum(weights)[0]; | ||
| 79 | + } | ||
| 80 | + | ||
| 81 | + inline Mat gaussianKernel(const Mat &distance, const float bandwidth) | ||
| 82 | + { | ||
| 83 | + Mat p, e; | ||
| 84 | + pow(distance / bandwidth, 2, p); | ||
| 85 | + exp(-0.5 * p, e); | ||
| 86 | + | ||
| 87 | + return (1.0 / (bandwidth * sqrt(2 * M_PI))) * e; | ||
| 88 | + } | ||
| 89 | + | ||
| 90 | + QList<int> assignClusterID(const QList<Mat> &points) | ||
| 91 | + { | ||
| 92 | + QList<int> groups; | ||
| 93 | + int newGroupIdx = 0; | ||
| 94 | + foreach (const Mat &point, points) { | ||
| 95 | + int group = nearestGroup(point, points, groups); | ||
| 96 | + if (group < 0) | ||
| 97 | + group = newGroupIdx++; | ||
| 98 | + groups.append(group); | ||
| 99 | + } | ||
| 100 | + | ||
| 101 | + if (Globals->verbose) | ||
| 102 | + qDebug("created %d clusters from %d templates", newGroupIdx, points.size()); | ||
| 103 | + | ||
| 104 | + return groups; | ||
| 105 | + } | ||
| 106 | + | ||
| 107 | + int nearestGroup(const Mat &point, const QList<Mat> &points, const QList<int> groups) | ||
| 108 | + { | ||
| 109 | + for (int i = 0; i < groups.size(); i++) { | ||
| 110 | + float dist = distance->compare(point, points[i]); | ||
| 111 | + if (dist < distanceThreshold) | ||
| 112 | + return groups[i]; | ||
| 113 | + } | ||
| 114 | + return -1; | ||
| 115 | + } | ||
| 116 | + | ||
| 117 | + TemplateList templates; | ||
| 118 | +}; | ||
| 119 | + | ||
| 120 | +BR_REGISTER(Transform, MeanShiftClusteringTransform) | ||
| 121 | + | ||
| 122 | +} // namespace br | ||
| 123 | + | ||
| 124 | +#include "cluster/meanshift.moc" |