Commit 89c438fefc3a9174ecf07d7898f7ad1b583985b2
1 parent
85e03bed
progress on Bayesian quantization
Showing
2 changed files
with
44 additions
and
8 deletions
sdk/plugins/algorithms.cpp
| ... | ... | @@ -42,7 +42,7 @@ class AlgorithmsInitializer : public Initializer |
| 42 | 42 | Globals->abbreviations.insert("OpenBR", "FaceRecognition"); |
| 43 | 43 | Globals->abbreviations.insert("GenderEstimation", "GenderClassification"); |
| 44 | 44 | Globals->abbreviations.insert("AgeEstimation", "AgeRegression"); |
| 45 | - Globals->abbreviations.insert("FaceRecognitionHoG", "Open+Cvt(Gray)+Cascade(FrontalFace)+ASEFEyes+Affine(64,64,0.25,0.35)+Gradient+Bin(0,360,8,true)+Merge+Integral+IntegralSampler+RootNorm+ProductQuantization:ProductQuantization"); | |
| 45 | + Globals->abbreviations.insert("FaceRecognitionHoG", "Open+Cvt(Gray)+Cascade(FrontalFace)+ASEFEyes+Affine(64,64,0.25,0.35)+Gradient+Bin(0,360,8,true)+Merge+Integral+IntegralSampler+RootNorm+ProductQuantization(2,true):ProductQuantization(true)"); | |
| 46 | 46 | |
| 47 | 47 | // Generic Image Processing |
| 48 | 48 | Globals->abbreviations.insert("SIFT", "Open+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); | ... | ... |
sdk/plugins/quantize.cpp
| ... | ... | @@ -14,6 +14,8 @@ |
| 14 | 14 | * limitations under the License. * |
| 15 | 15 | * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ |
| 16 | 16 | |
| 17 | +#include <QFutureSynchronizer> | |
| 18 | +#include <QtConcurrentRun> | |
| 17 | 19 | #include <openbr_plugin.h> |
| 18 | 20 | |
| 19 | 21 | #include "core/opencvutils.h" |
| ... | ... | @@ -131,7 +133,7 @@ class ProductQuantizationDistance : public Distance |
| 131 | 133 | const uchar *bData = b[i].data; |
| 132 | 134 | const float *lut = (const float*)ProductQuantizationLUTs[i].data; |
| 133 | 135 | for (int j=0; j<elements; j++) |
| 134 | - distance += lut[i*256*256 + aData[j]*256+bData[j]]; | |
| 136 | + distance += lut[i*256*256 + aData[j]*256+bData[j]]; | |
| 135 | 137 | } |
| 136 | 138 | if (!bayesian) distance = -log(distance+1); |
| 137 | 139 | return distance; |
| ... | ... | @@ -164,22 +166,56 @@ public: |
| 164 | 166 | } |
| 165 | 167 | |
| 166 | 168 | private: |
| 169 | + static double likelihoodRatio(const QPair<int,int> &totals, const QList<int> &targets, const QList<int> &queries) | |
| 170 | + { | |
| 171 | + int positives = 0, negatives = 0; | |
| 172 | + foreach (int t, targets) | |
| 173 | + foreach (int q, queries) | |
| 174 | + if (t == q) positives++; | |
| 175 | + else negatives++; | |
| 176 | + return log(float(positives)/float(totals.first)) / (float(negatives)/float(totals.second)); | |
| 177 | + } | |
| 178 | + | |
| 179 | + void _train(const Mat &data, const QPair<int,int> &totals, Mat &lut, int i, const QList<int> &templateLabels) | |
| 180 | + { | |
| 181 | + Mat labels, center; | |
| 182 | + kmeans(data.colRange(i*n,(i+1)*n), 256, labels, TermCriteria(TermCriteria::MAX_ITER, 10, 0), 3, KMEANS_PP_CENTERS, center); | |
| 183 | + QList<int> clusterLabels = OpenCVUtils::matrixToVector<int>(labels); | |
| 184 | + QHash< int, QList<int> > clusters; // QHash<clusterLabel, QList<templateLabel> > | |
| 185 | + for (int i=0; i<clusterLabels.size(); i++) | |
| 186 | + clusters[clusterLabels[i]].append(templateLabels[i]); | |
| 187 | + | |
| 188 | + for (int j=0; j<256; j++) | |
| 189 | + for (int k=0; k<256; k++) | |
| 190 | + lut.at<float>(i,j*256+k) = bayesian ? likelihoodRatio(totals, clusters[j], clusters[k]) : | |
| 191 | + norm(center.row(j), center.row(k), NORM_L2); | |
| 192 | + centers[i] = center; | |
| 193 | + } | |
| 194 | + | |
| 167 | 195 | void train(const TemplateList &src) |
| 168 | 196 | { |
| 169 | 197 | Mat data = OpenCVUtils::toMat(src.data()); |
| 170 | 198 | if (data.cols % n != 0) qFatal("Expected dimensionality to be divisible by n."); |
| 199 | + const QList<int> templateLabels = src.labels<int>(); | |
| 200 | + int totalPositives = 0, totalNegatives = 0; | |
| 201 | + for (int i=0; i<templateLabels.size(); i++) | |
| 202 | + for (int j=0; j<templateLabels.size(); j++) | |
| 203 | + if (templateLabels[i] == templateLabels[j]) totalPositives++; | |
| 204 | + else totalNegatives++; | |
| 205 | + QPair<int,int> totals(totalPositives, totalNegatives); | |
| 171 | 206 | |
| 172 | 207 | Mat &lut = ProductQuantizationLUTs[index]; |
| 173 | 208 | lut = Mat(data.cols/n, 256*256, CV_32FC1); |
| 174 | 209 | |
| 210 | + for (int i=0; i<lut.rows; i++) | |
| 211 | + centers.append(Mat()); | |
| 212 | + | |
| 213 | + QFutureSynchronizer<void> futures; | |
| 175 | 214 | for (int i=0; i<lut.rows; i++) { |
| 176 | - Mat labels, center; | |
| 177 | - kmeans(data.colRange(i*n,(i+1)*n), 256, labels, TermCriteria(TermCriteria::MAX_ITER, 10, 0), 3, KMEANS_PP_CENTERS, center); | |
| 178 | - for (int j=0; j<256; j++) | |
| 179 | - for (int k=0; k<256; k++) | |
| 180 | - lut.at<float>(i,j*256+k) = norm(center.row(j), center.row(k), NORM_L2); | |
| 181 | - centers.append(center); | |
| 215 | + if (Globals->parallelism) futures.addFuture(QtConcurrent::run(this, &ProductQuantizationTransform::_train, data, totals, lut, i, templateLabels)); | |
| 216 | + else _train (data, totals, lut, i, templateLabels); | |
| 182 | 217 | } |
| 218 | + futures.waitForFinished(); | |
| 183 | 219 | } |
| 184 | 220 | |
| 185 | 221 | void project(const Template &src, Template &dst) const | ... | ... |