diff --git a/openbr/plugins/distance.cpp b/openbr/plugins/distance.cpp index c606584..e082282 100644 --- a/openbr/plugins/distance.cpp +++ b/openbr/plugins/distance.cpp @@ -303,8 +303,10 @@ class HeatMapDistance : public Distance Q_OBJECT Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) - Q_PROPERTY(int rowSize READ get_rowSize WRITE set_rowSize RESET reset_rowSize STORED false) - BR_PROPERTY(int, rowSize, 1) + Q_PROPERTY(int rows READ get_rows WRITE set_rows RESET reset_rows STORED false) + BR_PROPERTY(int, rows, -1) + Q_PROPERTY(int cols READ get_cols WRITE set_cols RESET reset_cols STORED false) + BR_PROPERTY(int, cols, -1) void train(const TemplateList &src) { @@ -319,21 +321,17 @@ class HeatMapDistance : public Distance (void) a; (void) b; } - void compare(const TemplateList &target, const TemplateList &query, Output *output) const + void compare(const TemplateList &target, const TemplateList &query, Output *output) const { - int i = 0; - int j = 0; - for (int index = 0; index < target.size(); index++) { - float score = distance->compare(target[index],query[index]); - - if (j >= rowSize) { - i++; - j = 0; + if (rows*cols > target.size()) qFatal("Incompatible heatmap comparison dimensionality"); + + int index = 0; + for (int col = 0; col < cols; col++) { + for (int row = 0; row < rows; row++) { + float score = distance->compare(target[index],query[index]); + output->setRelative(score, row, col); + index++; } - - output->setRelative(score, i, j); - - j++; } } }; diff --git a/openbr/plugins/output.cpp b/openbr/plugins/output.cpp index bd426e8..eb4e564 100644 --- a/openbr/plugins/output.cpp +++ b/openbr/plugins/output.cpp @@ -80,7 +80,7 @@ class csvOutput : public MatrixOutput for (int i=0; i targetFiles.size()) qFatal("Incompatible heatmap output dimensionality"); + + QStringList lines; + for (int col = 0; col < cols; col++) { + QStringList words; + for (int row = 0; row < rows; row++) + words.append(toString(row,col)); + lines.append(words.join(",")); + } + QtUtils::writeFile(file, lines); + } + + void initialize(const FileList &targetFiles, const FileList &queryFiles) + { + if (rows == -1 || cols == -1) qFatal("heatOutput requires dimensionality"); + + Output::initialize(targetFiles, queryFiles); + data.create(rows, cols, CV_32FC1); + } +}; + +BR_REGISTER(Output, heatOutput) + +/*! + * \ingroup outputs * \brief One score per row. * \author Josh Klontz \cite jklontz */