Commit 139bcbb2cc624d60ed1d34bbbb08c4db5a65c4e5
Merge pull request #343 from biometrics/cluster_update
Decompose k-NN graph generation and clustering
Showing
3 changed files
with
201 additions
and
27 deletions
openbr/core/cluster.cpp
| @@ -82,7 +82,7 @@ float normalizedROD(const Neighborhood &neighborhood, int a, int b) | @@ -82,7 +82,7 @@ float normalizedROD(const Neighborhood &neighborhood, int a, int b) | ||
| 82 | return 1.f * (distanceA + distanceB) / std::min(indexA+1, indexB+1); | 82 | return 1.f * (distanceA + distanceB) / std::min(indexA+1, indexB+1); |
| 83 | } | 83 | } |
| 84 | 84 | ||
| 85 | -Neighborhood getNeighborhood(const QList<cv::Mat> &simmats) | 85 | +Neighborhood br::knnFromSimmat(const QList<cv::Mat> &simmats, int k) |
| 86 | { | 86 | { |
| 87 | Neighborhood neighborhood; | 87 | Neighborhood neighborhood; |
| 88 | 88 | ||
| @@ -130,36 +130,164 @@ Neighborhood getNeighborhood(const QList<cv::Mat> &simmats) | @@ -130,36 +130,164 @@ Neighborhood getNeighborhood(const QList<cv::Mat> &simmats) | ||
| 130 | // Keep the top matches | 130 | // Keep the top matches |
| 131 | for (int j=0; j<allNeighbors.size(); j++) { | 131 | for (int j=0; j<allNeighbors.size(); j++) { |
| 132 | Neighbors &val = allNeighbors[j]; | 132 | Neighbors &val = allNeighbors[j]; |
| 133 | - const int cutoff = 20; // Somewhat arbitrary number of neighbors to keep | 133 | + const int cutoff = k; // Number of neighbors to keep |
| 134 | int keep = std::min(cutoff, val.size()); | 134 | int keep = std::min(cutoff, val.size()); |
| 135 | std::partial_sort(val.begin(), val.begin()+keep, val.end(), compareNeighbors); | 135 | std::partial_sort(val.begin(), val.begin()+keep, val.end(), compareNeighbors); |
| 136 | neighborhood.append((Neighbors)val.mid(0, keep)); | 136 | neighborhood.append((Neighbors)val.mid(0, keep)); |
| 137 | } | 137 | } |
| 138 | } | 138 | } |
| 139 | 139 | ||
| 140 | - // Normalize scores | ||
| 141 | - for (int i=0; i<neighborhood.size(); i++) { | ||
| 142 | - Neighbors &neighbors = neighborhood[i]; | ||
| 143 | - for (int j=0; j<neighbors.size(); j++) { | ||
| 144 | - Neighbor &neighbor = neighbors[j]; | ||
| 145 | - if (neighbor.second == -std::numeric_limits<float>::infinity()) | ||
| 146 | - neighbor.second = 0; | ||
| 147 | - else if (neighbor.second == std::numeric_limits<float>::infinity()) | ||
| 148 | - neighbor.second = 1; | ||
| 149 | - else | ||
| 150 | - neighbor.second = (neighbor.second - globalMin) / (globalMax - globalMin); | 140 | + return neighborhood; |
| 141 | +} | ||
| 142 | + | ||
| 143 | +// generate k-NN graph from pre-computed similarity matrices | ||
| 144 | +Neighborhood br::knnFromSimmat(const QStringList &simmats, int k) | ||
| 145 | +{ | ||
| 146 | + QList<cv::Mat> mats; | ||
| 147 | + foreach (const QString &simmat, simmats) { | ||
| 148 | + QScopedPointer<br::Format> format(br::Factory<br::Format>::make(simmat)); | ||
| 149 | + br::Template t = format->read(); | ||
| 150 | + mats.append(t); | ||
| 151 | + } | ||
| 152 | + return knnFromSimmat(mats, k); | ||
| 153 | +} | ||
| 154 | + | ||
| 155 | +TemplateList knnFromGallery(const QString & galleryName, bool inMemory, const QString & outFile, int k) | ||
| 156 | +{ | ||
| 157 | + QSharedPointer<Transform> comparison = Transform::fromComparison(Globals->algorithm); | ||
| 158 | + | ||
| 159 | + Gallery *tempG = Gallery::make(galleryName); | ||
| 160 | + qint64 total = tempG->totalSize(); | ||
| 161 | + delete tempG; | ||
| 162 | + comparison->setPropertyRecursive("galleryName", galleryName+"[dropMetadata=true]"); | ||
| 163 | + | ||
| 164 | + bool multiProcess = Globals->file.getBool("multiProcess", false); | ||
| 165 | + if (multiProcess) | ||
| 166 | + comparison = QSharedPointer<Transform> (br::wrapTransform(comparison.data(), "ProcessWrapper")); | ||
| 167 | + | ||
| 168 | + QScopedPointer<Transform> collect(Transform::make("CollectNN+ProgressCounter+Discard", NULL)); | ||
| 169 | + collect->setPropertyRecursive("totalProgress", total); | ||
| 170 | + collect->setPropertyRecursive("keep", k); | ||
| 171 | + | ||
| 172 | + QList<Transform *> tforms; | ||
| 173 | + tforms.append(comparison.data()); | ||
| 174 | + tforms.append(collect.data()); | ||
| 175 | + | ||
| 176 | + QScopedPointer<Transform> compareCollect(br::pipeTransforms(tforms)); | ||
| 177 | + | ||
| 178 | + QSharedPointer <Transform> projector; | ||
| 179 | + if (inMemory) | ||
| 180 | + projector = QSharedPointer<Transform> (br::wrapTransform(compareCollect.data(), "Stream(readMode=StreamGallery, endPoint=Discard")); | ||
| 181 | + else | ||
| 182 | + projector = QSharedPointer<Transform> (br::wrapTransform(compareCollect.data(), "Stream(readMode=StreamGallery, endPoint=LogNN("+outFile+")+DiscardTemplates)")); | ||
| 183 | + | ||
| 184 | + TemplateList input; | ||
| 185 | + input.append(Template(galleryName)); | ||
| 186 | + TemplateList output; | ||
| 187 | + | ||
| 188 | + projector->init(); | ||
| 189 | + projector->projectUpdate(input, output); | ||
| 190 | + | ||
| 191 | + return output; | ||
| 192 | +} | ||
| 193 | + | ||
| 194 | +// Generate k-NN graph from a gallery, using the current algorithm for comparison. | ||
| 195 | +// Direct serialization to file system, k-NN graph is not retained in memory | ||
| 196 | +void br::knnFromGallery(const QString &galleryName, const QString &outFile, int k) | ||
| 197 | +{ | ||
| 198 | + knnFromGallery(galleryName, false, outFile, k); | ||
| 199 | +} | ||
| 200 | + | ||
| 201 | +// In-memory graph construction | ||
| 202 | +Neighborhood br::knnFromGallery(const QString &gallery, int k) | ||
| 203 | +{ | ||
| 204 | + // Nearest neighbor data current stored as template metadata, so retrieve it | ||
| 205 | + TemplateList res = knnFromGallery(gallery, true, "", k); | ||
| 206 | + | ||
| 207 | + Neighborhood neighborhood; | ||
| 208 | + foreach (const Template &t, res) { | ||
| 209 | + Neighbors neighbors = t.file.get<Neighbors>("neighbors"); | ||
| 210 | + neighbors.append(neighbors); | ||
| 211 | + } | ||
| 212 | + | ||
| 213 | + return neighborhood; | ||
| 214 | +} | ||
| 215 | + | ||
| 216 | +Neighborhood br::loadkNN(const QString &infile) | ||
| 217 | +{ | ||
| 218 | + Neighborhood neighborhood; | ||
| 219 | + QFile file(infile); | ||
| 220 | + bool success = file.open(QFile::ReadOnly); | ||
| 221 | + if (!success) qFatal("Failed to open %s for reading.", qPrintable(infile)); | ||
| 222 | + QStringList lines = QString(file.readAll()).split("\n"); | ||
| 223 | + file.close(); | ||
| 224 | + int min_idx = INT_MAX; | ||
| 225 | + int max_idx = -1; | ||
| 226 | + int count = 0; | ||
| 227 | + | ||
| 228 | + foreach (const QString &line, lines) { | ||
| 229 | + Neighbors neighbors; | ||
| 230 | + count++; | ||
| 231 | + if (line.trimmed().isEmpty()) { | ||
| 232 | + neighborhood.append(neighbors); | ||
| 233 | + continue; | ||
| 234 | + } | ||
| 235 | + bool off = false; | ||
| 236 | + QStringList list = line.trimmed().split(",", QString::SkipEmptyParts); | ||
| 237 | + foreach (const QString &item, list) { | ||
| 238 | + QStringList parts = item.trimmed().split(":", QString::SkipEmptyParts); | ||
| 239 | + bool intOK = true; | ||
| 240 | + bool floatOK = true; | ||
| 241 | + int idx = parts[0].toInt(&intOK); | ||
| 242 | + float score = parts[1].toFloat(&floatOK); | ||
| 243 | + | ||
| 244 | + if (idx > max_idx) | ||
| 245 | + max_idx = idx; | ||
| 246 | + if (idx <min_idx) | ||
| 247 | + min_idx = idx; | ||
| 248 | + | ||
| 249 | + if (idx >= lines.size()) { | ||
| 250 | + off = true; | ||
| 251 | + continue; | ||
| 252 | + } | ||
| 253 | + neighbors.append(qMakePair(idx, score)); | ||
| 254 | + | ||
| 255 | + | ||
| 256 | + if (!intOK && floatOK) | ||
| 257 | + qFatal("Failed to parse word: %s", qPrintable(item)); | ||
| 151 | } | 258 | } |
| 259 | + neighborhood.append(neighbors); | ||
| 152 | } | 260 | } |
| 153 | return neighborhood; | 261 | return neighborhood; |
| 154 | } | 262 | } |
| 155 | 263 | ||
| 156 | -// Zhu et al. "A Rank-Order Distance based Clustering Algorithm for Face Tagging", CVPR 2011 | ||
| 157 | -br::Clusters br::ClusterGallery(const QList<cv::Mat> &simmats, float aggressiveness) | 264 | +bool br::savekNN(const Neighborhood &neighborhood, const QString &outfile) |
| 265 | +{ | ||
| 266 | + QFile file(outfile); | ||
| 267 | + bool success = file.open(QFile::WriteOnly); | ||
| 268 | + if (!success) qFatal("Failed to open %s for writing.", qPrintable(outfile)); | ||
| 269 | + | ||
| 270 | + foreach (Neighbors neighbors, neighborhood) { | ||
| 271 | + QString aLine; | ||
| 272 | + if (!neighbors.empty()) | ||
| 273 | + { | ||
| 274 | + aLine.append(QString::number(neighbors[0].first)+":"+QString::number(neighbors[0].second)); | ||
| 275 | + for (int i=1; i < neighbors.size();i++) { | ||
| 276 | + aLine.append(","+QString::number(neighbors[i].first)+":"+QString::number(neighbors[i].second)); | ||
| 277 | + } | ||
| 278 | + } | ||
| 279 | + aLine += "\n"; | ||
| 280 | + file.write(qPrintable(aLine)); | ||
| 281 | + } | ||
| 282 | + file.close(); | ||
| 283 | + return true; | ||
| 284 | +} | ||
| 285 | + | ||
| 286 | + | ||
| 287 | +// Rank-order clustering on a pre-computed k-NN graph | ||
| 288 | +Clusters br::ClusterGraph(Neighborhood neighborhood, float aggressiveness, const QString &csv) | ||
| 158 | { | 289 | { |
| 159 | - qDebug("Clustering %d simmat(s), aggressiveness %f", simmats.size(), aggressiveness); | ||
| 160 | 290 | ||
| 161 | - // Read in gallery parts, keeping top neighbors of each template | ||
| 162 | - Neighborhood neighborhood = getNeighborhood(simmats); | ||
| 163 | const int cutoff = neighborhood.first().size(); | 291 | const int cutoff = neighborhood.first().size(); |
| 164 | const float threshold = 3*cutoff/4 * aggressiveness/5; | 292 | const float threshold = 3*cutoff/4 * aggressiveness/5; |
| 165 | 293 | ||
| @@ -235,10 +363,31 @@ br::Clusters br::ClusterGallery(const QList<cv::Mat> &simmats, float aggressiven | @@ -235,10 +363,31 @@ br::Clusters br::ClusterGallery(const QList<cv::Mat> &simmats, float aggressiven | ||
| 235 | clusters = newClusters; | 363 | clusters = newClusters; |
| 236 | neighborhood = newNeighborhood; | 364 | neighborhood = newNeighborhood; |
| 237 | } | 365 | } |
| 366 | + | ||
| 367 | + if (!csv.isEmpty()) | ||
| 368 | + WriteClusters(clusters, csv); | ||
| 369 | + | ||
| 238 | return clusters; | 370 | return clusters; |
| 239 | } | 371 | } |
| 240 | 372 | ||
| 241 | -br::Clusters br::ClusterGallery(const QStringList &simmats, float aggressiveness, const QString &csv) | 373 | +Clusters br::ClusterGraph(const QString & knnName, float aggressiveness, const QString &csv) |
| 374 | +{ | ||
| 375 | + Neighborhood neighbors = loadkNN(knnName); | ||
| 376 | + return ClusterGraph(neighbors, aggressiveness, csv); | ||
| 377 | +} | ||
| 378 | + | ||
| 379 | +// Zhu et al. "A Rank-Order Distance based Clustering Algorithm for Face Tagging", CVPR 2011 | ||
| 380 | +br::Clusters br::ClusterSimmat(const QList<cv::Mat> &simmats, float aggressiveness, const QString &csv) | ||
| 381 | +{ | ||
| 382 | + qDebug("Clustering %d simmat(s), aggressiveness %f", simmats.size(), aggressiveness); | ||
| 383 | + | ||
| 384 | + // Read in gallery parts, keeping top neighbors of each template | ||
| 385 | + Neighborhood neighborhood = knnFromSimmat(simmats); | ||
| 386 | + | ||
| 387 | + return ClusterGraph(neighborhood, aggressiveness, csv); | ||
| 388 | +} | ||
| 389 | + | ||
| 390 | +br::Clusters br::ClusterSimmat(const QStringList &simmats, float aggressiveness, const QString &csv) | ||
| 242 | { | 391 | { |
| 243 | QList<cv::Mat> mats; | 392 | QList<cv::Mat> mats; |
| 244 | foreach (const QString &simmat, simmats) { | 393 | foreach (const QString &simmat, simmats) { |
| @@ -247,11 +396,7 @@ br::Clusters br::ClusterGallery(const QStringList &simmats, float aggressiveness | @@ -247,11 +396,7 @@ br::Clusters br::ClusterGallery(const QStringList &simmats, float aggressiveness | ||
| 247 | mats.append(t); | 396 | mats.append(t); |
| 248 | } | 397 | } |
| 249 | 398 | ||
| 250 | - Clusters clusters = ClusterGallery(mats, aggressiveness); | ||
| 251 | - | ||
| 252 | - // Save clusters | ||
| 253 | - if (!csv.isEmpty()) | ||
| 254 | - WriteClusters(clusters, csv); | 399 | + Clusters clusters = ClusterSimmat(mats, aggressiveness, csv); |
| 255 | return clusters; | 400 | return clusters; |
| 256 | } | 401 | } |
| 257 | 402 |
openbr/core/cluster.h
| @@ -22,16 +22,45 @@ | @@ -22,16 +22,45 @@ | ||
| 22 | #include <QStringList> | 22 | #include <QStringList> |
| 23 | #include <QVector> | 23 | #include <QVector> |
| 24 | #include <openbr/openbr_plugin.h> | 24 | #include <openbr/openbr_plugin.h> |
| 25 | +#include <openbr/plugins/openbr_internal.h> | ||
| 25 | 26 | ||
| 26 | namespace br | 27 | namespace br |
| 27 | { | 28 | { |
| 28 | typedef QList<int> Cluster; // List of indices into galleries | 29 | typedef QList<int> Cluster; // List of indices into galleries |
| 29 | typedef QVector<Cluster> Clusters; | 30 | typedef QVector<Cluster> Clusters; |
| 30 | 31 | ||
| 31 | - Clusters ClusterGallery(const QList<cv::Mat> &simmats, float aggressiveness); | ||
| 32 | - Clusters ClusterGallery(const QStringList &simmats, float aggressiveness, const QString &csv); | 32 | + // generate k-NN graph from pre-computed similarity matrices |
| 33 | + Neighborhood knnFromSimmat(const QStringList &simmats, int k = 20); | ||
| 34 | + Neighborhood knnFromSimmat(const QList<cv::Mat> &simmats, int k = 20); | ||
| 35 | + | ||
| 36 | + // Generate k-NN graph from a gallery, using the current algorithm for comparison. | ||
| 37 | + // direct serialization to file system. | ||
| 38 | + void knnFromGallery(const QString &galleryName, const QString & outFile, int k = 20); | ||
| 39 | + // in memory graph computation | ||
| 40 | + Neighborhood knnFromGallery(const QString &gallery, int k = 20); | ||
| 41 | + | ||
| 42 | + // Load k-NN graph from a file with the following ascii format: | ||
| 43 | + // One line per sample, each line lists the top k neighbors for the sample as follows: | ||
| 44 | + // index1:score1,index2:score2,...,indexk:scorek | ||
| 45 | + Neighborhood loadkNN(const QString &fname); | ||
| 46 | + | ||
| 47 | + // Save k-NN graph to file | ||
| 48 | + bool savekNN(const Neighborhood &neighborhood, const QString &outfile); | ||
| 49 | + | ||
| 50 | + // Rank-order clustering on a pre-computed k-NN graph | ||
| 51 | + Clusters ClusterGraph(Neighborhood neighbors, float aggresssiveness, const QString &csv = ""); | ||
| 52 | + Clusters ClusterGraph(const QString & knnName, float aggressiveness, const QString &csv = ""); | ||
| 53 | + | ||
| 54 | + // Given a similarity matrix, compute the k-NN graph, then perform rank-order clustering. | ||
| 55 | + Clusters ClusterSimmat(const QList<cv::Mat> &simmats, float aggressiveness, const QString &csv = ""); | ||
| 56 | + Clusters ClusterSimmat(const QStringList &simmats, float aggressiveness, const QString &csv = ""); | ||
| 57 | + | ||
| 58 | + // evaluate clustering results in csv, reading ground truth data from gallery input, using truth_property | ||
| 59 | + // as the key for ground truth labels. | ||
| 33 | void EvalClustering(const QString &csv, const QString &input, QString truth_property); | 60 | void EvalClustering(const QString &csv, const QString &input, QString truth_property); |
| 34 | 61 | ||
| 62 | + // Read/write clusters from a text format, 1 line = 1 cluster, each line contains comma separated list | ||
| 63 | + // of assigned indices. | ||
| 35 | Clusters ReadClusters(const QString &csv); | 64 | Clusters ReadClusters(const QString &csv); |
| 36 | void WriteClusters(const Clusters &clusters, const QString &csv); | 65 | void WriteClusters(const Clusters &clusters, const QString &csv); |
| 37 | } | 66 | } |
openbr/openbr.cpp
| @@ -59,7 +59,7 @@ void br_cat(int num_input_galleries, const char *input_galleries[], const char * | @@ -59,7 +59,7 @@ void br_cat(int num_input_galleries, const char *input_galleries[], const char * | ||
| 59 | 59 | ||
| 60 | void br_cluster(int num_simmats, const char *simmats[], float aggressiveness, const char *csv) | 60 | void br_cluster(int num_simmats, const char *simmats[], float aggressiveness, const char *csv) |
| 61 | { | 61 | { |
| 62 | - ClusterGallery(QtUtils::toStringList(num_simmats, simmats), aggressiveness, csv); | 62 | + ClusterSimmat(QtUtils::toStringList(num_simmats, simmats), aggressiveness, csv); |
| 63 | } | 63 | } |
| 64 | 64 | ||
| 65 | void br_combine_masks(int num_input_masks, const char *input_masks[], const char *output_mask, const char *method) | 65 | void br_combine_masks(int num_input_masks, const char *input_masks[], const char *output_mask, const char *method) |