diff --git a/openbr/plugins/output/knn.cpp b/openbr/plugins/output/knn.cpp index 0160ecb..bedd989 100644 --- a/openbr/plugins/output/knn.cpp +++ b/openbr/plugins/output/knn.cpp @@ -6,44 +6,80 @@ namespace br { +typedef QPair Pair; + /*! * \ingroup outputs * \brief Outputs the k-Nearest Neighbors from the gallery for each probe. * \author Ben Klein \cite bhklein */ -class knnOutput : public MatrixOutput +class knnOutput : public Output { Q_OBJECT + int rowBlock, columnBlock; + size_t headerSize, k; + cv::Mat blockScores; + ~knnOutput() { - size_t num_probes = (size_t)queryFiles.size(); - if (targetFiles.isEmpty() || queryFiles.isEmpty()) return; - size_t k = file.get("k", 20); + writeBlock(); + } - if ((size_t)targetFiles.size() < k) - qFatal("Gallery size %s is smaller than k = %s.", qPrintable(QString::number(targetFiles.size())), qPrintable(QString::number(k))); + void setBlock(int rowBlock, int columnBlock) + { + if ((rowBlock == 0) && (columnBlock == 0)) { + k = file.get("k", 20); + QFile f(file); + if (!f.open(QFile::WriteOnly)) + qFatal("Unable to open %s for writing.", qPrintable(file)); + size_t querySize = (size_t)queryFiles.size(); + f.write((const char*) &querySize, sizeof(size_t)); + f.write((const char*) &k, sizeof(size_t)); + headerSize = 2 * sizeof(size_t); + } else { + writeBlock(); + } - QFile f(file); - if (!f.open(QFile::WriteOnly)) - qFatal("Unable to open %s for writing.", qPrintable(file)); - f.write((const char*) &num_probes, sizeof(size_t)); - f.write((const char*) &k, sizeof(size_t)); + this->rowBlock = rowBlock; + this->columnBlock = columnBlock; + + int matrixRows = std::min(queryFiles.size()-rowBlock*this->blockRows, blockRows); + int matrixCols = std::min(targetFiles.size()-columnBlock*this->blockCols, blockCols); + + blockScores = cv::Mat(matrixRows, matrixCols, CV_32FC1); + } - QVector neighbors; neighbors.reserve(num_probes*k); + void setRelative(float value, int i, int j) + { + blockScores.at(i,j) = value; + } + + void set(float value, int i, int j) + { + (void) value; (void) i; (void) j; + qFatal("Logic error."); + } + + void writeBlock() + { + QFile f(file); + if (!f.open(QFile::ReadWrite)) + qFatal("Unable to open %s for modifying.", qPrintable(file)); + QVector neighbors; neighbors.reserve(k * blockScores.rows); - for (size_t i=0; i Pair; + for (int i=0; i(data.row(i)), true)) { - if (QString(targetFiles[pair.second]) != QString(queryFiles[i])) { + foreach (const Pair &pair, Common::Sort(OpenCVUtils::matrixToVector(blockScores.row(i)), true)) { + if (QString(targetFiles[pair.second]) != QString(queryFiles[rowBlock*this->blockRows+i])) { Candidate candidate((size_t)pair.second, pair.first); neighbors.push_back(candidate); if (++rank >= k) break; } } } - f.write((const char*) neighbors.data(), num_probes * k * sizeof(Candidate)); + f.seek(headerSize + sizeof(Candidate)*quint64(rowBlock*this->blockRows)*k); + f.write((const char*) neighbors.data(), blockScores.rows * k * sizeof(Candidate)); f.close(); } };