Commit 0d0f5bd4b676ac45fb260431a8ca2b21fd24d8e2

Authored by Josh Klontz
1 parent 94c3bcaa

generalized product quantization code

Showing 1 changed file with 29 additions and 5 deletions
openbr/plugins/quantize.cpp
@@ -219,19 +219,41 @@ private: @@ -219,19 +219,41 @@ private:
219 // Common::KernelDensityEstimation(impostorScores, lut->at<float>(0,j*256+k), hImpostor))); 219 // Common::KernelDensityEstimation(impostorScores, lut->at<float>(0,j*256+k), hImpostor)));
220 } 220 }
221 221
  222 + int getStep(int cols) const
  223 + {
  224 + if (n > 0) return n;
  225 + if (n == 0) return cols;
  226 + return ceil(float(cols)/abs(n));
  227 + }
  228 +
  229 + int getOffset(int cols) const
  230 + {
  231 + if (n >= 0) return 0;
  232 + const int step = getStep(cols);
  233 + return (step - cols%step) % step;
  234 + }
  235 +
  236 + int getDims(int cols) const
  237 + {
  238 + const int step = getStep(cols);
  239 + if (n >= 0) return cols/step;
  240 + return ceil(float(cols)/step);
  241 + }
  242 +
222 void train(const TemplateList &src) 243 void train(const TemplateList &src)
223 { 244 {
224 Mat data = OpenCVUtils::toMat(src.data()); 245 Mat data = OpenCVUtils::toMat(src.data());
225 - if (data.cols % n != 0) qFatal("Expected dimensionality to be divisible by n."); 246 + const int step = getStep(data.cols);
226 const QList<int> labels = src.labels<int>(); 247 const QList<int> labels = src.labels<int>();
227 248
228 Mat &lut = ProductQuantizationLUTs[index]; 249 Mat &lut = ProductQuantizationLUTs[index];
229 - lut = Mat(data.cols/n, 256*256, CV_32FC1); 250 + lut = Mat(getDims(data.cols), 256*256, CV_32FC1);
230 251
231 QList<Mat> subdata, subluts; 252 QList<Mat> subdata, subluts;
  253 + const int offset = getOffset(data.cols);
232 for (int i=0; i<lut.rows; i++) { 254 for (int i=0; i<lut.rows; i++) {
233 centers.append(Mat()); 255 centers.append(Mat());
234 - subdata.append(data.colRange(i*n,(i+1)*n)); 256 + subdata.append(data.colRange(max(0, i*step-offset), (i+1)*step-offset));
235 subluts.append(lut.row(i)); 257 subluts.append(lut.row(i));
236 } 258 }
237 259
@@ -260,9 +282,11 @@ private: @@ -260,9 +282,11 @@ private:
260 void project(const Template &src, Template &dst) const 282 void project(const Template &src, Template &dst) const
261 { 283 {
262 Mat m = src.m().reshape(1, 1); 284 Mat m = src.m().reshape(1, 1);
263 - dst = Mat(1, m.cols/n, CV_8UC1); 285 + const int step = getStep(m.cols);
  286 + const int offset = getOffset(m.cols);
  287 + dst = Mat(1, getDims(m.cols), CV_8UC1);
264 for (int i=0; i<dst.m().cols; i++) 288 for (int i=0; i<dst.m().cols; i++)
265 - dst.m().at<uchar>(0,i) = getIndex(m.colRange(i*n, (i+1)*n), centers[i]); 289 + dst.m().at<uchar>(0,i) = getIndex(m.colRange(max(0, i*step-offset), (i+1)*step-offset), centers[i]);
266 } 290 }
267 291
268 void store(QDataStream &stream) const 292 void store(QDataStream &stream) const