Commit 66001a7191876851e6420a60b84146aea62249a7
1 parent
82a0ed2f
implemented PipeDistance
Showing
6 changed files
with
96 additions
and
50 deletions
sdk/openbr_plugin.cpp
| @@ -522,6 +522,9 @@ QString Object::argument(int index) const | @@ -522,6 +522,9 @@ QString Object::argument(int index) const | ||
| 522 | } else if (type == "QList<br::Transform*>") { | 522 | } else if (type == "QList<br::Transform*>") { |
| 523 | foreach (Transform *transform, variant.value< QList<Transform*> >()) | 523 | foreach (Transform *transform, variant.value< QList<Transform*> >()) |
| 524 | strings.append(transform->description()); | 524 | strings.append(transform->description()); |
| 525 | + } else if (type == "QList<br::Distance*>") { | ||
| 526 | + foreach (Distance *distance, variant.value< QList<Distance*> >()) | ||
| 527 | + strings.append(distance->description()); | ||
| 525 | } else { | 528 | } else { |
| 526 | qFatal("Unrecognized type: %s", qPrintable(type)); | 529 | qFatal("Unrecognized type: %s", qPrintable(type)); |
| 527 | } | 530 | } |
| @@ -556,6 +559,9 @@ void Object::store(QDataStream &stream) const | @@ -556,6 +559,9 @@ void Object::store(QDataStream &stream) const | ||
| 556 | if (type == "QList<br::Transform*>") { | 559 | if (type == "QList<br::Transform*>") { |
| 557 | foreach (Transform *transform, property.read(this).value< QList<Transform*> >()) | 560 | foreach (Transform *transform, property.read(this).value< QList<Transform*> >()) |
| 558 | transform->store(stream); | 561 | transform->store(stream); |
| 562 | + } else if (type == "QList<br::Distance*>") { | ||
| 563 | + foreach (Distance *distance, property.read(this).value< QList<Distance*> >()) | ||
| 564 | + distance->store(stream); | ||
| 559 | } else if (type == "br::Transform*") { | 565 | } else if (type == "br::Transform*") { |
| 560 | property.read(this).value<Transform*>()->store(stream); | 566 | property.read(this).value<Transform*>()->store(stream); |
| 561 | } else if (type == "br::Distance*") { | 567 | } else if (type == "br::Distance*") { |
| @@ -590,6 +596,9 @@ void Object::load(QDataStream &stream) | @@ -590,6 +596,9 @@ void Object::load(QDataStream &stream) | ||
| 590 | if (type == "QList<br::Transform*>") { | 596 | if (type == "QList<br::Transform*>") { |
| 591 | foreach (Transform *transform, property.read(this).value< QList<Transform*> >()) | 597 | foreach (Transform *transform, property.read(this).value< QList<Transform*> >()) |
| 592 | transform->load(stream); | 598 | transform->load(stream); |
| 599 | + } else if (type == "QList<br::Distance*>") { | ||
| 600 | + foreach (Distance *distance, property.read(this).value< QList<Distance*> >()) | ||
| 601 | + distance->load(stream); | ||
| 593 | } else if (type == "br::Transform*") { | 602 | } else if (type == "br::Transform*") { |
| 594 | property.read(this).value<Transform*>()->load(stream); | 603 | property.read(this).value<Transform*>()->load(stream); |
| 595 | } else if (type == "br::Distance*") { | 604 | } else if (type == "br::Distance*") { |
| @@ -653,6 +662,11 @@ void Object::setProperty(const QString &name, const QString &value) | @@ -653,6 +662,11 @@ void Object::setProperty(const QString &name, const QString &value) | ||
| 653 | foreach (const QString &string, strings) | 662 | foreach (const QString &string, strings) |
| 654 | values.append(Transform::make(string, this)); | 663 | values.append(Transform::make(string, this)); |
| 655 | variant.setValue(values); | 664 | variant.setValue(values); |
| 665 | + } else if (type == "QList<br::Distance*>") { | ||
| 666 | + QList<Distance*> values; | ||
| 667 | + foreach (const QString &string, strings) | ||
| 668 | + values.append(Distance::make(string, this)); | ||
| 669 | + variant.setValue(values); | ||
| 656 | } else { | 670 | } else { |
| 657 | qFatal("Unrecognized type: %s", qPrintable(type)); | 671 | qFatal("Unrecognized type: %s", qPrintable(type)); |
| 658 | } | 672 | } |
| @@ -828,6 +842,7 @@ void br::Context::initializeQt(QString sdkPath) | @@ -828,6 +842,7 @@ void br::Context::initializeQt(QString sdkPath) | ||
| 828 | qRegisterMetaType< br::Transform* >(); | 842 | qRegisterMetaType< br::Transform* >(); |
| 829 | qRegisterMetaType< QList<br::Transform*> >(); | 843 | qRegisterMetaType< QList<br::Transform*> >(); |
| 830 | qRegisterMetaType< br::Distance* >(); | 844 | qRegisterMetaType< br::Distance* >(); |
| 845 | + qRegisterMetaType< QList<br::Distance*> >(); | ||
| 831 | qRegisterMetaType< cv::Mat >(); | 846 | qRegisterMetaType< cv::Mat >(); |
| 832 | 847 | ||
| 833 | qInstallMsgHandler(messageHandler); | 848 | qInstallMsgHandler(messageHandler); |
| @@ -1311,15 +1326,21 @@ void Transform::backProject(const TemplateList &dst, TemplateList &src) const | @@ -1311,15 +1326,21 @@ void Transform::backProject(const TemplateList &dst, TemplateList &src) const | ||
| 1311 | /* Distance - public methods */ | 1326 | /* Distance - public methods */ |
| 1312 | Distance *Distance::make(QString str, QObject *parent) | 1327 | Distance *Distance::make(QString str, QObject *parent) |
| 1313 | { | 1328 | { |
| 1314 | - // Check for custom transforms | ||
| 1315 | - if (Globals->abbreviations.contains(str)) | ||
| 1316 | - return make(Globals->abbreviations[str], parent); | 1329 | + // Check for custom transforms |
| 1330 | + if (Globals->abbreviations.contains(str)) | ||
| 1331 | + return make(Globals->abbreviations[str], parent); | ||
| 1332 | + | ||
| 1333 | + { // Check for use of '+' as shorthand for Pipe(...) | ||
| 1334 | + QStringList words = parse(str, '+'); | ||
| 1335 | + if (words.size() > 1) | ||
| 1336 | + return make("Pipe([" + words.join(",") + "])", parent); | ||
| 1337 | + } | ||
| 1317 | 1338 | ||
| 1318 | - File f = "." + str; | ||
| 1319 | - Distance *distance = Factory<Distance>::make(f); | 1339 | + File f = "." + str; |
| 1340 | + Distance *distance = Factory<Distance>::make(f); | ||
| 1320 | 1341 | ||
| 1321 | - distance->setParent(parent); | ||
| 1322 | - return distance; | 1342 | + distance->setParent(parent); |
| 1343 | + return distance; | ||
| 1323 | } | 1344 | } |
| 1324 | 1345 | ||
| 1325 | void Distance::compare(const TemplateList &target, const TemplateList &query, Output *output) const | 1346 | void Distance::compare(const TemplateList &target, const TemplateList &query, Output *output) const |
sdk/openbr_plugin.h
| @@ -1136,6 +1136,7 @@ Q_DECLARE_METATYPE(QList<int>) | @@ -1136,6 +1136,7 @@ Q_DECLARE_METATYPE(QList<int>) | ||
| 1136 | Q_DECLARE_METATYPE(br::Transform*) | 1136 | Q_DECLARE_METATYPE(br::Transform*) |
| 1137 | Q_DECLARE_METATYPE(QList<br::Transform*>) | 1137 | Q_DECLARE_METATYPE(QList<br::Transform*>) |
| 1138 | Q_DECLARE_METATYPE(br::Distance*) | 1138 | Q_DECLARE_METATYPE(br::Distance*) |
| 1139 | +Q_DECLARE_METATYPE(QList<br::Distance*>) | ||
| 1139 | Q_DECLARE_METATYPE(cv::Mat) | 1140 | Q_DECLARE_METATYPE(cv::Mat) |
| 1140 | 1141 | ||
| 1141 | #endif // __OPENBR_PLUGIN_H | 1142 | #endif // __OPENBR_PLUGIN_H |
sdk/plugins/algorithms.cpp
| @@ -32,7 +32,7 @@ class AlgorithmsInitializer : public Initializer | @@ -32,7 +32,7 @@ class AlgorithmsInitializer : public Initializer | ||
| 32 | { | 32 | { |
| 33 | // Face | 33 | // Face |
| 34 | Globals->abbreviations.insert("FaceRecognition", "FaceDetection!<FaceRecognitionRegistration>!<FaceRecognitionExtraction>+<FaceRecognitionEmbedding>+<FaceRecognitionQuantization>:UCharL1"); | 34 | Globals->abbreviations.insert("FaceRecognition", "FaceDetection!<FaceRecognitionRegistration>!<FaceRecognitionExtraction>+<FaceRecognitionEmbedding>+<FaceRecognitionQuantization>:UCharL1"); |
| 35 | - Globals->abbreviations.insert("FaceRecognitionNoTraining", "FaceDetection!ASEFEyes+Affine(86,86,0.25,0.35)!Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+Mask+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59)+Cat:Dist(ChiSquared)"); | 35 | + Globals->abbreviations.insert("FaceRecognitionNoTraining", "FaceDetection!ASEFEyes+Affine(86,86,0.25,0.35)!Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+Mask+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59)+Cat:ChiSquared"); |
| 36 | Globals->abbreviations.insert("GenderClassification", "FaceDetection!<FaceClassificationRegistration>!<FaceClassificationExtraction>+<GenderClassifier>+Discard"); | 36 | Globals->abbreviations.insert("GenderClassification", "FaceDetection!<FaceClassificationRegistration>!<FaceClassificationExtraction>+<GenderClassifier>+Discard"); |
| 37 | Globals->abbreviations.insert("AgeRegression", "FaceDetection!<FaceClassificationRegistration>!<FaceClassificationExtraction>+<AgeRegressor>+Discard"); | 37 | Globals->abbreviations.insert("AgeRegression", "FaceDetection!<FaceClassificationRegistration>!<FaceClassificationExtraction>+<AgeRegressor>+Discard"); |
| 38 | Globals->abbreviations.insert("FaceQuality", "Open!Cascade(FrontalFace)+ASEFEyes+Affine(64,64,0.25,0.35)+ImageQuality+Cvt(Gray)+DFFS+Discard"); | 38 | Globals->abbreviations.insert("FaceQuality", "Open!Cascade(FrontalFace)+ASEFEyes+Affine(64,64,0.25,0.35)+ImageQuality+Cvt(Gray)+DFFS+Discard"); |
| @@ -45,7 +45,7 @@ class AlgorithmsInitializer : public Initializer | @@ -45,7 +45,7 @@ class AlgorithmsInitializer : public Initializer | ||
| 45 | Globals->abbreviations.insert("SURF", "Open+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)"); | 45 | Globals->abbreviations.insert("SURF", "Open+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)"); |
| 46 | Globals->abbreviations.insert("SmallSIFT", "Open+LimitSize(512)+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); | 46 | Globals->abbreviations.insert("SmallSIFT", "Open+LimitSize(512)+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); |
| 47 | Globals->abbreviations.insert("SmallSURF", "Open+LimitSize(512)+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)"); | 47 | Globals->abbreviations.insert("SmallSURF", "Open+LimitSize(512)+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)"); |
| 48 | - Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):Dist(L2)"); | 48 | + Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):L2"); |
| 49 | 49 | ||
| 50 | // Hash | 50 | // Hash |
| 51 | Globals->abbreviations.insert("FileName", "Name+Identity:Identical"); | 51 | Globals->abbreviations.insert("FileName", "Name+Identity:Identical"); |
sdk/plugins/distance.cpp
| @@ -14,6 +14,7 @@ | @@ -14,6 +14,7 @@ | ||
| 14 | * limitations under the License. * | 14 | * limitations under the License. * |
| 15 | * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ | 15 | * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ |
| 16 | 16 | ||
| 17 | +#include <QtConcurrentRun> | ||
| 17 | #include <opencv2/imgproc/imgproc.hpp> | 18 | #include <opencv2/imgproc/imgproc.hpp> |
| 18 | #include <openbr_plugin.h> | 19 | #include <openbr_plugin.h> |
| 19 | 20 | ||
| @@ -101,7 +102,6 @@ private: | @@ -101,7 +102,6 @@ private: | ||
| 101 | for (int col=0; col<a.cols; col++) { | 102 | for (int col=0; col<a.cols; col++) { |
| 102 | const float target = a.at<float>(row,col); | 103 | const float target = a.at<float>(row,col); |
| 103 | const float query = b.at<float>(row,col); | 104 | const float query = b.at<float>(row,col); |
| 104 | - | ||
| 105 | dot += target * query; | 105 | dot += target * query; |
| 106 | magA += target * target; | 106 | magA += target * target; |
| 107 | magB += query * query; | 107 | magB += query * query; |
| @@ -139,6 +139,44 @@ BR_REGISTER(Distance, DefaultDistance) | @@ -139,6 +139,44 @@ BR_REGISTER(Distance, DefaultDistance) | ||
| 139 | 139 | ||
| 140 | /*! | 140 | /*! |
| 141 | * \ingroup distances | 141 | * \ingroup distances |
| 142 | + * \brief Distances in series. | ||
| 143 | + * \author Josh Klontz \cite jklontz | ||
| 144 | + * | ||
| 145 | + * The templates are compared using each br::Distance in order. | ||
| 146 | + * If the result of the comparison with any given distance is -std::numeric_limits<float>::max() then this result is returned early. | ||
| 147 | + * Otherwise the returned result is the value of comparing the templates using the last br::Distance. | ||
| 148 | + */ | ||
| 149 | +class PipeDistance : public Distance | ||
| 150 | +{ | ||
| 151 | + Q_OBJECT | ||
| 152 | + Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances) | ||
| 153 | + BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>()) | ||
| 154 | + | ||
| 155 | + void train(const TemplateList &data) | ||
| 156 | + { | ||
| 157 | + QList< QFuture<void> > futures; | ||
| 158 | + foreach (br::Distance *distance, distances) | ||
| 159 | + if (Globals->parallelism) futures.append(QtConcurrent::run(distance, &Distance::train, data)); | ||
| 160 | + else distance->train(data); | ||
| 161 | + Globals->trackFutures(futures); | ||
| 162 | + } | ||
| 163 | + | ||
| 164 | + float compare(const Template &a, const Template &b) const | ||
| 165 | + { | ||
| 166 | + float result = -std::numeric_limits<float>::max(); | ||
| 167 | + foreach (br::Distance *distance, distances) { | ||
| 168 | + result = distance->compare(a, b); | ||
| 169 | + if (result == -std::numeric_limits<float>::max()) | ||
| 170 | + return result; | ||
| 171 | + } | ||
| 172 | + return result; | ||
| 173 | + } | ||
| 174 | +}; | ||
| 175 | + | ||
| 176 | +BR_REGISTER(Distance, PipeDistance) | ||
| 177 | + | ||
| 178 | +/*! | ||
| 179 | + * \ingroup distances | ||
| 142 | * \brief Fast 8-bit L1 distance | 180 | * \brief Fast 8-bit L1 distance |
| 143 | * \author Josh Klontz \cite jklontz | 181 | * \author Josh Klontz \cite jklontz |
| 144 | */ | 182 | */ |
sdk/plugins/quality.cpp
| @@ -260,37 +260,6 @@ class UnitDistance : public Distance | @@ -260,37 +260,6 @@ class UnitDistance : public Distance | ||
| 260 | 260 | ||
| 261 | BR_REGISTER(Distance, UnitDistance) | 261 | BR_REGISTER(Distance, UnitDistance) |
| 262 | 262 | ||
| 263 | -/*! | ||
| 264 | - * \ingroup distances | ||
| 265 | - * \brief Check target metadata before matching templates. | ||
| 266 | - * \author Josh Klontz \cite jklontz | ||
| 267 | - */ | ||
| 268 | -class MetadataDistance : public Distance | ||
| 269 | -{ | ||
| 270 | - Q_OBJECT | ||
| 271 | - Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance) | ||
| 272 | - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) | ||
| 273 | - | ||
| 274 | - void train(const TemplateList &src) | ||
| 275 | - { | ||
| 276 | - distance->train(src); | ||
| 277 | - } | ||
| 278 | - | ||
| 279 | - float compare(const Template &a, const Template &b) const | ||
| 280 | - { | ||
| 281 | - foreach (const QString &filter, Globals->demographicFilters.keys()) { | ||
| 282 | - const QString metadata = a.file.getString(filter, ""); | ||
| 283 | - if (metadata.isEmpty()) continue; | ||
| 284 | - const QRegExp re(Globals->demographicFilters[filter]); | ||
| 285 | - if (re.indexIn(metadata) == -1) | ||
| 286 | - return -std::numeric_limits<float>::max(); | ||
| 287 | - } | ||
| 288 | - return distance->compare(a, b); | ||
| 289 | - } | ||
| 290 | -}; | ||
| 291 | - | ||
| 292 | -BR_REGISTER(Distance, MetadataDistance) | ||
| 293 | - | ||
| 294 | } // namespace br | 263 | } // namespace br |
| 295 | 264 | ||
| 296 | #include "quality.moc" | 265 | #include "quality.moc" |
sdk/plugins/validate.cpp
| @@ -79,25 +79,42 @@ BR_REGISTER(Transform, CrossValidateTransform) | @@ -79,25 +79,42 @@ BR_REGISTER(Transform, CrossValidateTransform) | ||
| 79 | class CrossValidateDistance : public Distance | 79 | class CrossValidateDistance : public Distance |
| 80 | { | 80 | { |
| 81 | Q_OBJECT | 81 | Q_OBJECT |
| 82 | - Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance) | ||
| 83 | - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) | ||
| 84 | - | ||
| 85 | - void train(const TemplateList &src) | ||
| 86 | - { | ||
| 87 | - distance->train(src); | ||
| 88 | - } | ||
| 89 | 82 | ||
| 90 | float compare(const Template &a, const Template &b) const | 83 | float compare(const Template &a, const Template &b) const |
| 91 | { | 84 | { |
| 92 | const int partitionA = a.file.getInt("Cross_Validation_Partition", 0); | 85 | const int partitionA = a.file.getInt("Cross_Validation_Partition", 0); |
| 93 | const int partitionB = b.file.getInt("Cross_Validation_Partition", 0); | 86 | const int partitionB = b.file.getInt("Cross_Validation_Partition", 0); |
| 94 | - if (partitionA != partitionB) return -std::numeric_limits<float>::max(); | ||
| 95 | - return distance->compare(a, b); | 87 | + return (partitionA != partitionB) ? -std::numeric_limits<float>::max() : 0; |
| 96 | } | 88 | } |
| 97 | }; | 89 | }; |
| 98 | 90 | ||
| 99 | BR_REGISTER(Distance, CrossValidateDistance) | 91 | BR_REGISTER(Distance, CrossValidateDistance) |
| 100 | 92 | ||
| 93 | +/*! | ||
| 94 | + * \ingroup distances | ||
| 95 | + * \brief Checks target metadata. | ||
| 96 | + * \author Josh Klontz \cite jklontz | ||
| 97 | + */ | ||
| 98 | +class MetadataDistance : public Distance | ||
| 99 | +{ | ||
| 100 | + Q_OBJECT | ||
| 101 | + | ||
| 102 | + float compare(const Template &a, const Template &b) const | ||
| 103 | + { | ||
| 104 | + (void) b; | ||
| 105 | + foreach (const QString &filter, Globals->demographicFilters.keys()) { | ||
| 106 | + const QString metadata = a.file.getString(filter, ""); | ||
| 107 | + if (metadata.isEmpty()) continue; | ||
| 108 | + const QRegExp re(Globals->demographicFilters[filter]); | ||
| 109 | + if (re.indexIn(metadata) == -1) | ||
| 110 | + return -std::numeric_limits<float>::max(); | ||
| 111 | + } | ||
| 112 | + return 0; | ||
| 113 | + } | ||
| 114 | +}; | ||
| 115 | + | ||
| 116 | +BR_REGISTER(Distance, MetadataDistance) | ||
| 117 | + | ||
| 101 | } // namespace br | 118 | } // namespace br |
| 102 | 119 | ||
| 103 | #include "validate.moc" | 120 | #include "validate.moc" |