Commit 0d0f5bd4b676ac45fb260431a8ca2b21fd24d8e2
1 parent
94c3bcaa
generalized product quantization code
Showing
1 changed file
with
29 additions
and
5 deletions
openbr/plugins/quantize.cpp
| ... | ... | @@ -219,19 +219,41 @@ private: |
| 219 | 219 | // Common::KernelDensityEstimation(impostorScores, lut->at<float>(0,j*256+k), hImpostor))); |
| 220 | 220 | } |
| 221 | 221 | |
| 222 | + int getStep(int cols) const | |
| 223 | + { | |
| 224 | + if (n > 0) return n; | |
| 225 | + if (n == 0) return cols; | |
| 226 | + return ceil(float(cols)/abs(n)); | |
| 227 | + } | |
| 228 | + | |
| 229 | + int getOffset(int cols) const | |
| 230 | + { | |
| 231 | + if (n >= 0) return 0; | |
| 232 | + const int step = getStep(cols); | |
| 233 | + return (step - cols%step) % step; | |
| 234 | + } | |
| 235 | + | |
| 236 | + int getDims(int cols) const | |
| 237 | + { | |
| 238 | + const int step = getStep(cols); | |
| 239 | + if (n >= 0) return cols/step; | |
| 240 | + return ceil(float(cols)/step); | |
| 241 | + } | |
| 242 | + | |
| 222 | 243 | void train(const TemplateList &src) |
| 223 | 244 | { |
| 224 | 245 | Mat data = OpenCVUtils::toMat(src.data()); |
| 225 | - if (data.cols % n != 0) qFatal("Expected dimensionality to be divisible by n."); | |
| 246 | + const int step = getStep(data.cols); | |
| 226 | 247 | const QList<int> labels = src.labels<int>(); |
| 227 | 248 | |
| 228 | 249 | Mat &lut = ProductQuantizationLUTs[index]; |
| 229 | - lut = Mat(data.cols/n, 256*256, CV_32FC1); | |
| 250 | + lut = Mat(getDims(data.cols), 256*256, CV_32FC1); | |
| 230 | 251 | |
| 231 | 252 | QList<Mat> subdata, subluts; |
| 253 | + const int offset = getOffset(data.cols); | |
| 232 | 254 | for (int i=0; i<lut.rows; i++) { |
| 233 | 255 | centers.append(Mat()); |
| 234 | - subdata.append(data.colRange(i*n,(i+1)*n)); | |
| 256 | + subdata.append(data.colRange(max(0, i*step-offset), (i+1)*step-offset)); | |
| 235 | 257 | subluts.append(lut.row(i)); |
| 236 | 258 | } |
| 237 | 259 | |
| ... | ... | @@ -260,9 +282,11 @@ private: |
| 260 | 282 | void project(const Template &src, Template &dst) const |
| 261 | 283 | { |
| 262 | 284 | Mat m = src.m().reshape(1, 1); |
| 263 | - dst = Mat(1, m.cols/n, CV_8UC1); | |
| 285 | + const int step = getStep(m.cols); | |
| 286 | + const int offset = getOffset(m.cols); | |
| 287 | + dst = Mat(1, getDims(m.cols), CV_8UC1); | |
| 264 | 288 | for (int i=0; i<dst.m().cols; i++) |
| 265 | - dst.m().at<uchar>(0,i) = getIndex(m.colRange(i*n, (i+1)*n), centers[i]); | |
| 289 | + dst.m().at<uchar>(0,i) = getIndex(m.colRange(max(0, i*step-offset), (i+1)*step-offset), centers[i]); | |
| 266 | 290 | } |
| 267 | 291 | |
| 268 | 292 | void store(QDataStream &stream) const | ... | ... |