Commit 66001a7191876851e6420a60b84146aea62249a7
1 parent
82a0ed2f
implemented PipeDistance
Showing
6 changed files
with
96 additions
and
50 deletions
sdk/openbr_plugin.cpp
| ... | ... | @@ -522,6 +522,9 @@ QString Object::argument(int index) const |
| 522 | 522 | } else if (type == "QList<br::Transform*>") { |
| 523 | 523 | foreach (Transform *transform, variant.value< QList<Transform*> >()) |
| 524 | 524 | strings.append(transform->description()); |
| 525 | + } else if (type == "QList<br::Distance*>") { | |
| 526 | + foreach (Distance *distance, variant.value< QList<Distance*> >()) | |
| 527 | + strings.append(distance->description()); | |
| 525 | 528 | } else { |
| 526 | 529 | qFatal("Unrecognized type: %s", qPrintable(type)); |
| 527 | 530 | } |
| ... | ... | @@ -556,6 +559,9 @@ void Object::store(QDataStream &stream) const |
| 556 | 559 | if (type == "QList<br::Transform*>") { |
| 557 | 560 | foreach (Transform *transform, property.read(this).value< QList<Transform*> >()) |
| 558 | 561 | transform->store(stream); |
| 562 | + } else if (type == "QList<br::Distance*>") { | |
| 563 | + foreach (Distance *distance, property.read(this).value< QList<Distance*> >()) | |
| 564 | + distance->store(stream); | |
| 559 | 565 | } else if (type == "br::Transform*") { |
| 560 | 566 | property.read(this).value<Transform*>()->store(stream); |
| 561 | 567 | } else if (type == "br::Distance*") { |
| ... | ... | @@ -590,6 +596,9 @@ void Object::load(QDataStream &stream) |
| 590 | 596 | if (type == "QList<br::Transform*>") { |
| 591 | 597 | foreach (Transform *transform, property.read(this).value< QList<Transform*> >()) |
| 592 | 598 | transform->load(stream); |
| 599 | + } else if (type == "QList<br::Distance*>") { | |
| 600 | + foreach (Distance *distance, property.read(this).value< QList<Distance*> >()) | |
| 601 | + distance->load(stream); | |
| 593 | 602 | } else if (type == "br::Transform*") { |
| 594 | 603 | property.read(this).value<Transform*>()->load(stream); |
| 595 | 604 | } else if (type == "br::Distance*") { |
| ... | ... | @@ -653,6 +662,11 @@ void Object::setProperty(const QString &name, const QString &value) |
| 653 | 662 | foreach (const QString &string, strings) |
| 654 | 663 | values.append(Transform::make(string, this)); |
| 655 | 664 | variant.setValue(values); |
| 665 | + } else if (type == "QList<br::Distance*>") { | |
| 666 | + QList<Distance*> values; | |
| 667 | + foreach (const QString &string, strings) | |
| 668 | + values.append(Distance::make(string, this)); | |
| 669 | + variant.setValue(values); | |
| 656 | 670 | } else { |
| 657 | 671 | qFatal("Unrecognized type: %s", qPrintable(type)); |
| 658 | 672 | } |
| ... | ... | @@ -828,6 +842,7 @@ void br::Context::initializeQt(QString sdkPath) |
| 828 | 842 | qRegisterMetaType< br::Transform* >(); |
| 829 | 843 | qRegisterMetaType< QList<br::Transform*> >(); |
| 830 | 844 | qRegisterMetaType< br::Distance* >(); |
| 845 | + qRegisterMetaType< QList<br::Distance*> >(); | |
| 831 | 846 | qRegisterMetaType< cv::Mat >(); |
| 832 | 847 | |
| 833 | 848 | qInstallMsgHandler(messageHandler); |
| ... | ... | @@ -1311,15 +1326,21 @@ void Transform::backProject(const TemplateList &dst, TemplateList &src) const |
| 1311 | 1326 | /* Distance - public methods */ |
| 1312 | 1327 | Distance *Distance::make(QString str, QObject *parent) |
| 1313 | 1328 | { |
| 1314 | - // Check for custom transforms | |
| 1315 | - if (Globals->abbreviations.contains(str)) | |
| 1316 | - return make(Globals->abbreviations[str], parent); | |
| 1329 | + // Check for custom transforms | |
| 1330 | + if (Globals->abbreviations.contains(str)) | |
| 1331 | + return make(Globals->abbreviations[str], parent); | |
| 1332 | + | |
| 1333 | + { // Check for use of '+' as shorthand for Pipe(...) | |
| 1334 | + QStringList words = parse(str, '+'); | |
| 1335 | + if (words.size() > 1) | |
| 1336 | + return make("Pipe([" + words.join(",") + "])", parent); | |
| 1337 | + } | |
| 1317 | 1338 | |
| 1318 | - File f = "." + str; | |
| 1319 | - Distance *distance = Factory<Distance>::make(f); | |
| 1339 | + File f = "." + str; | |
| 1340 | + Distance *distance = Factory<Distance>::make(f); | |
| 1320 | 1341 | |
| 1321 | - distance->setParent(parent); | |
| 1322 | - return distance; | |
| 1342 | + distance->setParent(parent); | |
| 1343 | + return distance; | |
| 1323 | 1344 | } |
| 1324 | 1345 | |
| 1325 | 1346 | void Distance::compare(const TemplateList &target, const TemplateList &query, Output *output) const | ... | ... |
sdk/openbr_plugin.h
| ... | ... | @@ -1136,6 +1136,7 @@ Q_DECLARE_METATYPE(QList<int>) |
| 1136 | 1136 | Q_DECLARE_METATYPE(br::Transform*) |
| 1137 | 1137 | Q_DECLARE_METATYPE(QList<br::Transform*>) |
| 1138 | 1138 | Q_DECLARE_METATYPE(br::Distance*) |
| 1139 | +Q_DECLARE_METATYPE(QList<br::Distance*>) | |
| 1139 | 1140 | Q_DECLARE_METATYPE(cv::Mat) |
| 1140 | 1141 | |
| 1141 | 1142 | #endif // __OPENBR_PLUGIN_H | ... | ... |
sdk/plugins/algorithms.cpp
| ... | ... | @@ -32,7 +32,7 @@ class AlgorithmsInitializer : public Initializer |
| 32 | 32 | { |
| 33 | 33 | // Face |
| 34 | 34 | Globals->abbreviations.insert("FaceRecognition", "FaceDetection!<FaceRecognitionRegistration>!<FaceRecognitionExtraction>+<FaceRecognitionEmbedding>+<FaceRecognitionQuantization>:UCharL1"); |
| 35 | - Globals->abbreviations.insert("FaceRecognitionNoTraining", "FaceDetection!ASEFEyes+Affine(86,86,0.25,0.35)!Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+Mask+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59)+Cat:Dist(ChiSquared)"); | |
| 35 | + Globals->abbreviations.insert("FaceRecognitionNoTraining", "FaceDetection!ASEFEyes+Affine(86,86,0.25,0.35)!Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+Mask+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59)+Cat:ChiSquared"); | |
| 36 | 36 | Globals->abbreviations.insert("GenderClassification", "FaceDetection!<FaceClassificationRegistration>!<FaceClassificationExtraction>+<GenderClassifier>+Discard"); |
| 37 | 37 | Globals->abbreviations.insert("AgeRegression", "FaceDetection!<FaceClassificationRegistration>!<FaceClassificationExtraction>+<AgeRegressor>+Discard"); |
| 38 | 38 | Globals->abbreviations.insert("FaceQuality", "Open!Cascade(FrontalFace)+ASEFEyes+Affine(64,64,0.25,0.35)+ImageQuality+Cvt(Gray)+DFFS+Discard"); |
| ... | ... | @@ -45,7 +45,7 @@ class AlgorithmsInitializer : public Initializer |
| 45 | 45 | Globals->abbreviations.insert("SURF", "Open+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)"); |
| 46 | 46 | Globals->abbreviations.insert("SmallSIFT", "Open+LimitSize(512)+KeyPointDetector(SIFT)+KeyPointDescriptor(SIFT):KeyPointMatcher(BruteForce)"); |
| 47 | 47 | Globals->abbreviations.insert("SmallSURF", "Open+LimitSize(512)+KeyPointDetector(SURF)+KeyPointDescriptor(SURF):KeyPointMatcher(BruteForce)"); |
| 48 | - Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):Dist(L2)"); | |
| 48 | + Globals->abbreviations.insert("ColorHist", "Open+LimitSize(512)!EnsureChannels(3)+SplitChannels+Hist(256,0,8)+Cat+Normalize(L1):L2"); | |
| 49 | 49 | |
| 50 | 50 | // Hash |
| 51 | 51 | Globals->abbreviations.insert("FileName", "Name+Identity:Identical"); | ... | ... |
sdk/plugins/distance.cpp
| ... | ... | @@ -14,6 +14,7 @@ |
| 14 | 14 | * limitations under the License. * |
| 15 | 15 | * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ |
| 16 | 16 | |
| 17 | +#include <QtConcurrentRun> | |
| 17 | 18 | #include <opencv2/imgproc/imgproc.hpp> |
| 18 | 19 | #include <openbr_plugin.h> |
| 19 | 20 | |
| ... | ... | @@ -101,7 +102,6 @@ private: |
| 101 | 102 | for (int col=0; col<a.cols; col++) { |
| 102 | 103 | const float target = a.at<float>(row,col); |
| 103 | 104 | const float query = b.at<float>(row,col); |
| 104 | - | |
| 105 | 105 | dot += target * query; |
| 106 | 106 | magA += target * target; |
| 107 | 107 | magB += query * query; |
| ... | ... | @@ -139,6 +139,44 @@ BR_REGISTER(Distance, DefaultDistance) |
| 139 | 139 | |
| 140 | 140 | /*! |
| 141 | 141 | * \ingroup distances |
| 142 | + * \brief Distances in series. | |
| 143 | + * \author Josh Klontz \cite jklontz | |
| 144 | + * | |
| 145 | + * The templates are compared using each br::Distance in order. | |
| 146 | + * If the result of the comparison with any given distance is -std::numeric_limits<float>::max() then this result is returned early. | |
| 147 | + * Otherwise the returned result is the value of comparing the templates using the last br::Distance. | |
| 148 | + */ | |
| 149 | +class PipeDistance : public Distance | |
| 150 | +{ | |
| 151 | + Q_OBJECT | |
| 152 | + Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances) | |
| 153 | + BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>()) | |
| 154 | + | |
| 155 | + void train(const TemplateList &data) | |
| 156 | + { | |
| 157 | + QList< QFuture<void> > futures; | |
| 158 | + foreach (br::Distance *distance, distances) | |
| 159 | + if (Globals->parallelism) futures.append(QtConcurrent::run(distance, &Distance::train, data)); | |
| 160 | + else distance->train(data); | |
| 161 | + Globals->trackFutures(futures); | |
| 162 | + } | |
| 163 | + | |
| 164 | + float compare(const Template &a, const Template &b) const | |
| 165 | + { | |
| 166 | + float result = -std::numeric_limits<float>::max(); | |
| 167 | + foreach (br::Distance *distance, distances) { | |
| 168 | + result = distance->compare(a, b); | |
| 169 | + if (result == -std::numeric_limits<float>::max()) | |
| 170 | + return result; | |
| 171 | + } | |
| 172 | + return result; | |
| 173 | + } | |
| 174 | +}; | |
| 175 | + | |
| 176 | +BR_REGISTER(Distance, PipeDistance) | |
| 177 | + | |
| 178 | +/*! | |
| 179 | + * \ingroup distances | |
| 142 | 180 | * \brief Fast 8-bit L1 distance |
| 143 | 181 | * \author Josh Klontz \cite jklontz |
| 144 | 182 | */ | ... | ... |
sdk/plugins/quality.cpp
| ... | ... | @@ -260,37 +260,6 @@ class UnitDistance : public Distance |
| 260 | 260 | |
| 261 | 261 | BR_REGISTER(Distance, UnitDistance) |
| 262 | 262 | |
| 263 | -/*! | |
| 264 | - * \ingroup distances | |
| 265 | - * \brief Check target metadata before matching templates. | |
| 266 | - * \author Josh Klontz \cite jklontz | |
| 267 | - */ | |
| 268 | -class MetadataDistance : public Distance | |
| 269 | -{ | |
| 270 | - Q_OBJECT | |
| 271 | - Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance) | |
| 272 | - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) | |
| 273 | - | |
| 274 | - void train(const TemplateList &src) | |
| 275 | - { | |
| 276 | - distance->train(src); | |
| 277 | - } | |
| 278 | - | |
| 279 | - float compare(const Template &a, const Template &b) const | |
| 280 | - { | |
| 281 | - foreach (const QString &filter, Globals->demographicFilters.keys()) { | |
| 282 | - const QString metadata = a.file.getString(filter, ""); | |
| 283 | - if (metadata.isEmpty()) continue; | |
| 284 | - const QRegExp re(Globals->demographicFilters[filter]); | |
| 285 | - if (re.indexIn(metadata) == -1) | |
| 286 | - return -std::numeric_limits<float>::max(); | |
| 287 | - } | |
| 288 | - return distance->compare(a, b); | |
| 289 | - } | |
| 290 | -}; | |
| 291 | - | |
| 292 | -BR_REGISTER(Distance, MetadataDistance) | |
| 293 | - | |
| 294 | 263 | } // namespace br |
| 295 | 264 | |
| 296 | 265 | #include "quality.moc" | ... | ... |
sdk/plugins/validate.cpp
| ... | ... | @@ -79,25 +79,42 @@ BR_REGISTER(Transform, CrossValidateTransform) |
| 79 | 79 | class CrossValidateDistance : public Distance |
| 80 | 80 | { |
| 81 | 81 | Q_OBJECT |
| 82 | - Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance) | |
| 83 | - BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) | |
| 84 | - | |
| 85 | - void train(const TemplateList &src) | |
| 86 | - { | |
| 87 | - distance->train(src); | |
| 88 | - } | |
| 89 | 82 | |
| 90 | 83 | float compare(const Template &a, const Template &b) const |
| 91 | 84 | { |
| 92 | 85 | const int partitionA = a.file.getInt("Cross_Validation_Partition", 0); |
| 93 | 86 | const int partitionB = b.file.getInt("Cross_Validation_Partition", 0); |
| 94 | - if (partitionA != partitionB) return -std::numeric_limits<float>::max(); | |
| 95 | - return distance->compare(a, b); | |
| 87 | + return (partitionA != partitionB) ? -std::numeric_limits<float>::max() : 0; | |
| 96 | 88 | } |
| 97 | 89 | }; |
| 98 | 90 | |
| 99 | 91 | BR_REGISTER(Distance, CrossValidateDistance) |
| 100 | 92 | |
| 93 | +/*! | |
| 94 | + * \ingroup distances | |
| 95 | + * \brief Checks target metadata. | |
| 96 | + * \author Josh Klontz \cite jklontz | |
| 97 | + */ | |
| 98 | +class MetadataDistance : public Distance | |
| 99 | +{ | |
| 100 | + Q_OBJECT | |
| 101 | + | |
| 102 | + float compare(const Template &a, const Template &b) const | |
| 103 | + { | |
| 104 | + (void) b; | |
| 105 | + foreach (const QString &filter, Globals->demographicFilters.keys()) { | |
| 106 | + const QString metadata = a.file.getString(filter, ""); | |
| 107 | + if (metadata.isEmpty()) continue; | |
| 108 | + const QRegExp re(Globals->demographicFilters[filter]); | |
| 109 | + if (re.indexIn(metadata) == -1) | |
| 110 | + return -std::numeric_limits<float>::max(); | |
| 111 | + } | |
| 112 | + return 0; | |
| 113 | + } | |
| 114 | +}; | |
| 115 | + | |
| 116 | +BR_REGISTER(Distance, MetadataDistance) | |
| 117 | + | |
| 101 | 118 | } // namespace br |
| 102 | 119 | |
| 103 | 120 | #include "validate.moc" | ... | ... |