Commit 89c438fefc3a9174ecf07d7898f7ad1b583985b2

Authored by Josh Klontz
1 parent 85e03bed

progress on Bayesian quantization

sdk/plugins/algorithms.cpp
@@ -42,7 +42,7 @@ class AlgorithmsInitializer : public Initializer @@ -42,7 +42,7 @@ class AlgorithmsInitializer : public Initializer
42 Globals->abbreviations.insert("OpenBR", "FaceRecognition"); 42 Globals->abbreviations.insert("OpenBR", "FaceRecognition");
43 Globals->abbreviations.insert("GenderEstimation", "GenderClassification"); 43 Globals->abbreviations.insert("GenderEstimation", "GenderClassification");
44 Globals->abbreviations.insert("AgeEstimation", "AgeRegression"); 44 Globals->abbreviations.insert("AgeEstimation", "AgeRegression");
45 - Globals->abbreviations.insert("FaceRecognitionHoG", "Open+Cvt(Gray)+Cascade(FrontalFace)+ASEFEyes+Affine(64,64,0.25,0.35)+Gradient+Bin(0,360,8,true)+Merge+Integral+IntegralSampler+RootNorm+ProductQuantization:ProductQuantization"); 45 + Globals->abbreviations.insert("FaceRecognitionHoG", "Open+Cvt(Gray)+Cascade(FrontalFace)+ASEFEyes+Affine(64,64,0.25,0.35)+Gradient+Bin(0,360,8,true)+Merge+Integral+IntegralSampler+RootNorm+ProductQuantization(2,true):ProductQuantization(true)");
46 46
47 // Generic Image Processing 47 // Generic Image Processing
48 Globals->abbreviations.insert("SIFT", "Open+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); 48 Globals->abbreviations.insert("SIFT", "Open+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)");
sdk/plugins/quantize.cpp
@@ -14,6 +14,8 @@ @@ -14,6 +14,8 @@
14 * limitations under the License. * 14 * limitations under the License. *
15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
16 16
  17 +#include <QFutureSynchronizer>
  18 +#include <QtConcurrentRun>
17 #include <openbr_plugin.h> 19 #include <openbr_plugin.h>
18 20
19 #include "core/opencvutils.h" 21 #include "core/opencvutils.h"
@@ -131,7 +133,7 @@ class ProductQuantizationDistance : public Distance @@ -131,7 +133,7 @@ class ProductQuantizationDistance : public Distance
131 const uchar *bData = b[i].data; 133 const uchar *bData = b[i].data;
132 const float *lut = (const float*)ProductQuantizationLUTs[i].data; 134 const float *lut = (const float*)ProductQuantizationLUTs[i].data;
133 for (int j=0; j<elements; j++) 135 for (int j=0; j<elements; j++)
134 - distance += lut[i*256*256 + aData[j]*256+bData[j]]; 136 + distance += lut[i*256*256 + aData[j]*256+bData[j]];
135 } 137 }
136 if (!bayesian) distance = -log(distance+1); 138 if (!bayesian) distance = -log(distance+1);
137 return distance; 139 return distance;
@@ -164,22 +166,56 @@ public: @@ -164,22 +166,56 @@ public:
164 } 166 }
165 167
166 private: 168 private:
  169 + static double likelihoodRatio(const QPair<int,int> &totals, const QList<int> &targets, const QList<int> &queries)
  170 + {
  171 + int positives = 0, negatives = 0;
  172 + foreach (int t, targets)
  173 + foreach (int q, queries)
  174 + if (t == q) positives++;
  175 + else negatives++;
  176 + return log(float(positives)/float(totals.first)) / (float(negatives)/float(totals.second));
  177 + }
  178 +
  179 + void _train(const Mat &data, const QPair<int,int> &totals, Mat &lut, int i, const QList<int> &templateLabels)
  180 + {
  181 + Mat labels, center;
  182 + kmeans(data.colRange(i*n,(i+1)*n), 256, labels, TermCriteria(TermCriteria::MAX_ITER, 10, 0), 3, KMEANS_PP_CENTERS, center);
  183 + QList<int> clusterLabels = OpenCVUtils::matrixToVector<int>(labels);
  184 + QHash< int, QList<int> > clusters; // QHash<clusterLabel, QList<templateLabel> >
  185 + for (int i=0; i<clusterLabels.size(); i++)
  186 + clusters[clusterLabels[i]].append(templateLabels[i]);
  187 +
  188 + for (int j=0; j<256; j++)
  189 + for (int k=0; k<256; k++)
  190 + lut.at<float>(i,j*256+k) = bayesian ? likelihoodRatio(totals, clusters[j], clusters[k]) :
  191 + norm(center.row(j), center.row(k), NORM_L2);
  192 + centers[i] = center;
  193 + }
  194 +
167 void train(const TemplateList &src) 195 void train(const TemplateList &src)
168 { 196 {
169 Mat data = OpenCVUtils::toMat(src.data()); 197 Mat data = OpenCVUtils::toMat(src.data());
170 if (data.cols % n != 0) qFatal("Expected dimensionality to be divisible by n."); 198 if (data.cols % n != 0) qFatal("Expected dimensionality to be divisible by n.");
  199 + const QList<int> templateLabels = src.labels<int>();
  200 + int totalPositives = 0, totalNegatives = 0;
  201 + for (int i=0; i<templateLabels.size(); i++)
  202 + for (int j=0; j<templateLabels.size(); j++)
  203 + if (templateLabels[i] == templateLabels[j]) totalPositives++;
  204 + else totalNegatives++;
  205 + QPair<int,int> totals(totalPositives, totalNegatives);
171 206
172 Mat &lut = ProductQuantizationLUTs[index]; 207 Mat &lut = ProductQuantizationLUTs[index];
173 lut = Mat(data.cols/n, 256*256, CV_32FC1); 208 lut = Mat(data.cols/n, 256*256, CV_32FC1);
174 209
  210 + for (int i=0; i<lut.rows; i++)
  211 + centers.append(Mat());
  212 +
  213 + QFutureSynchronizer<void> futures;
175 for (int i=0; i<lut.rows; i++) { 214 for (int i=0; i<lut.rows; i++) {
176 - Mat labels, center;  
177 - kmeans(data.colRange(i*n,(i+1)*n), 256, labels, TermCriteria(TermCriteria::MAX_ITER, 10, 0), 3, KMEANS_PP_CENTERS, center);  
178 - for (int j=0; j<256; j++)  
179 - for (int k=0; k<256; k++)  
180 - lut.at<float>(i,j*256+k) = norm(center.row(j), center.row(k), NORM_L2);  
181 - centers.append(center); 215 + if (Globals->parallelism) futures.addFuture(QtConcurrent::run(this, &ProductQuantizationTransform::_train, data, totals, lut, i, templateLabels));
  216 + else _train (data, totals, lut, i, templateLabels);
182 } 217 }
  218 + futures.waitForFinished();
183 } 219 }
184 220
185 void project(const Template &src, Template &dst) const 221 void project(const Template &src, Template &dst) const