Commit 2437da4aa35abb4f1946370c184b04b220e93a10

Authored by Josh Klontz
2 parents b4cac0b6 f3f7b589

Merge branch 'master' of https://github.com/biometrics/openbr

openbr/plugins/svm.cpp
@@ -26,10 +26,8 @@ using namespace cv; @@ -26,10 +26,8 @@ using namespace cv;
26 namespace br 26 namespace br
27 { 27 {
28 28
29 -static void storeSVM(float a, float b, const SVM &svm, QDataStream &stream) 29 +static void storeSVM(const SVM &svm, QDataStream &stream)
30 { 30 {
31 - stream << a << b;  
32 -  
33 // Create local file 31 // Create local file
34 QTemporaryFile tempFile; 32 QTemporaryFile tempFile;
35 tempFile.open(); 33 tempFile.open();
@@ -45,10 +43,8 @@ static void storeSVM(float a, float b, const SVM &amp;svm, QDataStream &amp;stream) @@ -45,10 +43,8 @@ static void storeSVM(float a, float b, const SVM &amp;svm, QDataStream &amp;stream)
45 stream << data; 43 stream << data;
46 } 44 }
47 45
48 -static void loadSVM(float &a, float &b, SVM &svm, QDataStream &stream) 46 +static void loadSVM(SVM &svm, QDataStream &stream)
49 { 47 {
50 - stream >> a >> b;  
51 -  
52 // Copy local file contents from stream 48 // Copy local file contents from stream
53 QByteArray data; 49 QByteArray data;
54 stream >> data; 50 stream >> data;
@@ -63,22 +59,8 @@ static void loadSVM(float &amp;a, float &amp;b, SVM &amp;svm, QDataStream &amp;stream) @@ -63,22 +59,8 @@ static void loadSVM(float &amp;a, float &amp;b, SVM &amp;svm, QDataStream &amp;stream)
63 svm.load(qPrintable(tempFile.fileName())); 59 svm.load(qPrintable(tempFile.fileName()));
64 } 60 }
65 61
66 -static void trainSVM(float &a, float &b, SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma) 62 +static void trainSVM(SVM &svm, Mat data, Mat lab, int kernel, int type, float C, float gamma)
67 { 63 {
68 - if ((type == CvSVM::EPS_SVR) || (type == CvSVM::NU_SVR)) {  
69 - // Scale labels to [-1,1]  
70 - double min, max;  
71 - minMaxLoc(lab, &min, &max);  
72 - if (max > min) {  
73 - a = 2.0/(max-min);  
74 - b = -(min*a+1);  
75 - lab = (lab * a) + b;  
76 - }  
77 - } else {  
78 - a = 1;  
79 - b = 0;  
80 - }  
81 -  
82 if (data.type() != CV_32FC1) 64 if (data.type() != CV_32FC1)
83 qFatal("Expected single channel floating point training data."); 65 qFatal("Expected single channel floating point training data.");
84 66
@@ -139,29 +121,28 @@ private: @@ -139,29 +121,28 @@ private:
139 BR_PROPERTY(float, gamma, -1) 121 BR_PROPERTY(float, gamma, -1)
140 122
141 SVM svm; 123 SVM svm;
142 - float a, b;  
143 124
144 void train(const TemplateList &_data) 125 void train(const TemplateList &_data)
145 { 126 {
146 Mat data = OpenCVUtils::toMat(_data.data()); 127 Mat data = OpenCVUtils::toMat(_data.data());
147 Mat lab = OpenCVUtils::toMat(_data.labels<float>()); 128 Mat lab = OpenCVUtils::toMat(_data.labels<float>());
148 - trainSVM(a, b, svm, data, lab, kernel, type, C, gamma); 129 + trainSVM(svm, data, lab, kernel, type, C, gamma);
149 } 130 }
150 131
151 void project(const Template &src, Template &dst) const 132 void project(const Template &src, Template &dst) const
152 { 133 {
153 dst = src; 134 dst = src;
154 - dst.file.set("Label", ((svm.predict(src.m().reshape(1, 1)) - b)/a)); 135 + dst.file.set("Label", svm.predict(src.m().reshape(1, 1)));
155 } 136 }
156 137
157 void store(QDataStream &stream) const 138 void store(QDataStream &stream) const
158 { 139 {
159 - storeSVM(a, b, svm, stream); 140 + storeSVM(svm, stream);
160 } 141 }
161 142
162 void load(QDataStream &stream) 143 void load(QDataStream &stream)
163 { 144 {
164 - loadSVM(a, b, svm, stream); 145 + loadSVM(svm, stream);
165 } 146 }
166 }; 147 };
167 148
@@ -197,7 +178,6 @@ private: @@ -197,7 +178,6 @@ private:
197 BR_PROPERTY(Type, type, EPS_SVR) 178 BR_PROPERTY(Type, type, EPS_SVR)
198 179
199 SVM svm; 180 SVM svm;
200 - float a, b;  
201 181
202 void train(const TemplateList &src) 182 void train(const TemplateList &src)
203 { 183 {
@@ -220,24 +200,24 @@ private: @@ -220,24 +200,24 @@ private:
220 deltaData = deltaData.rowRange(0, index); 200 deltaData = deltaData.rowRange(0, index);
221 deltaLab = deltaLab.rowRange(0, index); 201 deltaLab = deltaLab.rowRange(0, index);
222 202
223 - trainSVM(a, b, svm, deltaData, deltaLab, kernel, type, -1, -1); 203 + trainSVM(svm, deltaData, deltaLab, kernel, type, -1, -1);
224 } 204 }
225 205
226 float compare(const Template &ta, const Template &tb) const 206 float compare(const Template &ta, const Template &tb) const
227 { 207 {
228 Mat delta; 208 Mat delta;
229 absdiff(ta, tb, delta); 209 absdiff(ta, tb, delta);
230 - return (svm.predict(delta.reshape(1, 1)) - b)/a; 210 + return svm.predict(delta.reshape(1, 1));
231 } 211 }
232 212
233 void store(QDataStream &stream) const 213 void store(QDataStream &stream) const
234 { 214 {
235 - storeSVM(a, b, svm, stream); 215 + storeSVM(svm, stream);
236 } 216 }
237 217
238 void load(QDataStream &stream) 218 void load(QDataStream &stream)
239 { 219 {
240 - loadSVM(a, b, svm, stream); 220 + loadSVM(svm, stream);
241 } 221 }
242 }; 222 };
243 223
1 -Subproject commit 1394fbec42897c1a53373443c32a6f62a010294f 1 +Subproject commit cddd8fc2231191ab5d62dc29573f73e97d5065c5