From 89c438fefc3a9174ecf07d7898f7ad1b583985b2 Mon Sep 17 00:00:00 2001 From: Josh Klontz Date: Sun, 17 Mar 2013 23:32:49 -0400 Subject: [PATCH] progress on Bayesian quantization --- sdk/plugins/algorithms.cpp | 2 +- sdk/plugins/quantize.cpp | 50 +++++++++++++++++++++++++++++++++++++++++++------- 2 files changed, 44 insertions(+), 8 deletions(-) diff --git a/sdk/plugins/algorithms.cpp b/sdk/plugins/algorithms.cpp index a46caad..80a9a8c 100644 --- a/sdk/plugins/algorithms.cpp +++ b/sdk/plugins/algorithms.cpp @@ -42,7 +42,7 @@ class AlgorithmsInitializer : public Initializer Globals->abbreviations.insert("OpenBR", "FaceRecognition"); Globals->abbreviations.insert("GenderEstimation", "GenderClassification"); Globals->abbreviations.insert("AgeEstimation", "AgeRegression"); - Globals->abbreviations.insert("FaceRecognitionHoG", "Open+Cvt(Gray)+Cascade(FrontalFace)+ASEFEyes+Affine(64,64,0.25,0.35)+Gradient+Bin(0,360,8,true)+Merge+Integral+IntegralSampler+RootNorm+ProductQuantization:ProductQuantization"); + Globals->abbreviations.insert("FaceRecognitionHoG", "Open+Cvt(Gray)+Cascade(FrontalFace)+ASEFEyes+Affine(64,64,0.25,0.35)+Gradient+Bin(0,360,8,true)+Merge+Integral+IntegralSampler+RootNorm+ProductQuantization(2,true):ProductQuantization(true)"); // Generic Image Processing Globals->abbreviations.insert("SIFT", "Open+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); diff --git a/sdk/plugins/quantize.cpp b/sdk/plugins/quantize.cpp index 6411e78..6624a17 100644 --- a/sdk/plugins/quantize.cpp +++ b/sdk/plugins/quantize.cpp @@ -14,6 +14,8 @@ * limitations under the License. * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ +#include +#include #include #include "core/opencvutils.h" @@ -131,7 +133,7 @@ class ProductQuantizationDistance : public Distance const uchar *bData = b[i].data; const float *lut = (const float*)ProductQuantizationLUTs[i].data; for (int j=0; j &totals, const QList &targets, const QList &queries) + { + int positives = 0, negatives = 0; + foreach (int t, targets) + foreach (int q, queries) + if (t == q) positives++; + else negatives++; + return log(float(positives)/float(totals.first)) / (float(negatives)/float(totals.second)); + } + + void _train(const Mat &data, const QPair &totals, Mat &lut, int i, const QList &templateLabels) + { + Mat labels, center; + kmeans(data.colRange(i*n,(i+1)*n), 256, labels, TermCriteria(TermCriteria::MAX_ITER, 10, 0), 3, KMEANS_PP_CENTERS, center); + QList clusterLabels = OpenCVUtils::matrixToVector(labels); + QHash< int, QList > clusters; // QHash > + for (int i=0; i(i,j*256+k) = bayesian ? likelihoodRatio(totals, clusters[j], clusters[k]) : + norm(center.row(j), center.row(k), NORM_L2); + centers[i] = center; + } + void train(const TemplateList &src) { Mat data = OpenCVUtils::toMat(src.data()); if (data.cols % n != 0) qFatal("Expected dimensionality to be divisible by n."); + const QList templateLabels = src.labels(); + int totalPositives = 0, totalNegatives = 0; + for (int i=0; i totals(totalPositives, totalNegatives); Mat &lut = ProductQuantizationLUTs[index]; lut = Mat(data.cols/n, 256*256, CV_32FC1); + for (int i=0; i futures; for (int i=0; i(i,j*256+k) = norm(center.row(j), center.row(k), NORM_L2); - centers.append(center); + if (Globals->parallelism) futures.addFuture(QtConcurrent::run(this, &ProductQuantizationTransform::_train, data, totals, lut, i, templateLabels)); + else _train (data, totals, lut, i, templateLabels); } + futures.waitForFinished(); } void project(const Template &src, Template &dst) const -- libgit2 0.21.4