Commit 3048ffa036c16e2f44ede04e92477b05285f19f6

Authored by Josh Klontz
1 parent 14c7b0fa

implemented Hellinger normalization

openbr/plugins/algorithms.cpp
@@ -42,7 +42,7 @@ class AlgorithmsInitializer : public Initializer @@ -42,7 +42,7 @@ class AlgorithmsInitializer : public Initializer
42 Globals->abbreviations.insert("OpenBR", "FaceRecognition"); 42 Globals->abbreviations.insert("OpenBR", "FaceRecognition");
43 Globals->abbreviations.insert("GenderEstimation", "GenderClassification"); 43 Globals->abbreviations.insert("GenderEstimation", "GenderClassification");
44 Globals->abbreviations.insert("AgeEstimation", "AgeRegression"); 44 Globals->abbreviations.insert("AgeEstimation", "AgeRegression");
45 - Globals->abbreviations.insert("FaceRecognitionHoG", "{PP5Register+Affine(128,128,0.25,0.35)+Cvt(Gray)}+Gradient+Bin(0,360,9,true)+Merge+Integral+RecursiveIntegralSampler(4,2,8,LDA(.95)+Normalize(L1)+Div(3)+ProductQuantization(3,L1,true)[fraction=0.2]):ProductQuantization(true)"); 45 + Globals->abbreviations.insert("FaceRecognitionHoG", "{PP5Register+Affine(128,128,0.25,0.35)+Cvt(Gray)}+Gradient+Bin(0,360,9,true)+Merge+Integral+RecursiveIntegralSampler(4,2,8,Center(Hellinger)+LDA(.95)+Normalize(L1)+Div(3)+ProductQuantization(3,L1,true)[fraction=0.2]):RecursiveProductQuantization");
46 46
47 // Generic Image Processing 47 // Generic Image Processing
48 Globals->abbreviations.insert("SIFT", "Open+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); 48 Globals->abbreviations.insert("SIFT", "Open+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)");
openbr/plugins/normalize.cpp
@@ -90,28 +90,31 @@ public: @@ -90,28 +90,31 @@ public:
90 /*!< */ 90 /*!< */
91 enum Method { Mean, 91 enum Method { Mean,
92 Median, 92 Median,
93 - Range }; 93 + Range,
  94 + Hellinger };
94 95
95 private: 96 private:
96 BR_PROPERTY(Method, method, Mean) 97 BR_PROPERTY(Method, method, Mean)
97 98
98 Mat a, b; // dst = (src - b) / a 99 Mat a, b; // dst = (src - b) / a
99 100
100 - static void _train(Method method, const cv::Mat &m, Mat *ca, Mat *cb, int i) 101 + static void _train(Method method, const cv::Mat &m, const QList<int> &labels, double *ca, double *cb)
101 { 102 {
102 double A = 1, B = 0; 103 double A = 1, B = 0;
103 - if (method == Mean) mean(m.col(i), &A, &B);  
104 - else if (method == Median) median(m.col(i), &A, &B);  
105 - else if (method == Range) range(m.col(i), &A, &B);  
106 - else qFatal("Invalid method.");  
107 - ca->at<double>(0, i) = A;  
108 - cb->at<double>(0, i) = B; 104 + if (method == Mean) mean(m, &A, &B);
  105 + else if (method == Median) median(m, &A, &B);
  106 + else if (method == Range) range(m, &A, &B);
  107 + else if (method == Hellinger) hellinger(m, labels, &A, &B);
  108 + else qFatal("Invalid method.");
  109 + *ca = A;
  110 + *cb = B;
109 } 111 }
110 112
111 void train(const TemplateList &data) 113 void train(const TemplateList &data)
112 { 114 {
113 Mat m; 115 Mat m;
114 OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F); 116 OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F);
  117 + const QList<int> labels = data.labels<int>();
115 const int dims = m.cols; 118 const int dims = m.cols;
116 119
117 vector<Mat> mv, av, bv; 120 vector<Mat> mv, av, bv;
@@ -125,8 +128,8 @@ private: @@ -125,8 +128,8 @@ private:
125 const bool parallel = (data.size() > 1000) && Globals->parallelism; 128 const bool parallel = (data.size() > 1000) && Globals->parallelism;
126 for (size_t c = 0; c < mv.size(); c++) { 129 for (size_t c = 0; c < mv.size(); c++) {
127 for (int i=0; i<dims; i++) 130 for (int i=0; i<dims; i++)
128 - if (parallel) futures.addFuture(QtConcurrent::run(_train, method, mv[c], &av[c], &bv[c], i));  
129 - else _train (method, mv[c], &av[c], &bv[c], i); 131 + if (parallel) futures.addFuture(QtConcurrent::run(_train, method, mv[c].col(i), labels, &av[c].at<double>(0, i), &bv[c].at<double>(0, i)));
  132 + else _train (method, mv[c].col(i), labels, &av[c].at<double>(0, i), &bv[c].at<double>(0, i));
130 av[c] = av[c].reshape(1, data.first().m().rows); 133 av[c] = av[c].reshape(1, data.first().m().rows);
131 bv[c] = bv[c].reshape(1, data.first().m().rows); 134 bv[c] = bv[c].reshape(1, data.first().m().rows);
132 } 135 }
@@ -181,6 +184,30 @@ private: @@ -181,6 +184,30 @@ private:
181 *a = max - min; 184 *a = max - min;
182 *b = min; 185 *b = min;
183 } 186 }
  187 +
  188 + static void hellinger(const Mat &src, const QList<int> &labels, double *a, double *b)
  189 + {
  190 + const QList<float> vals = OpenCVUtils::matrixToVector<float>(src);
  191 + if (vals.size() != labels.size())
  192 + qFatal("Logic error.");
  193 +
  194 + QVector<float> genuineScores; genuineScores.reserve(vals.size());
  195 + QVector<float> impostorScores; impostorScores.reserve(vals.size()*vals.size()/2);
  196 + for (int i=0; i<vals.size(); i++)
  197 + for (int j=i+1; j<vals.size(); j++)
  198 + if (labels[i] == labels[j]) genuineScores.append(vals[i]-vals[j]);
  199 + else impostorScores.append(vals[i]-vals[j]);
  200 +
  201 + float min, max;
  202 + Common::MinMax(vals, &min, &max);
  203 +
  204 + double gm, gs, im, is;
  205 + Common::MeanStdDev(genuineScores, &gm, &gs);
  206 + Common::MeanStdDev(impostorScores, &im, &is);
  207 +
  208 + *a = (max-min)/sqrt(1-sqrt(2*gs*is/(gs*gs+is*is))*exp(-0.25*pow(gm-im,2.0)/(gs*gs+is*is)));
  209 + *b = min;
  210 + }
184 }; 211 };
185 212
186 BR_REGISTER(Transform, CenterTransform) 213 BR_REGISTER(Transform, CenterTransform)