Commit 66001a7191876851e6420a60b84146aea62249a7

Authored by Josh Klontz
1 parent 82a0ed2f

implemented PipeDistance

sdk/openbr_plugin.cpp
... ... @@ -522,6 +522,9 @@ QString Object::argument(int index) const
522 522 } else if (type == "QList<br::Transform*>") {
523 523 foreach (Transform *transform, variant.value< QList<Transform*> >())
524 524 strings.append(transform->description());
  525 + } else if (type == "QList<br::Distance*>") {
  526 + foreach (Distance *distance, variant.value< QList<Distance*> >())
  527 + strings.append(distance->description());
525 528 } else {
526 529 qFatal("Unrecognized type: %s", qPrintable(type));
527 530 }
... ... @@ -556,6 +559,9 @@ void Object::store(QDataStream &amp;stream) const
556 559 if (type == "QList<br::Transform*>") {
557 560 foreach (Transform *transform, property.read(this).value< QList<Transform*> >())
558 561 transform->store(stream);
  562 + } else if (type == "QList<br::Distance*>") {
  563 + foreach (Distance *distance, property.read(this).value< QList<Distance*> >())
  564 + distance->store(stream);
559 565 } else if (type == "br::Transform*") {
560 566 property.read(this).value<Transform*>()->store(stream);
561 567 } else if (type == "br::Distance*") {
... ... @@ -590,6 +596,9 @@ void Object::load(QDataStream &amp;stream)
590 596 if (type == "QList<br::Transform*>") {
591 597 foreach (Transform *transform, property.read(this).value< QList<Transform*> >())
592 598 transform->load(stream);
  599 + } else if (type == "QList<br::Distance*>") {
  600 + foreach (Distance *distance, property.read(this).value< QList<Distance*> >())
  601 + distance->load(stream);
593 602 } else if (type == "br::Transform*") {
594 603 property.read(this).value<Transform*>()->load(stream);
595 604 } else if (type == "br::Distance*") {
... ... @@ -653,6 +662,11 @@ void Object::setProperty(const QString &amp;name, const QString &amp;value)
653 662 foreach (const QString &string, strings)
654 663 values.append(Transform::make(string, this));
655 664 variant.setValue(values);
  665 + } else if (type == "QList<br::Distance*>") {
  666 + QList<Distance*> values;
  667 + foreach (const QString &string, strings)
  668 + values.append(Distance::make(string, this));
  669 + variant.setValue(values);
656 670 } else {
657 671 qFatal("Unrecognized type: %s", qPrintable(type));
658 672 }
... ... @@ -828,6 +842,7 @@ void br::Context::initializeQt(QString sdkPath)
828 842 qRegisterMetaType< br::Transform* >();
829 843 qRegisterMetaType< QList<br::Transform*> >();
830 844 qRegisterMetaType< br::Distance* >();
  845 + qRegisterMetaType< QList<br::Distance*> >();
831 846 qRegisterMetaType< cv::Mat >();
832 847  
833 848 qInstallMsgHandler(messageHandler);
... ... @@ -1311,15 +1326,21 @@ void Transform::backProject(const TemplateList &amp;dst, TemplateList &amp;src) const
1311 1326 /* Distance - public methods */
1312 1327 Distance *Distance::make(QString str, QObject *parent)
1313 1328 {
1314   - // Check for custom transforms
1315   - if (Globals->abbreviations.contains(str))
1316   - return make(Globals->abbreviations[str], parent);
  1329 + // Check for custom transforms
  1330 + if (Globals->abbreviations.contains(str))
  1331 + return make(Globals->abbreviations[str], parent);
  1332 +
  1333 + { // Check for use of '+' as shorthand for Pipe(...)
  1334 + QStringList words = parse(str, '+');
  1335 + if (words.size() > 1)
  1336 + return make("Pipe([" + words.join(",") + "])", parent);
  1337 + }
1317 1338  
1318   - File f = "." + str;
1319   - Distance *distance = Factory<Distance>::make(f);
  1339 + File f = "." + str;
  1340 + Distance *distance = Factory<Distance>::make(f);
1320 1341  
1321   - distance->setParent(parent);
1322   - return distance;
  1342 + distance->setParent(parent);
  1343 + return distance;
1323 1344 }
1324 1345  
1325 1346 void Distance::compare(const TemplateList &target, const TemplateList &query, Output *output) const
... ...
sdk/openbr_plugin.h
... ... @@ -1136,6 +1136,7 @@ Q_DECLARE_METATYPE(QList&lt;int&gt;)
1136 1136 Q_DECLARE_METATYPE(br::Transform*)
1137 1137 Q_DECLARE_METATYPE(QList<br::Transform*>)
1138 1138 Q_DECLARE_METATYPE(br::Distance*)
  1139 +Q_DECLARE_METATYPE(QList<br::Distance*>)
1139 1140 Q_DECLARE_METATYPE(cv::Mat)
1140 1141  
1141 1142 #endif // __OPENBR_PLUGIN_H
... ...
sdk/plugins/algorithms.cpp
... ... @@ -32,7 +32,7 @@ class AlgorithmsInitializer : public Initializer
32 32 {
33 33 // Face
34 34 Globals->abbreviations.insert("FaceRecognition", "FaceDetection!<FaceRecognitionRegistration>!<FaceRecognitionExtraction>+<FaceRecognitionEmbedding>+<FaceRecognitionQuantization>:UCharL1");
35   - Globals->abbreviations.insert("FaceRecognitionNoTraining", "FaceDetection!ASEFEyes+Affine(86,86,0.25,0.35)!Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+Mask+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59)+Cat:Dist(ChiSquared)");
  35 + Globals->abbreviations.insert("FaceRecognitionNoTraining", "FaceDetection!ASEFEyes+Affine(86,86,0.25,0.35)!Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+Mask+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59)+Cat:ChiSquared");
36 36 Globals->abbreviations.insert("GenderClassification", "FaceDetection!<FaceClassificationRegistration>!<FaceClassificationExtraction>+<GenderClassifier>+Discard");
37 37 Globals->abbreviations.insert("AgeRegression", "FaceDetection!<FaceClassificationRegistration>!<FaceClassificationExtraction>+<AgeRegressor>+Discard");
38 38 Globals->abbreviations.insert("FaceQuality", "Open!Cascade(FrontalFace)+ASEFEyes+Affine(64,64,0.25,0.35)+ImageQuality+Cvt(Gray)+DFFS+Discard");
... ... @@ -45,7 +45,7 @@ class AlgorithmsInitializer : public Initializer
45 45 Globals->abbreviations.insert("SURF", "Open+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)");
46 46 Globals->abbreviations.insert("SmallSIFT", "Open+LimitSize(512)+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)");
47 47 Globals->abbreviations.insert("SmallSURF", "Open+LimitSize(512)+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)");
48   - Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):Dist(L2)");
  48 + Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):L2");
49 49  
50 50 // Hash
51 51 Globals->abbreviations.insert("FileName", "Name+Identity:Identical");
... ...
sdk/plugins/distance.cpp
... ... @@ -14,6 +14,7 @@
14 14 * limitations under the License. *
15 15 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */
16 16  
  17 +#include <QtConcurrentRun>
17 18 #include <opencv2/imgproc/imgproc.hpp>
18 19 #include <openbr_plugin.h>
19 20  
... ... @@ -101,7 +102,6 @@ private:
101 102 for (int col=0; col<a.cols; col++) {
102 103 const float target = a.at<float>(row,col);
103 104 const float query = b.at<float>(row,col);
104   -
105 105 dot += target * query;
106 106 magA += target * target;
107 107 magB += query * query;
... ... @@ -139,6 +139,44 @@ BR_REGISTER(Distance, DefaultDistance)
139 139  
140 140 /*!
141 141 * \ingroup distances
  142 + * \brief Distances in series.
  143 + * \author Josh Klontz \cite jklontz
  144 + *
  145 + * The templates are compared using each br::Distance in order.
  146 + * If the result of the comparison with any given distance is -std::numeric_limits<float>::max() then this result is returned early.
  147 + * Otherwise the returned result is the value of comparing the templates using the last br::Distance.
  148 + */
  149 +class PipeDistance : public Distance
  150 +{
  151 + Q_OBJECT
  152 + Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances)
  153 + BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>())
  154 +
  155 + void train(const TemplateList &data)
  156 + {
  157 + QList< QFuture<void> > futures;
  158 + foreach (br::Distance *distance, distances)
  159 + if (Globals->parallelism) futures.append(QtConcurrent::run(distance, &Distance::train, data));
  160 + else distance->train(data);
  161 + Globals->trackFutures(futures);
  162 + }
  163 +
  164 + float compare(const Template &a, const Template &b) const
  165 + {
  166 + float result = -std::numeric_limits<float>::max();
  167 + foreach (br::Distance *distance, distances) {
  168 + result = distance->compare(a, b);
  169 + if (result == -std::numeric_limits<float>::max())
  170 + return result;
  171 + }
  172 + return result;
  173 + }
  174 +};
  175 +
  176 +BR_REGISTER(Distance, PipeDistance)
  177 +
  178 +/*!
  179 + * \ingroup distances
142 180 * \brief Fast 8-bit L1 distance
143 181 * \author Josh Klontz \cite jklontz
144 182 */
... ...
sdk/plugins/quality.cpp
... ... @@ -260,37 +260,6 @@ class UnitDistance : public Distance
260 260  
261 261 BR_REGISTER(Distance, UnitDistance)
262 262  
263   -/*!
264   - * \ingroup distances
265   - * \brief Check target metadata before matching templates.
266   - * \author Josh Klontz \cite jklontz
267   - */
268   -class MetadataDistance : public Distance
269   -{
270   - Q_OBJECT
271   - Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance)
272   - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))
273   -
274   - void train(const TemplateList &src)
275   - {
276   - distance->train(src);
277   - }
278   -
279   - float compare(const Template &a, const Template &b) const
280   - {
281   - foreach (const QString &filter, Globals->demographicFilters.keys()) {
282   - const QString metadata = a.file.getString(filter, "");
283   - if (metadata.isEmpty()) continue;
284   - const QRegExp re(Globals->demographicFilters[filter]);
285   - if (re.indexIn(metadata) == -1)
286   - return -std::numeric_limits<float>::max();
287   - }
288   - return distance->compare(a, b);
289   - }
290   -};
291   -
292   -BR_REGISTER(Distance, MetadataDistance)
293   -
294 263 } // namespace br
295 264  
296 265 #include "quality.moc"
... ...
sdk/plugins/validate.cpp
... ... @@ -79,25 +79,42 @@ BR_REGISTER(Transform, CrossValidateTransform)
79 79 class CrossValidateDistance : public Distance
80 80 {
81 81 Q_OBJECT
82   - Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance)
83   - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)"))
84   -
85   - void train(const TemplateList &src)
86   - {
87   - distance->train(src);
88   - }
89 82  
90 83 float compare(const Template &a, const Template &b) const
91 84 {
92 85 const int partitionA = a.file.getInt("Cross_Validation_Partition", 0);
93 86 const int partitionB = b.file.getInt("Cross_Validation_Partition", 0);
94   - if (partitionA != partitionB) return -std::numeric_limits<float>::max();
95   - return distance->compare(a, b);
  87 + return (partitionA != partitionB) ? -std::numeric_limits<float>::max() : 0;
96 88 }
97 89 };
98 90  
99 91 BR_REGISTER(Distance, CrossValidateDistance)
100 92  
  93 +/*!
  94 + * \ingroup distances
  95 + * \brief Checks target metadata.
  96 + * \author Josh Klontz \cite jklontz
  97 + */
  98 +class MetadataDistance : public Distance
  99 +{
  100 + Q_OBJECT
  101 +
  102 + float compare(const Template &a, const Template &b) const
  103 + {
  104 + (void) b;
  105 + foreach (const QString &filter, Globals->demographicFilters.keys()) {
  106 + const QString metadata = a.file.getString(filter, "");
  107 + if (metadata.isEmpty()) continue;
  108 + const QRegExp re(Globals->demographicFilters[filter]);
  109 + if (re.indexIn(metadata) == -1)
  110 + return -std::numeric_limits<float>::max();
  111 + }
  112 + return 0;
  113 + }
  114 +};
  115 +
  116 +BR_REGISTER(Distance, MetadataDistance)
  117 +
101 118 } // namespace br
102 119  
103 120 #include "validate.moc"
... ...