Commit 139bcbb2cc624d60ed1d34bbbb08c4db5a65c4e5
Merge pull request #343 from biometrics/cluster_update
Decompose k-NN graph generation and clustering
Showing
3 changed files
with
201 additions
and
27 deletions
openbr/core/cluster.cpp
| ... | ... | @@ -82,7 +82,7 @@ float normalizedROD(const Neighborhood &neighborhood, int a, int b) |
| 82 | 82 | return 1.f * (distanceA + distanceB) / std::min(indexA+1, indexB+1); |
| 83 | 83 | } |
| 84 | 84 | |
| 85 | -Neighborhood getNeighborhood(const QList<cv::Mat> &simmats) | |
| 85 | +Neighborhood br::knnFromSimmat(const QList<cv::Mat> &simmats, int k) | |
| 86 | 86 | { |
| 87 | 87 | Neighborhood neighborhood; |
| 88 | 88 | |
| ... | ... | @@ -130,36 +130,164 @@ Neighborhood getNeighborhood(const QList<cv::Mat> &simmats) |
| 130 | 130 | // Keep the top matches |
| 131 | 131 | for (int j=0; j<allNeighbors.size(); j++) { |
| 132 | 132 | Neighbors &val = allNeighbors[j]; |
| 133 | - const int cutoff = 20; // Somewhat arbitrary number of neighbors to keep | |
| 133 | + const int cutoff = k; // Number of neighbors to keep | |
| 134 | 134 | int keep = std::min(cutoff, val.size()); |
| 135 | 135 | std::partial_sort(val.begin(), val.begin()+keep, val.end(), compareNeighbors); |
| 136 | 136 | neighborhood.append((Neighbors)val.mid(0, keep)); |
| 137 | 137 | } |
| 138 | 138 | } |
| 139 | 139 | |
| 140 | - // Normalize scores | |
| 141 | - for (int i=0; i<neighborhood.size(); i++) { | |
| 142 | - Neighbors &neighbors = neighborhood[i]; | |
| 143 | - for (int j=0; j<neighbors.size(); j++) { | |
| 144 | - Neighbor &neighbor = neighbors[j]; | |
| 145 | - if (neighbor.second == -std::numeric_limits<float>::infinity()) | |
| 146 | - neighbor.second = 0; | |
| 147 | - else if (neighbor.second == std::numeric_limits<float>::infinity()) | |
| 148 | - neighbor.second = 1; | |
| 149 | - else | |
| 150 | - neighbor.second = (neighbor.second - globalMin) / (globalMax - globalMin); | |
| 140 | + return neighborhood; | |
| 141 | +} | |
| 142 | + | |
| 143 | +// generate k-NN graph from pre-computed similarity matrices | |
| 144 | +Neighborhood br::knnFromSimmat(const QStringList &simmats, int k) | |
| 145 | +{ | |
| 146 | + QList<cv::Mat> mats; | |
| 147 | + foreach (const QString &simmat, simmats) { | |
| 148 | + QScopedPointer<br::Format> format(br::Factory<br::Format>::make(simmat)); | |
| 149 | + br::Template t = format->read(); | |
| 150 | + mats.append(t); | |
| 151 | + } | |
| 152 | + return knnFromSimmat(mats, k); | |
| 153 | +} | |
| 154 | + | |
| 155 | +TemplateList knnFromGallery(const QString & galleryName, bool inMemory, const QString & outFile, int k) | |
| 156 | +{ | |
| 157 | + QSharedPointer<Transform> comparison = Transform::fromComparison(Globals->algorithm); | |
| 158 | + | |
| 159 | + Gallery *tempG = Gallery::make(galleryName); | |
| 160 | + qint64 total = tempG->totalSize(); | |
| 161 | + delete tempG; | |
| 162 | + comparison->setPropertyRecursive("galleryName", galleryName+"[dropMetadata=true]"); | |
| 163 | + | |
| 164 | + bool multiProcess = Globals->file.getBool("multiProcess", false); | |
| 165 | + if (multiProcess) | |
| 166 | + comparison = QSharedPointer<Transform> (br::wrapTransform(comparison.data(), "ProcessWrapper")); | |
| 167 | + | |
| 168 | + QScopedPointer<Transform> collect(Transform::make("CollectNN+ProgressCounter+Discard", NULL)); | |
| 169 | + collect->setPropertyRecursive("totalProgress", total); | |
| 170 | + collect->setPropertyRecursive("keep", k); | |
| 171 | + | |
| 172 | + QList<Transform *> tforms; | |
| 173 | + tforms.append(comparison.data()); | |
| 174 | + tforms.append(collect.data()); | |
| 175 | + | |
| 176 | + QScopedPointer<Transform> compareCollect(br::pipeTransforms(tforms)); | |
| 177 | + | |
| 178 | + QSharedPointer <Transform> projector; | |
| 179 | + if (inMemory) | |
| 180 | + projector = QSharedPointer<Transform> (br::wrapTransform(compareCollect.data(), "Stream(readMode=StreamGallery, endPoint=Discard")); | |
| 181 | + else | |
| 182 | + projector = QSharedPointer<Transform> (br::wrapTransform(compareCollect.data(), "Stream(readMode=StreamGallery, endPoint=LogNN("+outFile+")+DiscardTemplates)")); | |
| 183 | + | |
| 184 | + TemplateList input; | |
| 185 | + input.append(Template(galleryName)); | |
| 186 | + TemplateList output; | |
| 187 | + | |
| 188 | + projector->init(); | |
| 189 | + projector->projectUpdate(input, output); | |
| 190 | + | |
| 191 | + return output; | |
| 192 | +} | |
| 193 | + | |
| 194 | +// Generate k-NN graph from a gallery, using the current algorithm for comparison. | |
| 195 | +// Direct serialization to file system, k-NN graph is not retained in memory | |
| 196 | +void br::knnFromGallery(const QString &galleryName, const QString &outFile, int k) | |
| 197 | +{ | |
| 198 | + knnFromGallery(galleryName, false, outFile, k); | |
| 199 | +} | |
| 200 | + | |
| 201 | +// In-memory graph construction | |
| 202 | +Neighborhood br::knnFromGallery(const QString &gallery, int k) | |
| 203 | +{ | |
| 204 | + // Nearest neighbor data current stored as template metadata, so retrieve it | |
| 205 | + TemplateList res = knnFromGallery(gallery, true, "", k); | |
| 206 | + | |
| 207 | + Neighborhood neighborhood; | |
| 208 | + foreach (const Template &t, res) { | |
| 209 | + Neighbors neighbors = t.file.get<Neighbors>("neighbors"); | |
| 210 | + neighbors.append(neighbors); | |
| 211 | + } | |
| 212 | + | |
| 213 | + return neighborhood; | |
| 214 | +} | |
| 215 | + | |
| 216 | +Neighborhood br::loadkNN(const QString &infile) | |
| 217 | +{ | |
| 218 | + Neighborhood neighborhood; | |
| 219 | + QFile file(infile); | |
| 220 | + bool success = file.open(QFile::ReadOnly); | |
| 221 | + if (!success) qFatal("Failed to open %s for reading.", qPrintable(infile)); | |
| 222 | + QStringList lines = QString(file.readAll()).split("\n"); | |
| 223 | + file.close(); | |
| 224 | + int min_idx = INT_MAX; | |
| 225 | + int max_idx = -1; | |
| 226 | + int count = 0; | |
| 227 | + | |
| 228 | + foreach (const QString &line, lines) { | |
| 229 | + Neighbors neighbors; | |
| 230 | + count++; | |
| 231 | + if (line.trimmed().isEmpty()) { | |
| 232 | + neighborhood.append(neighbors); | |
| 233 | + continue; | |
| 234 | + } | |
| 235 | + bool off = false; | |
| 236 | + QStringList list = line.trimmed().split(",", QString::SkipEmptyParts); | |
| 237 | + foreach (const QString &item, list) { | |
| 238 | + QStringList parts = item.trimmed().split(":", QString::SkipEmptyParts); | |
| 239 | + bool intOK = true; | |
| 240 | + bool floatOK = true; | |
| 241 | + int idx = parts[0].toInt(&intOK); | |
| 242 | + float score = parts[1].toFloat(&floatOK); | |
| 243 | + | |
| 244 | + if (idx > max_idx) | |
| 245 | + max_idx = idx; | |
| 246 | + if (idx <min_idx) | |
| 247 | + min_idx = idx; | |
| 248 | + | |
| 249 | + if (idx >= lines.size()) { | |
| 250 | + off = true; | |
| 251 | + continue; | |
| 252 | + } | |
| 253 | + neighbors.append(qMakePair(idx, score)); | |
| 254 | + | |
| 255 | + | |
| 256 | + if (!intOK && floatOK) | |
| 257 | + qFatal("Failed to parse word: %s", qPrintable(item)); | |
| 151 | 258 | } |
| 259 | + neighborhood.append(neighbors); | |
| 152 | 260 | } |
| 153 | 261 | return neighborhood; |
| 154 | 262 | } |
| 155 | 263 | |
| 156 | -// Zhu et al. "A Rank-Order Distance based Clustering Algorithm for Face Tagging", CVPR 2011 | |
| 157 | -br::Clusters br::ClusterGallery(const QList<cv::Mat> &simmats, float aggressiveness) | |
| 264 | +bool br::savekNN(const Neighborhood &neighborhood, const QString &outfile) | |
| 265 | +{ | |
| 266 | + QFile file(outfile); | |
| 267 | + bool success = file.open(QFile::WriteOnly); | |
| 268 | + if (!success) qFatal("Failed to open %s for writing.", qPrintable(outfile)); | |
| 269 | + | |
| 270 | + foreach (Neighbors neighbors, neighborhood) { | |
| 271 | + QString aLine; | |
| 272 | + if (!neighbors.empty()) | |
| 273 | + { | |
| 274 | + aLine.append(QString::number(neighbors[0].first)+":"+QString::number(neighbors[0].second)); | |
| 275 | + for (int i=1; i < neighbors.size();i++) { | |
| 276 | + aLine.append(","+QString::number(neighbors[i].first)+":"+QString::number(neighbors[i].second)); | |
| 277 | + } | |
| 278 | + } | |
| 279 | + aLine += "\n"; | |
| 280 | + file.write(qPrintable(aLine)); | |
| 281 | + } | |
| 282 | + file.close(); | |
| 283 | + return true; | |
| 284 | +} | |
| 285 | + | |
| 286 | + | |
| 287 | +// Rank-order clustering on a pre-computed k-NN graph | |
| 288 | +Clusters br::ClusterGraph(Neighborhood neighborhood, float aggressiveness, const QString &csv) | |
| 158 | 289 | { |
| 159 | - qDebug("Clustering %d simmat(s), aggressiveness %f", simmats.size(), aggressiveness); | |
| 160 | 290 | |
| 161 | - // Read in gallery parts, keeping top neighbors of each template | |
| 162 | - Neighborhood neighborhood = getNeighborhood(simmats); | |
| 163 | 291 | const int cutoff = neighborhood.first().size(); |
| 164 | 292 | const float threshold = 3*cutoff/4 * aggressiveness/5; |
| 165 | 293 | |
| ... | ... | @@ -235,10 +363,31 @@ br::Clusters br::ClusterGallery(const QList<cv::Mat> &simmats, float aggressiven |
| 235 | 363 | clusters = newClusters; |
| 236 | 364 | neighborhood = newNeighborhood; |
| 237 | 365 | } |
| 366 | + | |
| 367 | + if (!csv.isEmpty()) | |
| 368 | + WriteClusters(clusters, csv); | |
| 369 | + | |
| 238 | 370 | return clusters; |
| 239 | 371 | } |
| 240 | 372 | |
| 241 | -br::Clusters br::ClusterGallery(const QStringList &simmats, float aggressiveness, const QString &csv) | |
| 373 | +Clusters br::ClusterGraph(const QString & knnName, float aggressiveness, const QString &csv) | |
| 374 | +{ | |
| 375 | + Neighborhood neighbors = loadkNN(knnName); | |
| 376 | + return ClusterGraph(neighbors, aggressiveness, csv); | |
| 377 | +} | |
| 378 | + | |
| 379 | +// Zhu et al. "A Rank-Order Distance based Clustering Algorithm for Face Tagging", CVPR 2011 | |
| 380 | +br::Clusters br::ClusterSimmat(const QList<cv::Mat> &simmats, float aggressiveness, const QString &csv) | |
| 381 | +{ | |
| 382 | + qDebug("Clustering %d simmat(s), aggressiveness %f", simmats.size(), aggressiveness); | |
| 383 | + | |
| 384 | + // Read in gallery parts, keeping top neighbors of each template | |
| 385 | + Neighborhood neighborhood = knnFromSimmat(simmats); | |
| 386 | + | |
| 387 | + return ClusterGraph(neighborhood, aggressiveness, csv); | |
| 388 | +} | |
| 389 | + | |
| 390 | +br::Clusters br::ClusterSimmat(const QStringList &simmats, float aggressiveness, const QString &csv) | |
| 242 | 391 | { |
| 243 | 392 | QList<cv::Mat> mats; |
| 244 | 393 | foreach (const QString &simmat, simmats) { |
| ... | ... | @@ -247,11 +396,7 @@ br::Clusters br::ClusterGallery(const QStringList &simmats, float aggressiveness |
| 247 | 396 | mats.append(t); |
| 248 | 397 | } |
| 249 | 398 | |
| 250 | - Clusters clusters = ClusterGallery(mats, aggressiveness); | |
| 251 | - | |
| 252 | - // Save clusters | |
| 253 | - if (!csv.isEmpty()) | |
| 254 | - WriteClusters(clusters, csv); | |
| 399 | + Clusters clusters = ClusterSimmat(mats, aggressiveness, csv); | |
| 255 | 400 | return clusters; |
| 256 | 401 | } |
| 257 | 402 | ... | ... |
openbr/core/cluster.h
| ... | ... | @@ -22,16 +22,45 @@ |
| 22 | 22 | #include <QStringList> |
| 23 | 23 | #include <QVector> |
| 24 | 24 | #include <openbr/openbr_plugin.h> |
| 25 | +#include <openbr/plugins/openbr_internal.h> | |
| 25 | 26 | |
| 26 | 27 | namespace br |
| 27 | 28 | { |
| 28 | 29 | typedef QList<int> Cluster; // List of indices into galleries |
| 29 | 30 | typedef QVector<Cluster> Clusters; |
| 30 | 31 | |
| 31 | - Clusters ClusterGallery(const QList<cv::Mat> &simmats, float aggressiveness); | |
| 32 | - Clusters ClusterGallery(const QStringList &simmats, float aggressiveness, const QString &csv); | |
| 32 | + // generate k-NN graph from pre-computed similarity matrices | |
| 33 | + Neighborhood knnFromSimmat(const QStringList &simmats, int k = 20); | |
| 34 | + Neighborhood knnFromSimmat(const QList<cv::Mat> &simmats, int k = 20); | |
| 35 | + | |
| 36 | + // Generate k-NN graph from a gallery, using the current algorithm for comparison. | |
| 37 | + // direct serialization to file system. | |
| 38 | + void knnFromGallery(const QString &galleryName, const QString & outFile, int k = 20); | |
| 39 | + // in memory graph computation | |
| 40 | + Neighborhood knnFromGallery(const QString &gallery, int k = 20); | |
| 41 | + | |
| 42 | + // Load k-NN graph from a file with the following ascii format: | |
| 43 | + // One line per sample, each line lists the top k neighbors for the sample as follows: | |
| 44 | + // index1:score1,index2:score2,...,indexk:scorek | |
| 45 | + Neighborhood loadkNN(const QString &fname); | |
| 46 | + | |
| 47 | + // Save k-NN graph to file | |
| 48 | + bool savekNN(const Neighborhood &neighborhood, const QString &outfile); | |
| 49 | + | |
| 50 | + // Rank-order clustering on a pre-computed k-NN graph | |
| 51 | + Clusters ClusterGraph(Neighborhood neighbors, float aggresssiveness, const QString &csv = ""); | |
| 52 | + Clusters ClusterGraph(const QString & knnName, float aggressiveness, const QString &csv = ""); | |
| 53 | + | |
| 54 | + // Given a similarity matrix, compute the k-NN graph, then perform rank-order clustering. | |
| 55 | + Clusters ClusterSimmat(const QList<cv::Mat> &simmats, float aggressiveness, const QString &csv = ""); | |
| 56 | + Clusters ClusterSimmat(const QStringList &simmats, float aggressiveness, const QString &csv = ""); | |
| 57 | + | |
| 58 | + // evaluate clustering results in csv, reading ground truth data from gallery input, using truth_property | |
| 59 | + // as the key for ground truth labels. | |
| 33 | 60 | void EvalClustering(const QString &csv, const QString &input, QString truth_property); |
| 34 | 61 | |
| 62 | + // Read/write clusters from a text format, 1 line = 1 cluster, each line contains comma separated list | |
| 63 | + // of assigned indices. | |
| 35 | 64 | Clusters ReadClusters(const QString &csv); |
| 36 | 65 | void WriteClusters(const Clusters &clusters, const QString &csv); |
| 37 | 66 | } | ... | ... |
openbr/openbr.cpp
| ... | ... | @@ -59,7 +59,7 @@ void br_cat(int num_input_galleries, const char *input_galleries[], const char * |
| 59 | 59 | |
| 60 | 60 | void br_cluster(int num_simmats, const char *simmats[], float aggressiveness, const char *csv) |
| 61 | 61 | { |
| 62 | - ClusterGallery(QtUtils::toStringList(num_simmats, simmats), aggressiveness, csv); | |
| 62 | + ClusterSimmat(QtUtils::toStringList(num_simmats, simmats), aggressiveness, csv); | |
| 63 | 63 | } |
| 64 | 64 | |
| 65 | 65 | void br_combine_masks(int num_input_masks, const char *input_masks[], const char *output_mask, const char *method) | ... | ... |