Commit 4e9226852a5a7c26b39ef8fefab4c1fc58ef215b

Authored by Josh Klontz
1 parent 81e9c1c2

generalized BayesianQuantizationDistance

Showing 1 changed file with 36 additions and 22 deletions
openbr/plugins/quantize.cpp
... ... @@ -64,40 +64,54 @@ BR_REGISTER(Transform, QuantizeTransform)
64 64 class BayesianQuantizationDistance : public Distance
65 65 {
66 66 Q_OBJECT
67   - QVector<float> loglikelihood;
  67 + Q_PROPERTY(int K READ get_K WRITE set_K RESET reset_K STORED false)
  68 + BR_PROPERTY(int, K, 1)
68 69  
69   - void train(const TemplateList &src)
  70 + Mat labels, centers;
  71 +
  72 + static void computeLogLikelihood(const Mat &data, const QList<int> &labels, float *loglikelihood)
70 73 {
71   - if (src.first().size() > 1)
72   - qFatal("Expected sigle matrix templates.");
  74 + const QList<uchar> values = OpenCVUtils::matrixToVector<uchar>(data);
  75 + if (values.size() != labels.size())
  76 + qFatal("Logic error.");
  77 +
  78 + QVector<int> genuines(256*256,0), impostors(256*256,0);
  79 + for (int i=0; i<labels.size(); i++)
  80 + for (int j=0; j<labels.size(); j++)
  81 + if (labels[i] == labels[j]) genuines[256*values[i]+values[j]]++;
  82 + else impostors[256*values[i]+values[j]]++;
73 83  
74   - Mat data = OpenCVUtils::toMat(src.data());
75   - QList<int> labels = src.labels<int>();
76   -
77   - QVector<qint64> genuines(256*256,0), impostors(256*256,0);
78   - for (int i=0; i<labels.size(); i++) {
79   - const uchar *a = data.ptr(i);
80   - for (int j=0; j<labels.size(); j++) {
81   - const uchar *b = data.ptr(j);
82   - const bool genuine = (labels[i] == labels[j]);
83   - for (int k=0; k<data.cols; k++)
84   - genuine ? genuines[256*a[k]+b[k]]++ : impostors[256*a[k]+b[k]]++;
85   - }
86   - }
87 84  
88   - qint64 totalGenuines(0), totalImpostors(0);
  85 + int totalGenuines(0), totalImpostors(0);
89 86 for (int i=0; i<256*256; i++) {
90 87 totalGenuines += genuines[i];
91 88 totalImpostors += impostors[i];
92 89 }
93 90  
94   - loglikelihood = QVector<float>(256*256);
95 91 for (int i=0; i<256; i++)
96 92 for (int j=0; j<256; j++)
97 93 loglikelihood[i*256+j] = log((double(genuines[i*256+j]+genuines[j*256+i]+1)/totalGenuines)/
98 94 (double(impostors[i*256+j]+impostors[j*256+i]+1)/totalImpostors));
99 95 }
100 96  
  97 + void train(const TemplateList &src)
  98 + {
  99 + if ((src.first().size() > 1) || (src.first().m().type() != CV_8UC1))
  100 + qFatal("Expected sigle matrix templates of type CV_8UC1.");
  101 +
  102 + const Mat data = OpenCVUtils::toMat(src.data());
  103 + const QList<int> templateLabels = src.labels<int>();
  104 + Mat loglikelihoods(data.cols, 256*256, CV_32FC1);
  105 +
  106 + QFutureSynchronizer<void> futures;
  107 + for (int i=0; i<data.cols; i++)
  108 + if (Globals->parallelism) futures.addFuture(QtConcurrent::run(&BayesianQuantizationDistance::computeLogLikelihood, data.col(i), templateLabels, loglikelihoods.ptr<float>(i)));
  109 + else computeLogLikelihood( data.col(i), templateLabels, loglikelihoods.ptr<float>(i));
  110 + futures.waitForFinished();
  111 +
  112 + kmeans(loglikelihoods, K, labels, TermCriteria(TermCriteria::MAX_ITER, 10, 0), 3, KMEANS_PP_CENTERS, centers);
  113 + }
  114 +
101 115 float compare(const Template &a, const Template &b) const
102 116 {
103 117 const uchar *aData = a.m().data;
... ... @@ -105,18 +119,18 @@ class BayesianQuantizationDistance : public Distance
105 119 const int size = a.m().rows * a.m().cols;
106 120 float likelihood = 0;
107 121 for (int i=0; i<size; i++)
108   - likelihood += loglikelihood[256*aData[i]+bData[i]];
  122 + likelihood += centers.ptr<float>(labels.at<int>(i))[256*aData[i]+bData[i]];
109 123 return likelihood;
110 124 }
111 125  
112 126 void load(QDataStream &stream)
113 127 {
114   - stream >> loglikelihood;
  128 + stream >> labels >> centers;
115 129 }
116 130  
117 131 void store(QDataStream &stream) const
118 132 {
119   - stream << loglikelihood;
  133 + stream << labels << centers;
120 134 }
121 135 };
122 136  
... ...