Commit 4e9226852a5a7c26b39ef8fefab4c1fc58ef215b
1 parent
81e9c1c2
generalized BayesianQuantizationDistance
Showing
1 changed file
with
36 additions
and
22 deletions
openbr/plugins/quantize.cpp
| @@ -64,40 +64,54 @@ BR_REGISTER(Transform, QuantizeTransform) | @@ -64,40 +64,54 @@ BR_REGISTER(Transform, QuantizeTransform) | ||
| 64 | class BayesianQuantizationDistance : public Distance | 64 | class BayesianQuantizationDistance : public Distance |
| 65 | { | 65 | { |
| 66 | Q_OBJECT | 66 | Q_OBJECT |
| 67 | - QVector<float> loglikelihood; | 67 | + Q_PROPERTY(int K READ get_K WRITE set_K RESET reset_K STORED false) |
| 68 | + BR_PROPERTY(int, K, 1) | ||
| 68 | 69 | ||
| 69 | - void train(const TemplateList &src) | 70 | + Mat labels, centers; |
| 71 | + | ||
| 72 | + static void computeLogLikelihood(const Mat &data, const QList<int> &labels, float *loglikelihood) | ||
| 70 | { | 73 | { |
| 71 | - if (src.first().size() > 1) | ||
| 72 | - qFatal("Expected sigle matrix templates."); | 74 | + const QList<uchar> values = OpenCVUtils::matrixToVector<uchar>(data); |
| 75 | + if (values.size() != labels.size()) | ||
| 76 | + qFatal("Logic error."); | ||
| 77 | + | ||
| 78 | + QVector<int> genuines(256*256,0), impostors(256*256,0); | ||
| 79 | + for (int i=0; i<labels.size(); i++) | ||
| 80 | + for (int j=0; j<labels.size(); j++) | ||
| 81 | + if (labels[i] == labels[j]) genuines[256*values[i]+values[j]]++; | ||
| 82 | + else impostors[256*values[i]+values[j]]++; | ||
| 73 | 83 | ||
| 74 | - Mat data = OpenCVUtils::toMat(src.data()); | ||
| 75 | - QList<int> labels = src.labels<int>(); | ||
| 76 | - | ||
| 77 | - QVector<qint64> genuines(256*256,0), impostors(256*256,0); | ||
| 78 | - for (int i=0; i<labels.size(); i++) { | ||
| 79 | - const uchar *a = data.ptr(i); | ||
| 80 | - for (int j=0; j<labels.size(); j++) { | ||
| 81 | - const uchar *b = data.ptr(j); | ||
| 82 | - const bool genuine = (labels[i] == labels[j]); | ||
| 83 | - for (int k=0; k<data.cols; k++) | ||
| 84 | - genuine ? genuines[256*a[k]+b[k]]++ : impostors[256*a[k]+b[k]]++; | ||
| 85 | - } | ||
| 86 | - } | ||
| 87 | 84 | ||
| 88 | - qint64 totalGenuines(0), totalImpostors(0); | 85 | + int totalGenuines(0), totalImpostors(0); |
| 89 | for (int i=0; i<256*256; i++) { | 86 | for (int i=0; i<256*256; i++) { |
| 90 | totalGenuines += genuines[i]; | 87 | totalGenuines += genuines[i]; |
| 91 | totalImpostors += impostors[i]; | 88 | totalImpostors += impostors[i]; |
| 92 | } | 89 | } |
| 93 | 90 | ||
| 94 | - loglikelihood = QVector<float>(256*256); | ||
| 95 | for (int i=0; i<256; i++) | 91 | for (int i=0; i<256; i++) |
| 96 | for (int j=0; j<256; j++) | 92 | for (int j=0; j<256; j++) |
| 97 | loglikelihood[i*256+j] = log((double(genuines[i*256+j]+genuines[j*256+i]+1)/totalGenuines)/ | 93 | loglikelihood[i*256+j] = log((double(genuines[i*256+j]+genuines[j*256+i]+1)/totalGenuines)/ |
| 98 | (double(impostors[i*256+j]+impostors[j*256+i]+1)/totalImpostors)); | 94 | (double(impostors[i*256+j]+impostors[j*256+i]+1)/totalImpostors)); |
| 99 | } | 95 | } |
| 100 | 96 | ||
| 97 | + void train(const TemplateList &src) | ||
| 98 | + { | ||
| 99 | + if ((src.first().size() > 1) || (src.first().m().type() != CV_8UC1)) | ||
| 100 | + qFatal("Expected sigle matrix templates of type CV_8UC1."); | ||
| 101 | + | ||
| 102 | + const Mat data = OpenCVUtils::toMat(src.data()); | ||
| 103 | + const QList<int> templateLabels = src.labels<int>(); | ||
| 104 | + Mat loglikelihoods(data.cols, 256*256, CV_32FC1); | ||
| 105 | + | ||
| 106 | + QFutureSynchronizer<void> futures; | ||
| 107 | + for (int i=0; i<data.cols; i++) | ||
| 108 | + if (Globals->parallelism) futures.addFuture(QtConcurrent::run(&BayesianQuantizationDistance::computeLogLikelihood, data.col(i), templateLabels, loglikelihoods.ptr<float>(i))); | ||
| 109 | + else computeLogLikelihood( data.col(i), templateLabels, loglikelihoods.ptr<float>(i)); | ||
| 110 | + futures.waitForFinished(); | ||
| 111 | + | ||
| 112 | + kmeans(loglikelihoods, K, labels, TermCriteria(TermCriteria::MAX_ITER, 10, 0), 3, KMEANS_PP_CENTERS, centers); | ||
| 113 | + } | ||
| 114 | + | ||
| 101 | float compare(const Template &a, const Template &b) const | 115 | float compare(const Template &a, const Template &b) const |
| 102 | { | 116 | { |
| 103 | const uchar *aData = a.m().data; | 117 | const uchar *aData = a.m().data; |
| @@ -105,18 +119,18 @@ class BayesianQuantizationDistance : public Distance | @@ -105,18 +119,18 @@ class BayesianQuantizationDistance : public Distance | ||
| 105 | const int size = a.m().rows * a.m().cols; | 119 | const int size = a.m().rows * a.m().cols; |
| 106 | float likelihood = 0; | 120 | float likelihood = 0; |
| 107 | for (int i=0; i<size; i++) | 121 | for (int i=0; i<size; i++) |
| 108 | - likelihood += loglikelihood[256*aData[i]+bData[i]]; | 122 | + likelihood += centers.ptr<float>(labels.at<int>(i))[256*aData[i]+bData[i]]; |
| 109 | return likelihood; | 123 | return likelihood; |
| 110 | } | 124 | } |
| 111 | 125 | ||
| 112 | void load(QDataStream &stream) | 126 | void load(QDataStream &stream) |
| 113 | { | 127 | { |
| 114 | - stream >> loglikelihood; | 128 | + stream >> labels >> centers; |
| 115 | } | 129 | } |
| 116 | 130 | ||
| 117 | void store(QDataStream &stream) const | 131 | void store(QDataStream &stream) const |
| 118 | { | 132 | { |
| 119 | - stream << loglikelihood; | 133 | + stream << labels << centers; |
| 120 | } | 134 | } |
| 121 | }; | 135 | }; |
| 122 | 136 |