Commit 139bcbb2cc624d60ed1d34bbbb08c4db5a65c4e5

Authored by caotto
2 parents 21807b94 500ad5a2

Merge pull request #343 from biometrics/cluster_update

Decompose k-NN graph generation and clustering
openbr/core/cluster.cpp
... ... @@ -82,7 +82,7 @@ float normalizedROD(const Neighborhood &neighborhood, int a, int b)
82 82 return 1.f * (distanceA + distanceB) / std::min(indexA+1, indexB+1);
83 83 }
84 84  
85   -Neighborhood getNeighborhood(const QList<cv::Mat> &simmats)
  85 +Neighborhood br::knnFromSimmat(const QList<cv::Mat> &simmats, int k)
86 86 {
87 87 Neighborhood neighborhood;
88 88  
... ... @@ -130,36 +130,164 @@ Neighborhood getNeighborhood(const QList&lt;cv::Mat&gt; &amp;simmats)
130 130 // Keep the top matches
131 131 for (int j=0; j<allNeighbors.size(); j++) {
132 132 Neighbors &val = allNeighbors[j];
133   - const int cutoff = 20; // Somewhat arbitrary number of neighbors to keep
  133 + const int cutoff = k; // Number of neighbors to keep
134 134 int keep = std::min(cutoff, val.size());
135 135 std::partial_sort(val.begin(), val.begin()+keep, val.end(), compareNeighbors);
136 136 neighborhood.append((Neighbors)val.mid(0, keep));
137 137 }
138 138 }
139 139  
140   - // Normalize scores
141   - for (int i=0; i<neighborhood.size(); i++) {
142   - Neighbors &neighbors = neighborhood[i];
143   - for (int j=0; j<neighbors.size(); j++) {
144   - Neighbor &neighbor = neighbors[j];
145   - if (neighbor.second == -std::numeric_limits<float>::infinity())
146   - neighbor.second = 0;
147   - else if (neighbor.second == std::numeric_limits<float>::infinity())
148   - neighbor.second = 1;
149   - else
150   - neighbor.second = (neighbor.second - globalMin) / (globalMax - globalMin);
  140 + return neighborhood;
  141 +}
  142 +
  143 +// generate k-NN graph from pre-computed similarity matrices
  144 +Neighborhood br::knnFromSimmat(const QStringList &simmats, int k)
  145 +{
  146 + QList<cv::Mat> mats;
  147 + foreach (const QString &simmat, simmats) {
  148 + QScopedPointer<br::Format> format(br::Factory<br::Format>::make(simmat));
  149 + br::Template t = format->read();
  150 + mats.append(t);
  151 + }
  152 + return knnFromSimmat(mats, k);
  153 +}
  154 +
  155 +TemplateList knnFromGallery(const QString & galleryName, bool inMemory, const QString & outFile, int k)
  156 +{
  157 + QSharedPointer<Transform> comparison = Transform::fromComparison(Globals->algorithm);
  158 +
  159 + Gallery *tempG = Gallery::make(galleryName);
  160 + qint64 total = tempG->totalSize();
  161 + delete tempG;
  162 + comparison->setPropertyRecursive("galleryName", galleryName+"[dropMetadata=true]");
  163 +
  164 + bool multiProcess = Globals->file.getBool("multiProcess", false);
  165 + if (multiProcess)
  166 + comparison = QSharedPointer<Transform> (br::wrapTransform(comparison.data(), "ProcessWrapper"));
  167 +
  168 + QScopedPointer<Transform> collect(Transform::make("CollectNN+ProgressCounter+Discard", NULL));
  169 + collect->setPropertyRecursive("totalProgress", total);
  170 + collect->setPropertyRecursive("keep", k);
  171 +
  172 + QList<Transform *> tforms;
  173 + tforms.append(comparison.data());
  174 + tforms.append(collect.data());
  175 +
  176 + QScopedPointer<Transform> compareCollect(br::pipeTransforms(tforms));
  177 +
  178 + QSharedPointer <Transform> projector;
  179 + if (inMemory)
  180 + projector = QSharedPointer<Transform> (br::wrapTransform(compareCollect.data(), "Stream(readMode=StreamGallery, endPoint=Discard"));
  181 + else
  182 + projector = QSharedPointer<Transform> (br::wrapTransform(compareCollect.data(), "Stream(readMode=StreamGallery, endPoint=LogNN("+outFile+")+DiscardTemplates)"));
  183 +
  184 + TemplateList input;
  185 + input.append(Template(galleryName));
  186 + TemplateList output;
  187 +
  188 + projector->init();
  189 + projector->projectUpdate(input, output);
  190 +
  191 + return output;
  192 +}
  193 +
  194 +// Generate k-NN graph from a gallery, using the current algorithm for comparison.
  195 +// Direct serialization to file system, k-NN graph is not retained in memory
  196 +void br::knnFromGallery(const QString &galleryName, const QString &outFile, int k)
  197 +{
  198 + knnFromGallery(galleryName, false, outFile, k);
  199 +}
  200 +
  201 +// In-memory graph construction
  202 +Neighborhood br::knnFromGallery(const QString &gallery, int k)
  203 +{
  204 + // Nearest neighbor data current stored as template metadata, so retrieve it
  205 + TemplateList res = knnFromGallery(gallery, true, "", k);
  206 +
  207 + Neighborhood neighborhood;
  208 + foreach (const Template &t, res) {
  209 + Neighbors neighbors = t.file.get<Neighbors>("neighbors");
  210 + neighbors.append(neighbors);
  211 + }
  212 +
  213 + return neighborhood;
  214 +}
  215 +
  216 +Neighborhood br::loadkNN(const QString &infile)
  217 +{
  218 + Neighborhood neighborhood;
  219 + QFile file(infile);
  220 + bool success = file.open(QFile::ReadOnly);
  221 + if (!success) qFatal("Failed to open %s for reading.", qPrintable(infile));
  222 + QStringList lines = QString(file.readAll()).split("\n");
  223 + file.close();
  224 + int min_idx = INT_MAX;
  225 + int max_idx = -1;
  226 + int count = 0;
  227 +
  228 + foreach (const QString &line, lines) {
  229 + Neighbors neighbors;
  230 + count++;
  231 + if (line.trimmed().isEmpty()) {
  232 + neighborhood.append(neighbors);
  233 + continue;
  234 + }
  235 + bool off = false;
  236 + QStringList list = line.trimmed().split(",", QString::SkipEmptyParts);
  237 + foreach (const QString &item, list) {
  238 + QStringList parts = item.trimmed().split(":", QString::SkipEmptyParts);
  239 + bool intOK = true;
  240 + bool floatOK = true;
  241 + int idx = parts[0].toInt(&intOK);
  242 + float score = parts[1].toFloat(&floatOK);
  243 +
  244 + if (idx > max_idx)
  245 + max_idx = idx;
  246 + if (idx <min_idx)
  247 + min_idx = idx;
  248 +
  249 + if (idx >= lines.size()) {
  250 + off = true;
  251 + continue;
  252 + }
  253 + neighbors.append(qMakePair(idx, score));
  254 +
  255 +
  256 + if (!intOK && floatOK)
  257 + qFatal("Failed to parse word: %s", qPrintable(item));
151 258 }
  259 + neighborhood.append(neighbors);
152 260 }
153 261 return neighborhood;
154 262 }
155 263  
156   -// Zhu et al. "A Rank-Order Distance based Clustering Algorithm for Face Tagging", CVPR 2011
157   -br::Clusters br::ClusterGallery(const QList<cv::Mat> &simmats, float aggressiveness)
  264 +bool br::savekNN(const Neighborhood &neighborhood, const QString &outfile)
  265 +{
  266 + QFile file(outfile);
  267 + bool success = file.open(QFile::WriteOnly);
  268 + if (!success) qFatal("Failed to open %s for writing.", qPrintable(outfile));
  269 +
  270 + foreach (Neighbors neighbors, neighborhood) {
  271 + QString aLine;
  272 + if (!neighbors.empty())
  273 + {
  274 + aLine.append(QString::number(neighbors[0].first)+":"+QString::number(neighbors[0].second));
  275 + for (int i=1; i < neighbors.size();i++) {
  276 + aLine.append(","+QString::number(neighbors[i].first)+":"+QString::number(neighbors[i].second));
  277 + }
  278 + }
  279 + aLine += "\n";
  280 + file.write(qPrintable(aLine));
  281 + }
  282 + file.close();
  283 + return true;
  284 +}
  285 +
  286 +
  287 +// Rank-order clustering on a pre-computed k-NN graph
  288 +Clusters br::ClusterGraph(Neighborhood neighborhood, float aggressiveness, const QString &csv)
158 289 {
159   - qDebug("Clustering %d simmat(s), aggressiveness %f", simmats.size(), aggressiveness);
160 290  
161   - // Read in gallery parts, keeping top neighbors of each template
162   - Neighborhood neighborhood = getNeighborhood(simmats);
163 291 const int cutoff = neighborhood.first().size();
164 292 const float threshold = 3*cutoff/4 * aggressiveness/5;
165 293  
... ... @@ -235,10 +363,31 @@ br::Clusters br::ClusterGallery(const QList&lt;cv::Mat&gt; &amp;simmats, float aggressiven
235 363 clusters = newClusters;
236 364 neighborhood = newNeighborhood;
237 365 }
  366 +
  367 + if (!csv.isEmpty())
  368 + WriteClusters(clusters, csv);
  369 +
238 370 return clusters;
239 371 }
240 372  
241   -br::Clusters br::ClusterGallery(const QStringList &simmats, float aggressiveness, const QString &csv)
  373 +Clusters br::ClusterGraph(const QString & knnName, float aggressiveness, const QString &csv)
  374 +{
  375 + Neighborhood neighbors = loadkNN(knnName);
  376 + return ClusterGraph(neighbors, aggressiveness, csv);
  377 +}
  378 +
  379 +// Zhu et al. "A Rank-Order Distance based Clustering Algorithm for Face Tagging", CVPR 2011
  380 +br::Clusters br::ClusterSimmat(const QList<cv::Mat> &simmats, float aggressiveness, const QString &csv)
  381 +{
  382 + qDebug("Clustering %d simmat(s), aggressiveness %f", simmats.size(), aggressiveness);
  383 +
  384 + // Read in gallery parts, keeping top neighbors of each template
  385 + Neighborhood neighborhood = knnFromSimmat(simmats);
  386 +
  387 + return ClusterGraph(neighborhood, aggressiveness, csv);
  388 +}
  389 +
  390 +br::Clusters br::ClusterSimmat(const QStringList &simmats, float aggressiveness, const QString &csv)
242 391 {
243 392 QList<cv::Mat> mats;
244 393 foreach (const QString &simmat, simmats) {
... ... @@ -247,11 +396,7 @@ br::Clusters br::ClusterGallery(const QStringList &amp;simmats, float aggressiveness
247 396 mats.append(t);
248 397 }
249 398  
250   - Clusters clusters = ClusterGallery(mats, aggressiveness);
251   -
252   - // Save clusters
253   - if (!csv.isEmpty())
254   - WriteClusters(clusters, csv);
  399 + Clusters clusters = ClusterSimmat(mats, aggressiveness, csv);
255 400 return clusters;
256 401 }
257 402  
... ...
openbr/core/cluster.h
... ... @@ -22,16 +22,45 @@
22 22 #include <QStringList>
23 23 #include <QVector>
24 24 #include <openbr/openbr_plugin.h>
  25 +#include <openbr/plugins/openbr_internal.h>
25 26  
26 27 namespace br
27 28 {
28 29 typedef QList<int> Cluster; // List of indices into galleries
29 30 typedef QVector<Cluster> Clusters;
30 31  
31   - Clusters ClusterGallery(const QList<cv::Mat> &simmats, float aggressiveness);
32   - Clusters ClusterGallery(const QStringList &simmats, float aggressiveness, const QString &csv);
  32 + // generate k-NN graph from pre-computed similarity matrices
  33 + Neighborhood knnFromSimmat(const QStringList &simmats, int k = 20);
  34 + Neighborhood knnFromSimmat(const QList<cv::Mat> &simmats, int k = 20);
  35 +
  36 + // Generate k-NN graph from a gallery, using the current algorithm for comparison.
  37 + // direct serialization to file system.
  38 + void knnFromGallery(const QString &galleryName, const QString & outFile, int k = 20);
  39 + // in memory graph computation
  40 + Neighborhood knnFromGallery(const QString &gallery, int k = 20);
  41 +
  42 + // Load k-NN graph from a file with the following ascii format:
  43 + // One line per sample, each line lists the top k neighbors for the sample as follows:
  44 + // index1:score1,index2:score2,...,indexk:scorek
  45 + Neighborhood loadkNN(const QString &fname);
  46 +
  47 + // Save k-NN graph to file
  48 + bool savekNN(const Neighborhood &neighborhood, const QString &outfile);
  49 +
  50 + // Rank-order clustering on a pre-computed k-NN graph
  51 + Clusters ClusterGraph(Neighborhood neighbors, float aggresssiveness, const QString &csv = "");
  52 + Clusters ClusterGraph(const QString & knnName, float aggressiveness, const QString &csv = "");
  53 +
  54 + // Given a similarity matrix, compute the k-NN graph, then perform rank-order clustering.
  55 + Clusters ClusterSimmat(const QList<cv::Mat> &simmats, float aggressiveness, const QString &csv = "");
  56 + Clusters ClusterSimmat(const QStringList &simmats, float aggressiveness, const QString &csv = "");
  57 +
  58 + // evaluate clustering results in csv, reading ground truth data from gallery input, using truth_property
  59 + // as the key for ground truth labels.
33 60 void EvalClustering(const QString &csv, const QString &input, QString truth_property);
34 61  
  62 + // Read/write clusters from a text format, 1 line = 1 cluster, each line contains comma separated list
  63 + // of assigned indices.
35 64 Clusters ReadClusters(const QString &csv);
36 65 void WriteClusters(const Clusters &clusters, const QString &csv);
37 66 }
... ...
openbr/openbr.cpp
... ... @@ -59,7 +59,7 @@ void br_cat(int num_input_galleries, const char *input_galleries[], const char *
59 59  
60 60 void br_cluster(int num_simmats, const char *simmats[], float aggressiveness, const char *csv)
61 61 {
62   - ClusterGallery(QtUtils::toStringList(num_simmats, simmats), aggressiveness, csv);
  62 + ClusterSimmat(QtUtils::toStringList(num_simmats, simmats), aggressiveness, csv);
63 63 }
64 64  
65 65 void br_combine_masks(int num_input_masks, const char *input_masks[], const char *output_mask, const char *method)
... ...