Commit 3048ffa036c16e2f44ede04e92477b05285f19f6
1 parent
14c7b0fa
implemented Hellinger normalization
Showing
2 changed files
with
38 additions
and
11 deletions
openbr/plugins/algorithms.cpp
| ... | ... | @@ -42,7 +42,7 @@ class AlgorithmsInitializer : public Initializer |
| 42 | 42 | Globals->abbreviations.insert("OpenBR", "FaceRecognition"); |
| 43 | 43 | Globals->abbreviations.insert("GenderEstimation", "GenderClassification"); |
| 44 | 44 | Globals->abbreviations.insert("AgeEstimation", "AgeRegression"); |
| 45 | - Globals->abbreviations.insert("FaceRecognitionHoG", "{PP5Register+Affine(128,128,0.25,0.35)+Cvt(Gray)}+Gradient+Bin(0,360,9,true)+Merge+Integral+RecursiveIntegralSampler(4,2,8,LDA(.95)+Normalize(L1)+Div(3)+ProductQuantization(3,L1,true)[fraction=0.2]):ProductQuantization(true)"); | |
| 45 | + Globals->abbreviations.insert("FaceRecognitionHoG", "{PP5Register+Affine(128,128,0.25,0.35)+Cvt(Gray)}+Gradient+Bin(0,360,9,true)+Merge+Integral+RecursiveIntegralSampler(4,2,8,Center(Hellinger)+LDA(.95)+Normalize(L1)+Div(3)+ProductQuantization(3,L1,true)[fraction=0.2]):RecursiveProductQuantization"); | |
| 46 | 46 | |
| 47 | 47 | // Generic Image Processing |
| 48 | 48 | Globals->abbreviations.insert("SIFT", "Open+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); | ... | ... |
openbr/plugins/normalize.cpp
| ... | ... | @@ -90,28 +90,31 @@ public: |
| 90 | 90 | /*!< */ |
| 91 | 91 | enum Method { Mean, |
| 92 | 92 | Median, |
| 93 | - Range }; | |
| 93 | + Range, | |
| 94 | + Hellinger }; | |
| 94 | 95 | |
| 95 | 96 | private: |
| 96 | 97 | BR_PROPERTY(Method, method, Mean) |
| 97 | 98 | |
| 98 | 99 | Mat a, b; // dst = (src - b) / a |
| 99 | 100 | |
| 100 | - static void _train(Method method, const cv::Mat &m, Mat *ca, Mat *cb, int i) | |
| 101 | + static void _train(Method method, const cv::Mat &m, const QList<int> &labels, double *ca, double *cb) | |
| 101 | 102 | { |
| 102 | 103 | double A = 1, B = 0; |
| 103 | - if (method == Mean) mean(m.col(i), &A, &B); | |
| 104 | - else if (method == Median) median(m.col(i), &A, &B); | |
| 105 | - else if (method == Range) range(m.col(i), &A, &B); | |
| 106 | - else qFatal("Invalid method."); | |
| 107 | - ca->at<double>(0, i) = A; | |
| 108 | - cb->at<double>(0, i) = B; | |
| 104 | + if (method == Mean) mean(m, &A, &B); | |
| 105 | + else if (method == Median) median(m, &A, &B); | |
| 106 | + else if (method == Range) range(m, &A, &B); | |
| 107 | + else if (method == Hellinger) hellinger(m, labels, &A, &B); | |
| 108 | + else qFatal("Invalid method."); | |
| 109 | + *ca = A; | |
| 110 | + *cb = B; | |
| 109 | 111 | } |
| 110 | 112 | |
| 111 | 113 | void train(const TemplateList &data) |
| 112 | 114 | { |
| 113 | 115 | Mat m; |
| 114 | 116 | OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F); |
| 117 | + const QList<int> labels = data.labels<int>(); | |
| 115 | 118 | const int dims = m.cols; |
| 116 | 119 | |
| 117 | 120 | vector<Mat> mv, av, bv; |
| ... | ... | @@ -125,8 +128,8 @@ private: |
| 125 | 128 | const bool parallel = (data.size() > 1000) && Globals->parallelism; |
| 126 | 129 | for (size_t c = 0; c < mv.size(); c++) { |
| 127 | 130 | for (int i=0; i<dims; i++) |
| 128 | - if (parallel) futures.addFuture(QtConcurrent::run(_train, method, mv[c], &av[c], &bv[c], i)); | |
| 129 | - else _train (method, mv[c], &av[c], &bv[c], i); | |
| 131 | + if (parallel) futures.addFuture(QtConcurrent::run(_train, method, mv[c].col(i), labels, &av[c].at<double>(0, i), &bv[c].at<double>(0, i))); | |
| 132 | + else _train (method, mv[c].col(i), labels, &av[c].at<double>(0, i), &bv[c].at<double>(0, i)); | |
| 130 | 133 | av[c] = av[c].reshape(1, data.first().m().rows); |
| 131 | 134 | bv[c] = bv[c].reshape(1, data.first().m().rows); |
| 132 | 135 | } |
| ... | ... | @@ -181,6 +184,30 @@ private: |
| 181 | 184 | *a = max - min; |
| 182 | 185 | *b = min; |
| 183 | 186 | } |
| 187 | + | |
| 188 | + static void hellinger(const Mat &src, const QList<int> &labels, double *a, double *b) | |
| 189 | + { | |
| 190 | + const QList<float> vals = OpenCVUtils::matrixToVector<float>(src); | |
| 191 | + if (vals.size() != labels.size()) | |
| 192 | + qFatal("Logic error."); | |
| 193 | + | |
| 194 | + QVector<float> genuineScores; genuineScores.reserve(vals.size()); | |
| 195 | + QVector<float> impostorScores; impostorScores.reserve(vals.size()*vals.size()/2); | |
| 196 | + for (int i=0; i<vals.size(); i++) | |
| 197 | + for (int j=i+1; j<vals.size(); j++) | |
| 198 | + if (labels[i] == labels[j]) genuineScores.append(vals[i]-vals[j]); | |
| 199 | + else impostorScores.append(vals[i]-vals[j]); | |
| 200 | + | |
| 201 | + float min, max; | |
| 202 | + Common::MinMax(vals, &min, &max); | |
| 203 | + | |
| 204 | + double gm, gs, im, is; | |
| 205 | + Common::MeanStdDev(genuineScores, &gm, &gs); | |
| 206 | + Common::MeanStdDev(impostorScores, &im, &is); | |
| 207 | + | |
| 208 | + *a = (max-min)/sqrt(1-sqrt(2*gs*is/(gs*gs+is*is))*exp(-0.25*pow(gm-im,2.0)/(gs*gs+is*is))); | |
| 209 | + *b = min; | |
| 210 | + } | |
| 184 | 211 | }; |
| 185 | 212 | |
| 186 | 213 | BR_REGISTER(Transform, CenterTransform) | ... | ... |