Commit 66001a7191876851e6420a60b84146aea62249a7

Authored by Josh Klontz
1 parent 82a0ed2f

implemented PipeDistance

sdk/openbr_plugin.cpp
@@ -522,6 +522,9 @@ QString Object::argument(int index) const @@ -522,6 +522,9 @@ QString Object::argument(int index) const
522 } else if (type == "QList<br::Transform*>") { 522 } else if (type == "QList<br::Transform*>") {
523 foreach (Transform *transform, variant.value< QList<Transform*> >()) 523 foreach (Transform *transform, variant.value< QList<Transform*> >())
524 strings.append(transform->description()); 524 strings.append(transform->description());
  525 + } else if (type == "QList<br::Distance*>") {
  526 + foreach (Distance *distance, variant.value< QList<Distance*> >())
  527 + strings.append(distance->description());
525 } else { 528 } else {
526 qFatal("Unrecognized type: %s", qPrintable(type)); 529 qFatal("Unrecognized type: %s", qPrintable(type));
527 } 530 }
@@ -556,6 +559,9 @@ void Object::store(QDataStream &amp;stream) const @@ -556,6 +559,9 @@ void Object::store(QDataStream &amp;stream) const
556 if (type == "QList<br::Transform*>") { 559 if (type == "QList<br::Transform*>") {
557 foreach (Transform *transform, property.read(this).value< QList<Transform*> >()) 560 foreach (Transform *transform, property.read(this).value< QList<Transform*> >())
558 transform->store(stream); 561 transform->store(stream);
  562 + } else if (type == "QList<br::Distance*>") {
  563 + foreach (Distance *distance, property.read(this).value< QList<Distance*> >())
  564 + distance->store(stream);
559 } else if (type == "br::Transform*") { 565 } else if (type == "br::Transform*") {
560 property.read(this).value<Transform*>()->store(stream); 566 property.read(this).value<Transform*>()->store(stream);
561 } else if (type == "br::Distance*") { 567 } else if (type == "br::Distance*") {
@@ -590,6 +596,9 @@ void Object::load(QDataStream &amp;stream) @@ -590,6 +596,9 @@ void Object::load(QDataStream &amp;stream)
590 if (type == "QList<br::Transform*>") { 596 if (type == "QList<br::Transform*>") {
591 foreach (Transform *transform, property.read(this).value< QList<Transform*> >()) 597 foreach (Transform *transform, property.read(this).value< QList<Transform*> >())
592 transform->load(stream); 598 transform->load(stream);
  599 + } else if (type == "QList<br::Distance*>") {
  600 + foreach (Distance *distance, property.read(this).value< QList<Distance*> >())
  601 + distance->load(stream);
593 } else if (type == "br::Transform*") { 602 } else if (type == "br::Transform*") {
594 property.read(this).value<Transform*>()->load(stream); 603 property.read(this).value<Transform*>()->load(stream);
595 } else if (type == "br::Distance*") { 604 } else if (type == "br::Distance*") {
@@ -653,6 +662,11 @@ void Object::setProperty(const QString &amp;name, const QString &amp;value) @@ -653,6 +662,11 @@ void Object::setProperty(const QString &amp;name, const QString &amp;value)
653 foreach (const QString &string, strings) 662 foreach (const QString &string, strings)
654 values.append(Transform::make(string, this)); 663 values.append(Transform::make(string, this));
655 variant.setValue(values); 664 variant.setValue(values);
  665 + } else if (type == "QList<br::Distance*>") {
  666 + QList<Distance*> values;
  667 + foreach (const QString &string, strings)
  668 + values.append(Distance::make(string, this));
  669 + variant.setValue(values);
656 } else { 670 } else {
657 qFatal("Unrecognized type: %s", qPrintable(type)); 671 qFatal("Unrecognized type: %s", qPrintable(type));
658 } 672 }
@@ -828,6 +842,7 @@ void br::Context::initializeQt(QString sdkPath) @@ -828,6 +842,7 @@ void br::Context::initializeQt(QString sdkPath)
828 qRegisterMetaType< br::Transform* >(); 842 qRegisterMetaType< br::Transform* >();
829 qRegisterMetaType< QList<br::Transform*> >(); 843 qRegisterMetaType< QList<br::Transform*> >();
830 qRegisterMetaType< br::Distance* >(); 844 qRegisterMetaType< br::Distance* >();
  845 + qRegisterMetaType< QList<br::Distance*> >();
831 qRegisterMetaType< cv::Mat >(); 846 qRegisterMetaType< cv::Mat >();
832 847
833 qInstallMsgHandler(messageHandler); 848 qInstallMsgHandler(messageHandler);
@@ -1311,15 +1326,21 @@ void Transform::backProject(const TemplateList &amp;dst, TemplateList &amp;src) const @@ -1311,15 +1326,21 @@ void Transform::backProject(const TemplateList &amp;dst, TemplateList &amp;src) const
1311 /* Distance - public methods */ 1326 /* Distance - public methods */
1312 Distance *Distance::make(QString str, QObject *parent) 1327 Distance *Distance::make(QString str, QObject *parent)
1313 { 1328 {
1314 - // Check for custom transforms  
1315 - if (Globals->abbreviations.contains(str))  
1316 - return make(Globals->abbreviations[str], parent); 1329 + // Check for custom transforms
  1330 + if (Globals->abbreviations.contains(str))
  1331 + return make(Globals->abbreviations[str], parent);
  1332 +
  1333 + { // Check for use of '+' as shorthand for Pipe(...)
  1334 + QStringList words = parse(str, '+');
  1335 + if (words.size() > 1)
  1336 + return make("Pipe([" + words.join(",") + "])", parent);
  1337 + }
1317 1338
1318 - File f = "." + str;  
1319 - Distance *distance = Factory<Distance>::make(f); 1339 + File f = "." + str;
  1340 + Distance *distance = Factory<Distance>::make(f);
1320 1341
1321 - distance->setParent(parent);  
1322 - return distance; 1342 + distance->setParent(parent);
  1343 + return distance;
1323 } 1344 }
1324 1345
1325 void Distance::compare(const TemplateList &target, const TemplateList &query, Output *output) const 1346 void Distance::compare(const TemplateList &target, const TemplateList &query, Output *output) const
sdk/openbr_plugin.h
@@ -1136,6 +1136,7 @@ Q_DECLARE_METATYPE(QList&lt;int&gt;) @@ -1136,6 +1136,7 @@ Q_DECLARE_METATYPE(QList&lt;int&gt;)
1136 Q_DECLARE_METATYPE(br::Transform*) 1136 Q_DECLARE_METATYPE(br::Transform*)
1137 Q_DECLARE_METATYPE(QList<br::Transform*>) 1137 Q_DECLARE_METATYPE(QList<br::Transform*>)
1138 Q_DECLARE_METATYPE(br::Distance*) 1138 Q_DECLARE_METATYPE(br::Distance*)
  1139 +Q_DECLARE_METATYPE(QList<br::Distance*>)
1139 Q_DECLARE_METATYPE(cv::Mat) 1140 Q_DECLARE_METATYPE(cv::Mat)
1140 1141
1141 #endif // __OPENBR_PLUGIN_H 1142 #endif // __OPENBR_PLUGIN_H
sdk/plugins/algorithms.cpp
@@ -32,7 +32,7 @@ class AlgorithmsInitializer : public Initializer @@ -32,7 +32,7 @@ class AlgorithmsInitializer : public Initializer
32 { 32 {
33 // Face 33 // Face
34 Globals->abbreviations.insert("FaceRecognition", "FaceDetection!<FaceRecognitionRegistration>!<FaceRecognitionExtraction>+<FaceRecognitionEmbedding>+<FaceRecognitionQuantization>:UCharL1"); 34 Globals->abbreviations.insert("FaceRecognition", "FaceDetection!<FaceRecognitionRegistration>!<FaceRecognitionExtraction>+<FaceRecognitionEmbedding>+<FaceRecognitionQuantization>:UCharL1");
35 - Globals->abbreviations.insert("FaceRecognitionNoTraining", "FaceDetection!ASEFEyes+Affine(86,86,0.25,0.35)!Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+Mask+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59)+Cat:Dist(ChiSquared)"); 35 + Globals->abbreviations.insert("FaceRecognitionNoTraining", "FaceDetection!ASEFEyes+Affine(86,86,0.25,0.35)!Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+Mask+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59)+Cat:ChiSquared");
36 Globals->abbreviations.insert("GenderClassification", "FaceDetection!<FaceClassificationRegistration>!<FaceClassificationExtraction>+<GenderClassifier>+Discard"); 36 Globals->abbreviations.insert("GenderClassification", "FaceDetection!<FaceClassificationRegistration>!<FaceClassificationExtraction>+<GenderClassifier>+Discard");
37 Globals->abbreviations.insert("AgeRegression", "FaceDetection!<FaceClassificationRegistration>!<FaceClassificationExtraction>+<AgeRegressor>+Discard"); 37 Globals->abbreviations.insert("AgeRegression", "FaceDetection!<FaceClassificationRegistration>!<FaceClassificationExtraction>+<AgeRegressor>+Discard");
38 Globals->abbreviations.insert("FaceQuality", "Open!Cascade(FrontalFace)+ASEFEyes+Affine(64,64,0.25,0.35)+ImageQuality+Cvt(Gray)+DFFS+Discard"); 38 Globals->abbreviations.insert("FaceQuality", "Open!Cascade(FrontalFace)+ASEFEyes+Affine(64,64,0.25,0.35)+ImageQuality+Cvt(Gray)+DFFS+Discard");
@@ -45,7 +45,7 @@ class AlgorithmsInitializer : public Initializer @@ -45,7 +45,7 @@ class AlgorithmsInitializer : public Initializer
45 Globals->abbreviations.insert("SURF", "Open+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)"); 45 Globals->abbreviations.insert("SURF", "Open+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)");
46 Globals->abbreviations.insert("SmallSIFT", "Open+LimitSize(512)+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); 46 Globals->abbreviations.insert("SmallSIFT", "Open+LimitSize(512)+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)");
47 Globals->abbreviations.insert("SmallSURF", "Open+LimitSize(512)+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)"); 47 Globals->abbreviations.insert("SmallSURF", "Open+LimitSize(512)+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)");
48 - Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):Dist(L2)"); 48 + Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):L2");
49 49
50 // Hash 50 // Hash
51 Globals->abbreviations.insert("FileName", "Name+Identity:Identical"); 51 Globals->abbreviations.insert("FileName", "Name+Identity:Identical");
sdk/plugins/distance.cpp
@@ -14,6 +14,7 @@ @@ -14,6 +14,7 @@
14 * limitations under the License. * 14 * limitations under the License. *
15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ 15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
16 16
  17 +#include <QtConcurrentRun>
17 #include <opencv2/imgproc/imgproc.hpp> 18 #include <opencv2/imgproc/imgproc.hpp>
18 #include <openbr_plugin.h> 19 #include <openbr_plugin.h>
19 20
@@ -101,7 +102,6 @@ private: @@ -101,7 +102,6 @@ private:
101 for (int col=0; col<a.cols; col++) { 102 for (int col=0; col<a.cols; col++) {
102 const float target = a.at<float>(row,col); 103 const float target = a.at<float>(row,col);
103 const float query = b.at<float>(row,col); 104 const float query = b.at<float>(row,col);
104 -  
105 dot += target * query; 105 dot += target * query;
106 magA += target * target; 106 magA += target * target;
107 magB += query * query; 107 magB += query * query;
@@ -139,6 +139,44 @@ BR_REGISTER(Distance, DefaultDistance) @@ -139,6 +139,44 @@ BR_REGISTER(Distance, DefaultDistance)
139 139
140 /*! 140 /*!
141 * \ingroup distances 141 * \ingroup distances
  142 + * \brief Distances in series.
  143 + * \author Josh Klontz \cite jklontz
  144 + *
  145 + * The templates are compared using each br::Distance in order.
  146 + * If the result of the comparison with any given distance is -std::numeric_limits<float>::max() then this result is returned early.
  147 + * Otherwise the returned result is the value of comparing the templates using the last br::Distance.
  148 + */
  149 +class PipeDistance : public Distance
  150 +{
  151 + Q_OBJECT
  152 + Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances)
  153 + BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>())
  154 +
  155 + void train(const TemplateList &data)
  156 + {
  157 + QList< QFuture<void> > futures;
  158 + foreach (br::Distance *distance, distances)
  159 + if (Globals->parallelism) futures.append(QtConcurrent::run(distance, &Distance::train, data));
  160 + else distance->train(data);
  161 + Globals->trackFutures(futures);
  162 + }
  163 +
  164 + float compare(const Template &a, const Template &b) const
  165 + {
  166 + float result = -std::numeric_limits<float>::max();
  167 + foreach (br::Distance *distance, distances) {
  168 + result = distance->compare(a, b);
  169 + if (result == -std::numeric_limits<float>::max())
  170 + return result;
  171 + }
  172 + return result;
  173 + }
  174 +};
  175 +
  176 +BR_REGISTER(Distance, PipeDistance)
  177 +
  178 +/*!
  179 + * \ingroup distances
142 * \brief Fast 8-bit L1 distance 180 * \brief Fast 8-bit L1 distance
143 * \author Josh Klontz \cite jklontz 181 * \author Josh Klontz \cite jklontz
144 */ 182 */
sdk/plugins/quality.cpp
@@ -260,37 +260,6 @@ class UnitDistance : public Distance @@ -260,37 +260,6 @@ class UnitDistance : public Distance
260 260
261 BR_REGISTER(Distance, UnitDistance) 261 BR_REGISTER(Distance, UnitDistance)
262 262
263 -/*!  
264 - * \ingroup distances  
265 - * \brief Check target metadata before matching templates.  
266 - * \author Josh Klontz \cite jklontz  
267 - */  
268 -class MetadataDistance : public Distance  
269 -{  
270 - Q_OBJECT  
271 - Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance)  
272 - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))  
273 -  
274 - void train(const TemplateList &src)  
275 - {  
276 - distance->train(src);  
277 - }  
278 -  
279 - float compare(const Template &a, const Template &b) const  
280 - {  
281 - foreach (const QString &filter, Globals->demographicFilters.keys()) {  
282 - const QString metadata = a.file.getString(filter, "");  
283 - if (metadata.isEmpty()) continue;  
284 - const QRegExp re(Globals->demographicFilters[filter]);  
285 - if (re.indexIn(metadata) == -1)  
286 - return -std::numeric_limits<float>::max();  
287 - }  
288 - return distance->compare(a, b);  
289 - }  
290 -};  
291 -  
292 -BR_REGISTER(Distance, MetadataDistance)  
293 -  
294 } // namespace br 263 } // namespace br
295 264
296 #include "quality.moc" 265 #include "quality.moc"
sdk/plugins/validate.cpp
@@ -79,25 +79,42 @@ BR_REGISTER(Transform, CrossValidateTransform) @@ -79,25 +79,42 @@ BR_REGISTER(Transform, CrossValidateTransform)
79 class CrossValidateDistance : public Distance 79 class CrossValidateDistance : public Distance
80 { 80 {
81 Q_OBJECT 81 Q_OBJECT
82 - Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance)  
83 - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))  
84 -  
85 - void train(const TemplateList &src)  
86 - {  
87 - distance->train(src);  
88 - }  
89 82
90 float compare(const Template &a, const Template &b) const 83 float compare(const Template &a, const Template &b) const
91 { 84 {
92 const int partitionA = a.file.getInt("Cross_Validation_Partition", 0); 85 const int partitionA = a.file.getInt("Cross_Validation_Partition", 0);
93 const int partitionB = b.file.getInt("Cross_Validation_Partition", 0); 86 const int partitionB = b.file.getInt("Cross_Validation_Partition", 0);
94 - if (partitionA != partitionB) return -std::numeric_limits<float>::max();  
95 - return distance->compare(a, b); 87 + return (partitionA != partitionB) ? -std::numeric_limits<float>::max() : 0;
96 } 88 }
97 }; 89 };
98 90
99 BR_REGISTER(Distance, CrossValidateDistance) 91 BR_REGISTER(Distance, CrossValidateDistance)
100 92
  93 +/*!
  94 + * \ingroup distances
  95 + * \brief Checks target metadata.
  96 + * \author Josh Klontz \cite jklontz
  97 + */
  98 +class MetadataDistance : public Distance
  99 +{
  100 + Q_OBJECT
  101 +
  102 + float compare(const Template &a, const Template &b) const
  103 + {
  104 + (void) b;
  105 + foreach (const QString &filter, Globals->demographicFilters.keys()) {
  106 + const QString metadata = a.file.getString(filter, "");
  107 + if (metadata.isEmpty()) continue;
  108 + const QRegExp re(Globals->demographicFilters[filter]);
  109 + if (re.indexIn(metadata) == -1)
  110 + return -std::numeric_limits<float>::max();
  111 + }
  112 + return 0;
  113 + }
  114 +};
  115 +
  116 +BR_REGISTER(Distance, MetadataDistance)
  117 +
101 } // namespace br 118 } // namespace br
102 119
103 #include "validate.moc" 120 #include "validate.moc"