Commit 0d0f5bd4b676ac45fb260431a8ca2b21fd24d8e2

Authored by Josh Klontz
1 parent 94c3bcaa

generalized product quantization code

Showing 1 changed file with 29 additions and 5 deletions
openbr/plugins/quantize.cpp
... ... @@ -219,19 +219,41 @@ private:
219 219 // Common::KernelDensityEstimation(impostorScores, lut->at<float>(0,j*256+k), hImpostor)));
220 220 }
221 221  
  222 + int getStep(int cols) const
  223 + {
  224 + if (n > 0) return n;
  225 + if (n == 0) return cols;
  226 + return ceil(float(cols)/abs(n));
  227 + }
  228 +
  229 + int getOffset(int cols) const
  230 + {
  231 + if (n >= 0) return 0;
  232 + const int step = getStep(cols);
  233 + return (step - cols%step) % step;
  234 + }
  235 +
  236 + int getDims(int cols) const
  237 + {
  238 + const int step = getStep(cols);
  239 + if (n >= 0) return cols/step;
  240 + return ceil(float(cols)/step);
  241 + }
  242 +
222 243 void train(const TemplateList &src)
223 244 {
224 245 Mat data = OpenCVUtils::toMat(src.data());
225   - if (data.cols % n != 0) qFatal("Expected dimensionality to be divisible by n.");
  246 + const int step = getStep(data.cols);
226 247 const QList<int> labels = src.labels<int>();
227 248  
228 249 Mat &lut = ProductQuantizationLUTs[index];
229   - lut = Mat(data.cols/n, 256*256, CV_32FC1);
  250 + lut = Mat(getDims(data.cols), 256*256, CV_32FC1);
230 251  
231 252 QList<Mat> subdata, subluts;
  253 + const int offset = getOffset(data.cols);
232 254 for (int i=0; i<lut.rows; i++) {
233 255 centers.append(Mat());
234   - subdata.append(data.colRange(i*n,(i+1)*n));
  256 + subdata.append(data.colRange(max(0, i*step-offset), (i+1)*step-offset));
235 257 subluts.append(lut.row(i));
236 258 }
237 259  
... ... @@ -260,9 +282,11 @@ private:
260 282 void project(const Template &src, Template &dst) const
261 283 {
262 284 Mat m = src.m().reshape(1, 1);
263   - dst = Mat(1, m.cols/n, CV_8UC1);
  285 + const int step = getStep(m.cols);
  286 + const int offset = getOffset(m.cols);
  287 + dst = Mat(1, getDims(m.cols), CV_8UC1);
264 288 for (int i=0; i<dst.m().cols; i++)
265   - dst.m().at<uchar>(0,i) = getIndex(m.colRange(i*n, (i+1)*n), centers[i]);
  289 + dst.m().at<uchar>(0,i) = getIndex(m.colRange(max(0, i*step-offset), (i+1)*step-offset), centers[i]);
266 290 }
267 291  
268 292 void store(QDataStream &stream) const
... ...