Commit 5e02aadf1514acfdf6307d87e8d24d2d80cd215c
1 parent
ccb66cf5
reintroduced product quantization
Showing
2 changed files
with
89 additions
and
1 deletions
sdk/plugins/quantize.cpp
| ... | ... | @@ -109,6 +109,95 @@ class PackTransform : public UntrainableTransform |
| 109 | 109 | |
| 110 | 110 | BR_REGISTER(Transform, PackTransform) |
| 111 | 111 | |
| 112 | +/*! | |
| 113 | + * \ingroup distances | |
| 114 | + * \brief Product quantization distance. | |
| 115 | + * \author Josh Klontz \cite jklontz | |
| 116 | + */ | |
| 117 | +class ProductQuantizationDistance : public Distance | |
| 118 | +{ | |
| 119 | + Q_OBJECT | |
| 120 | + friend class ProductQuantizationTransform; | |
| 121 | + static QList<Mat> luts; | |
| 122 | + | |
| 123 | + float compare(const Template &a, const Template &b) const | |
| 124 | + { | |
| 125 | + float distance = 0; | |
| 126 | + for (int i=0; i<a.m().cols; i++) | |
| 127 | + distance += pow(luts[i].at<float>(a.m().at<uchar>(0,i), b.m().at<uchar>(0,i)),2); | |
| 128 | + return sqrt(distance); | |
| 129 | + } | |
| 130 | +}; | |
| 131 | + | |
| 132 | +QList<Mat> ProductQuantizationDistance::luts; | |
| 133 | + | |
| 134 | +BR_REGISTER(Distance, ProductQuantizationDistance) | |
| 135 | + | |
| 136 | +/*! | |
| 137 | + * \ingroup transforms | |
| 138 | + * \brief Product quantization \cite jegou11 | |
| 139 | + * \author Josh Klonyz \cite jklontz | |
| 140 | + */ | |
| 141 | +class ProductQuantizationTransform : public Transform | |
| 142 | +{ | |
| 143 | + Q_OBJECT | |
| 144 | + static int counter; | |
| 145 | + int index; | |
| 146 | + Mat centers, lut; | |
| 147 | + | |
| 148 | +public: | |
| 149 | + ProductQuantizationTransform() | |
| 150 | + { | |
| 151 | + index = counter++; | |
| 152 | + ProductQuantizationDistance::luts.append(Mat()); | |
| 153 | + } | |
| 154 | + | |
| 155 | +private: | |
| 156 | + void train(const TemplateList &src) | |
| 157 | + { | |
| 158 | + Mat data = OpenCVUtils::toMat(src.data()); | |
| 159 | + Mat labels; | |
| 160 | + kmeans(data, 256, labels, TermCriteria(TermCriteria::MAX_ITER, 10, 0), 3, KMEANS_PP_CENTERS, centers); | |
| 161 | + | |
| 162 | + lut = Mat(256, 256, CV_32FC1); | |
| 163 | + for (int i=0; i<256; i++) | |
| 164 | + for (int j=0; j<256; j++) | |
| 165 | + lut.at<float>(i,j) = norm(centers.row(i), centers.row(j), NORM_L2); | |
| 166 | + ProductQuantizationDistance::luts[index] = lut; | |
| 167 | + } | |
| 168 | + | |
| 169 | + void project(const Template &src, Template &dst) const | |
| 170 | + { | |
| 171 | + uchar bestIndex = -1; | |
| 172 | + double bestDistance = std::numeric_limits<double>::max(); | |
| 173 | + for (uchar i=0; i<256; i++) { | |
| 174 | + double distance = norm(src, centers.row(i), NORM_L2); | |
| 175 | + if (distance < bestDistance) { | |
| 176 | + bestDistance = distance; | |
| 177 | + bestIndex = i; | |
| 178 | + } | |
| 179 | + } | |
| 180 | + assert(bestIndex != -1); | |
| 181 | + dst = Mat(1, 1, CV_8UC1); | |
| 182 | + dst.m().at<uchar>(0,0) = bestIndex; | |
| 183 | + } | |
| 184 | + | |
| 185 | + void store(QDataStream &stream) const | |
| 186 | + { | |
| 187 | + stream << centers << lut; | |
| 188 | + } | |
| 189 | + | |
| 190 | + void load(QDataStream &stream) | |
| 191 | + { | |
| 192 | + stream >> centers >> lut; | |
| 193 | + ProductQuantizationDistance::luts[index] = lut; | |
| 194 | + } | |
| 195 | +}; | |
| 196 | + | |
| 197 | +int ProductQuantizationTransform::counter = 0; | |
| 198 | + | |
| 199 | +BR_REGISTER(Transform, ProductQuantizationTransform) | |
| 200 | + | |
| 112 | 201 | } // namespace br |
| 113 | 202 | |
| 114 | 203 | #include "quantize.moc" | ... | ... |