/* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * Copyright 2012 The MITRE Corporation * * * * Licensed under the Apache License, Version 2.0 (the "License"); * * you may not use this file except in compliance with the License. * * You may obtain a copy of the License at * * * * http://www.apache.org/licenses/LICENSE-2.0 * * * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * * See the License for the specific language governing permissions and * * limitations under the License. * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ #include #include #include #include #include #include #include #include "core/bee.h" #include "core/cluster.h" typedef QPair Neighbor; // QPair typedef QList Neighbors; typedef QVector Neighborhood; // Compare function used to order neighbors from highest to lowest similarity static bool compareNeighbors(const Neighbor &a, const Neighbor &b) { if (a.second == b.second) return a.first < b.first; return a.second > b.second; } // Zhu et al. "A Rank-Order Distance based Clustering Algorithm for Face Tagging", CVPR 2011 // Ob(x) in eq. 1, modified to consider 0/1 as ground truth imposter/genuine. static int indexOf(const Neighbors &neighbors, int i) { for (int j=0; j::max(); if ((neighborhood[b][indexA].second == 1) || (neighborhood[a][indexB].second == 1)) return 0; if ((neighborhood[b][indexA].second == 0) || (neighborhood[a][indexB].second == 0)) return std::numeric_limits::max(); int distanceA = asymmetricalROD(neighborhood, a, b); int distanceB = asymmetricalROD(neighborhood, b, a); return 1.f * (distanceA + distanceB) / std::min(indexA+1, indexB+1); } Neighborhood getNeighborhood(const QStringList &simmats) { Neighborhood neighborhood; float globalMax = -std::numeric_limits::max(); float globalMin = std::numeric_limits::max(); int numGalleries = (int)sqrt((float)simmats.size()); if (numGalleries*numGalleries != simmats.size()) qFatal("cluser.cpp readGalleries incorrect number of similarity matrices."); // Process each simmat for (int i=0; i allNeighbors; int currentRows = -1; int columnOffset = 0; for (int j=0; j(k,l); if ((i==j) && (k==l)) continue; // Skips self-similarity scores if ((val != -std::numeric_limits::infinity()) && (val != std::numeric_limits::infinity())) { globalMax = std::max(globalMax, val); globalMin = std::min(globalMin, val); } neighbors.append(Neighbor(l+columnOffset, val)); } } columnOffset += m.cols; } // Keep the top matches for (int j=0; j::infinity()) neighbor.second = 0; else if (neighbor.second == std::numeric_limits::infinity()) neighbor.second = 1; else neighbor.second = (neighbor.second - globalMin) / (globalMax - globalMin); } } return neighborhood; } // Zhu et al. "A Rank-Order Distance based Clustering Algorithm for Face Tagging", CVPR 2011 br::Clusters br::ClusterGallery(const QStringList &simmats, float aggressiveness, const QString &csv) { qDebug("Clustering %d simmat(s)", simmats.size()); // Read in gallery parts, keeping top neighbors of each template Neighborhood neighborhood = getNeighborhood(simmats); const int cutoff = neighborhood.first().size(); const float threshold = 3*cutoff/4 * aggressiveness/5; // Initialize clusters Clusters clusters(neighborhood.size()); for (int i=0; i nextClusterIDs(neighborhood.size()); for (int i=0; i clusterIDLUT; QList allClusterIDs = QSet::fromList(nextClusterIDs.toList()).values(); for (int i=0; i= clusters.size()); clusters = newClusters; neighborhood = newNeighborhood; } // Save clusters if (!csv.isEmpty()) WriteClusters(clusters, csv); return clusters; } // Santo Fortunato "Community detection in graphs", Physics Reports 486 (2010) // wI or wII metric (page 148) float wallaceMetric(const br::Clusters &clusters, const QVector &indices) { int matches = 0; int total = 0; foreach (const QList &cluster, clusters) { for (int i=0; i &indicesA, const QVector &indicesB) { int a[2][2] = {{0,0},{0,0}}; for (int i=0; i labels = TemplateList::fromInput(input).files().labels(); QHash labelToIndex; int nClusters = 0; for (int i=0; i()); QVector truthIndices(labels.size()); for (int i=0; i testIndices(labels.size()); for (int i=0; i