Commit 4e9226852a5a7c26b39ef8fefab4c1fc58ef215b

Authored by Josh Klontz
1 parent 81e9c1c2

generalized BayesianQuantizationDistance

Showing 1 changed file with 36 additions and 22 deletions
openbr/plugins/quantize.cpp
@@ -64,40 +64,54 @@ BR_REGISTER(Transform, QuantizeTransform) @@ -64,40 +64,54 @@ BR_REGISTER(Transform, QuantizeTransform)
64 class BayesianQuantizationDistance : public Distance 64 class BayesianQuantizationDistance : public Distance
65 { 65 {
66 Q_OBJECT 66 Q_OBJECT
67 - QVector<float> loglikelihood; 67 + Q_PROPERTY(int K READ get_K WRITE set_K RESET reset_K STORED false)
  68 + BR_PROPERTY(int, K, 1)
68 69
69 - void train(const TemplateList &src) 70 + Mat labels, centers;
  71 +
  72 + static void computeLogLikelihood(const Mat &data, const QList<int> &labels, float *loglikelihood)
70 { 73 {
71 - if (src.first().size() > 1)  
72 - qFatal("Expected sigle matrix templates."); 74 + const QList<uchar> values = OpenCVUtils::matrixToVector<uchar>(data);
  75 + if (values.size() != labels.size())
  76 + qFatal("Logic error.");
  77 +
  78 + QVector<int> genuines(256*256,0), impostors(256*256,0);
  79 + for (int i=0; i<labels.size(); i++)
  80 + for (int j=0; j<labels.size(); j++)
  81 + if (labels[i] == labels[j]) genuines[256*values[i]+values[j]]++;
  82 + else impostors[256*values[i]+values[j]]++;
73 83
74 - Mat data = OpenCVUtils::toMat(src.data());  
75 - QList<int> labels = src.labels<int>();  
76 -  
77 - QVector<qint64> genuines(256*256,0), impostors(256*256,0);  
78 - for (int i=0; i<labels.size(); i++) {  
79 - const uchar *a = data.ptr(i);  
80 - for (int j=0; j<labels.size(); j++) {  
81 - const uchar *b = data.ptr(j);  
82 - const bool genuine = (labels[i] == labels[j]);  
83 - for (int k=0; k<data.cols; k++)  
84 - genuine ? genuines[256*a[k]+b[k]]++ : impostors[256*a[k]+b[k]]++;  
85 - }  
86 - }  
87 84
88 - qint64 totalGenuines(0), totalImpostors(0); 85 + int totalGenuines(0), totalImpostors(0);
89 for (int i=0; i<256*256; i++) { 86 for (int i=0; i<256*256; i++) {
90 totalGenuines += genuines[i]; 87 totalGenuines += genuines[i];
91 totalImpostors += impostors[i]; 88 totalImpostors += impostors[i];
92 } 89 }
93 90
94 - loglikelihood = QVector<float>(256*256);  
95 for (int i=0; i<256; i++) 91 for (int i=0; i<256; i++)
96 for (int j=0; j<256; j++) 92 for (int j=0; j<256; j++)
97 loglikelihood[i*256+j] = log((double(genuines[i*256+j]+genuines[j*256+i]+1)/totalGenuines)/ 93 loglikelihood[i*256+j] = log((double(genuines[i*256+j]+genuines[j*256+i]+1)/totalGenuines)/
98 (double(impostors[i*256+j]+impostors[j*256+i]+1)/totalImpostors)); 94 (double(impostors[i*256+j]+impostors[j*256+i]+1)/totalImpostors));
99 } 95 }
100 96
  97 + void train(const TemplateList &src)
  98 + {
  99 + if ((src.first().size() > 1) || (src.first().m().type() != CV_8UC1))
  100 + qFatal("Expected sigle matrix templates of type CV_8UC1.");
  101 +
  102 + const Mat data = OpenCVUtils::toMat(src.data());
  103 + const QList<int> templateLabels = src.labels<int>();
  104 + Mat loglikelihoods(data.cols, 256*256, CV_32FC1);
  105 +
  106 + QFutureSynchronizer<void> futures;
  107 + for (int i=0; i<data.cols; i++)
  108 + if (Globals->parallelism) futures.addFuture(QtConcurrent::run(&BayesianQuantizationDistance::computeLogLikelihood, data.col(i), templateLabels, loglikelihoods.ptr<float>(i)));
  109 + else computeLogLikelihood( data.col(i), templateLabels, loglikelihoods.ptr<float>(i));
  110 + futures.waitForFinished();
  111 +
  112 + kmeans(loglikelihoods, K, labels, TermCriteria(TermCriteria::MAX_ITER, 10, 0), 3, KMEANS_PP_CENTERS, centers);
  113 + }
  114 +
101 float compare(const Template &a, const Template &b) const 115 float compare(const Template &a, const Template &b) const
102 { 116 {
103 const uchar *aData = a.m().data; 117 const uchar *aData = a.m().data;
@@ -105,18 +119,18 @@ class BayesianQuantizationDistance : public Distance @@ -105,18 +119,18 @@ class BayesianQuantizationDistance : public Distance
105 const int size = a.m().rows * a.m().cols; 119 const int size = a.m().rows * a.m().cols;
106 float likelihood = 0; 120 float likelihood = 0;
107 for (int i=0; i<size; i++) 121 for (int i=0; i<size; i++)
108 - likelihood += loglikelihood[256*aData[i]+bData[i]]; 122 + likelihood += centers.ptr<float>(labels.at<int>(i))[256*aData[i]+bData[i]];
109 return likelihood; 123 return likelihood;
110 } 124 }
111 125
112 void load(QDataStream &stream) 126 void load(QDataStream &stream)
113 { 127 {
114 - stream >> loglikelihood; 128 + stream >> labels >> centers;
115 } 129 }
116 130
117 void store(QDataStream &stream) const 131 void store(QDataStream &stream) const
118 { 132 {
119 - stream << loglikelihood; 133 + stream << labels << centers;
120 } 134 }
121 }; 135 };
122 136