Commit ccee88d8600b615a8e94bdc30fba042bf81fcaf7
Merge pull request #303 from biometrics/classification_api
Update Classification API to support OpenBR object handling
Showing
2 changed files
with
102 additions
and
1 deletions
openbr/openbr_plugin.cpp
| @@ -626,7 +626,11 @@ QStringList Object::prunedArguments(bool expanded) const | @@ -626,7 +626,11 @@ QStringList Object::prunedArguments(bool expanded) const | ||
| 626 | if (className.endsWith(interfaceName)) | 626 | if (className.endsWith(interfaceName)) |
| 627 | className.chop(interfaceName.size()); | 627 | className.chop(interfaceName.size()); |
| 628 | 628 | ||
| 629 | - if (interfaceName == "Distance") | 629 | + if (interfaceName == "Representation") |
| 630 | + shellObject.reset(Factory<Representation>::make(className)); | ||
| 631 | + else if (interfaceName == "Classifier") | ||
| 632 | + shellObject.reset(Factory<Classifier>::make(className)); | ||
| 633 | + else if (interfaceName == "Distance") | ||
| 630 | shellObject.reset(Factory<Distance>::make(className)); | 634 | shellObject.reset(Factory<Distance>::make(className)); |
| 631 | else if (interfaceName == "Transform") | 635 | else if (interfaceName == "Transform") |
| 632 | shellObject.reset(Factory<Transform>::make(className)); | 636 | shellObject.reset(Factory<Transform>::make(className)); |
| @@ -673,6 +677,12 @@ QString Object::argument(int index, bool expanded) const | @@ -673,6 +677,12 @@ QString Object::argument(int index, bool expanded) const | ||
| 673 | } else if (type == "QList<br::Distance*>") { | 677 | } else if (type == "QList<br::Distance*>") { |
| 674 | foreach (Distance *distance, variant.value< QList<Distance*> >()) | 678 | foreach (Distance *distance, variant.value< QList<Distance*> >()) |
| 675 | strings.append(distance->description(expanded)); | 679 | strings.append(distance->description(expanded)); |
| 680 | + } else if (type == "QList<br::Representation*>") { | ||
| 681 | + foreach (Representation *representation, variant.value< QList<Representation*> >()) | ||
| 682 | + strings.append(representation->description(expanded)); | ||
| 683 | + } else if (type == "QList<br::Classifier*>") { | ||
| 684 | + foreach (Classifier *classifier, variant.value< QList<Classifier*> >()) | ||
| 685 | + strings.append(classifier->description(expanded)); | ||
| 676 | } else { | 686 | } else { |
| 677 | qFatal("Unrecognized type: %s", qPrintable(type)); | 687 | qFatal("Unrecognized type: %s", qPrintable(type)); |
| 678 | } | 688 | } |
| @@ -682,6 +692,10 @@ QString Object::argument(int index, bool expanded) const | @@ -682,6 +692,10 @@ QString Object::argument(int index, bool expanded) const | ||
| 682 | return variant.value<Transform*>()->description(expanded); | 692 | return variant.value<Transform*>()->description(expanded); |
| 683 | } else if (type == "br::Distance*") { | 693 | } else if (type == "br::Distance*") { |
| 684 | return variant.value<Distance*>()->description(expanded); | 694 | return variant.value<Distance*>()->description(expanded); |
| 695 | + } else if (type == "br::Representation*") { | ||
| 696 | + return variant.value<Representation*>()->description(expanded); | ||
| 697 | + } else if (type == "br::Classifier*") { | ||
| 698 | + return variant.value<Classifier*>()->description(expanded); | ||
| 685 | } else if (type == "QStringList") { | 699 | } else if (type == "QStringList") { |
| 686 | return "[" + variant.toStringList().join(",") + "]"; | 700 | return "[" + variant.toStringList().join(",") + "]"; |
| 687 | } | 701 | } |
| @@ -713,10 +727,20 @@ void Object::store(QDataStream &stream) const | @@ -713,10 +727,20 @@ void Object::store(QDataStream &stream) const | ||
| 713 | } else if (type == "QList<br::Distance*>") { | 727 | } else if (type == "QList<br::Distance*>") { |
| 714 | foreach (Distance *distance, property.read(this).value< QList<Distance*> >()) | 728 | foreach (Distance *distance, property.read(this).value< QList<Distance*> >()) |
| 715 | distance->store(stream); | 729 | distance->store(stream); |
| 730 | + } else if (type == "QList<br::Representation*>") { | ||
| 731 | + foreach (Representation *representation, property.read(this).value< QList<Representation*> >()) | ||
| 732 | + representation->store(stream); | ||
| 733 | + } else if (type == "QList<br::Classifier*>") { | ||
| 734 | + foreach (Classifier *classifier, property.read(this).value< QList<Classifier*> >()) | ||
| 735 | + classifier->store(stream); | ||
| 716 | } else if (type == "br::Transform*") { | 736 | } else if (type == "br::Transform*") { |
| 717 | property.read(this).value<Transform*>()->store(stream); | 737 | property.read(this).value<Transform*>()->store(stream); |
| 718 | } else if (type == "br::Distance*") { | 738 | } else if (type == "br::Distance*") { |
| 719 | property.read(this).value<Distance*>()->store(stream); | 739 | property.read(this).value<Distance*>()->store(stream); |
| 740 | + } else if (type == "br::Representation*") { | ||
| 741 | + property.read(this).value<Representation*>()->store(stream); | ||
| 742 | + } else if (type == "br::Classifier*") { | ||
| 743 | + property.read(this).value<Classifier*>()->store(stream); | ||
| 720 | } else if (type == "bool") { | 744 | } else if (type == "bool") { |
| 721 | stream << property.read(this).toBool(); | 745 | stream << property.read(this).toBool(); |
| 722 | } else if (type == "int") { | 746 | } else if (type == "int") { |
| @@ -750,10 +774,20 @@ void Object::load(QDataStream &stream) | @@ -750,10 +774,20 @@ void Object::load(QDataStream &stream) | ||
| 750 | } else if (type == "QList<br::Distance*>") { | 774 | } else if (type == "QList<br::Distance*>") { |
| 751 | foreach (Distance *distance, property.read(this).value< QList<Distance*> >()) | 775 | foreach (Distance *distance, property.read(this).value< QList<Distance*> >()) |
| 752 | distance->load(stream); | 776 | distance->load(stream); |
| 777 | + } else if (type == "QList<br::Representation*>") { | ||
| 778 | + foreach (Representation *representation, property.read(this).value< QList<Representation*> >()) | ||
| 779 | + representation->load(stream); | ||
| 780 | + } else if (type == "QList<br::Classifier*>") { | ||
| 781 | + foreach (Classifier *classifier, property.read(this).value< QList<Classifier*> >()) | ||
| 782 | + classifier->load(stream); | ||
| 753 | } else if (type == "br::Transform*") { | 783 | } else if (type == "br::Transform*") { |
| 754 | property.read(this).value<Transform*>()->load(stream); | 784 | property.read(this).value<Transform*>()->load(stream); |
| 755 | } else if (type == "br::Distance*") { | 785 | } else if (type == "br::Distance*") { |
| 756 | property.read(this).value<Distance*>()->load(stream); | 786 | property.read(this).value<Distance*>()->load(stream); |
| 787 | + } else if (type == "br::Representation*") { | ||
| 788 | + property.read(this).value<Representation*>()->load(stream); | ||
| 789 | + } else if (type == "br::Classifier*") { | ||
| 790 | + property.read(this).value<Classifier*>()->load(stream); | ||
| 757 | } else if (type == "bool") { | 791 | } else if (type == "bool") { |
| 758 | bool value; | 792 | bool value; |
| 759 | stream >> value; | 793 | stream >> value; |
| @@ -919,6 +953,18 @@ void Object::setProperty(const QString &name, QVariant value) | @@ -919,6 +953,18 @@ void Object::setProperty(const QString &name, QVariant value) | ||
| 919 | if (element.canConvert<QString>()) parsedValues.append(Distance::make(element.toString(), this)); | 953 | if (element.canConvert<QString>()) parsedValues.append(Distance::make(element.toString(), this)); |
| 920 | else parsedValues.append(element.value<Distance*>()); | 954 | else parsedValues.append(element.value<Distance*>()); |
| 921 | value.setValue(parsedValues); | 955 | value.setValue(parsedValues); |
| 956 | + } else if (type == "QList<br::Representation*>") { | ||
| 957 | + QList<Representation*> parsedValues; | ||
| 958 | + foreach (const QVariant &element, elements) | ||
| 959 | + if (element.canConvert<QString>()) parsedValues.append(Representation::make(element.toString(), this)); | ||
| 960 | + else parsedValues.append(element.value<Representation*>()); | ||
| 961 | + value.setValue(parsedValues); | ||
| 962 | + } else if (type == "QList<br::Classifier*>") { | ||
| 963 | + QList<Classifier*> parsedValues; | ||
| 964 | + foreach (const QVariant &element, elements) | ||
| 965 | + if (element.canConvert<QString>()) parsedValues.append(Classifier::make(element.toString(), this)); | ||
| 966 | + else parsedValues.append(element.value<Classifier*>()); | ||
| 967 | + value.setValue(parsedValues); | ||
| 922 | } else { | 968 | } else { |
| 923 | qFatal("Unrecognized type: %s", qPrintable(type)); | 969 | qFatal("Unrecognized type: %s", qPrintable(type)); |
| 924 | } | 970 | } |
| @@ -928,6 +974,12 @@ void Object::setProperty(const QString &name, QVariant value) | @@ -928,6 +974,12 @@ void Object::setProperty(const QString &name, QVariant value) | ||
| 928 | } else if (type == "br::Distance*") { | 974 | } else if (type == "br::Distance*") { |
| 929 | if (value.canConvert<QString>()) | 975 | if (value.canConvert<QString>()) |
| 930 | value.setValue(Distance::make(value.toString(), this)); | 976 | value.setValue(Distance::make(value.toString(), this)); |
| 977 | + } else if (type == "br::Representation*") { | ||
| 978 | + if (value.canConvert<QString>()) | ||
| 979 | + value.setValue(Representation::make(value.toString(), this)); | ||
| 980 | + } else if (type == "br::Classifier*") { | ||
| 981 | + if (value.canConvert<QString>()) | ||
| 982 | + value.setValue(Classifier::make(value.toString(), this)); | ||
| 931 | } else if (type == "bool") { | 983 | } else if (type == "bool") { |
| 932 | if (value.isNull()) value = true; | 984 | if (value.isNull()) value = true; |
| 933 | else if (value == "false") value = false; | 985 | else if (value == "false") value = false; |
| @@ -1086,10 +1138,14 @@ void br::Context::initialize(int &argc, char *argv[], QString sdkPath, bool useG | @@ -1086,10 +1138,14 @@ void br::Context::initialize(int &argc, char *argv[], QString sdkPath, bool useG | ||
| 1086 | qRegisterMetaType<br::TemplateList>(); | 1138 | qRegisterMetaType<br::TemplateList>(); |
| 1087 | qRegisterMetaType< br::Transform* >(); | 1139 | qRegisterMetaType< br::Transform* >(); |
| 1088 | qRegisterMetaType< br::Distance* >(); | 1140 | qRegisterMetaType< br::Distance* >(); |
| 1141 | + qRegisterMetaType< br::Representation* >(); | ||
| 1142 | + qRegisterMetaType< br::Classifier* >(); | ||
| 1089 | qRegisterMetaType< QList<int> >(); | 1143 | qRegisterMetaType< QList<int> >(); |
| 1090 | qRegisterMetaType< QList<float> >(); | 1144 | qRegisterMetaType< QList<float> >(); |
| 1091 | qRegisterMetaType< QList<br::Transform*> >(); | 1145 | qRegisterMetaType< QList<br::Transform*> >(); |
| 1092 | qRegisterMetaType< QList<br::Distance*> >(); | 1146 | qRegisterMetaType< QList<br::Distance*> >(); |
| 1147 | + qRegisterMetaType< QList<br::Representation* > >(); | ||
| 1148 | + qRegisterMetaType< QList<br::Classifier* > >(); | ||
| 1093 | qRegisterMetaType< QAbstractSocket::SocketState> (); | 1149 | qRegisterMetaType< QAbstractSocket::SocketState> (); |
| 1094 | qRegisterMetaType< QLocalSocket::LocalSocketState> (); | 1150 | qRegisterMetaType< QLocalSocket::LocalSocketState> (); |
| 1095 | 1151 | ||
| @@ -1196,6 +1252,15 @@ QStringList br::Context::objects(const char *abstractions, const char *implement | @@ -1196,6 +1252,15 @@ QStringList br::Context::objects(const char *abstractions, const char *implement | ||
| 1196 | if (implementationsRegExp.exactMatch(name)) | 1252 | if (implementationsRegExp.exactMatch(name)) |
| 1197 | objectList.append(name + (parameters ? "\t" + Factory<Transform>::parameters(name) : "")); | 1253 | objectList.append(name + (parameters ? "\t" + Factory<Transform>::parameters(name) : "")); |
| 1198 | 1254 | ||
| 1255 | + if (abstractionsRegExp.exactMatch("Representation")) | ||
| 1256 | + foreach (const QString &name, Factory<Representation>::names()) | ||
| 1257 | + if (implementationsRegExp.exactMatch(name)) | ||
| 1258 | + objectList.append(name + (parameters ? "\t" + Factory<Representation>::parameters(name) : "")); | ||
| 1259 | + | ||
| 1260 | + if (abstractionsRegExp.exactMatch("Classifier")) | ||
| 1261 | + foreach (const QString &name, Factory<Classifier>::names()) | ||
| 1262 | + if (implementationsRegExp.exactMatch(name)) | ||
| 1263 | + objectList.append(name + (parameters ? "\t" + Factory<Classifier>::parameters(name) : "")); | ||
| 1199 | 1264 | ||
| 1200 | return objectList; | 1265 | return objectList; |
| 1201 | } | 1266 | } |
| @@ -1607,3 +1672,29 @@ Transform *br::pipeTransforms(QList<Transform *> &transforms) | @@ -1607,3 +1672,29 @@ Transform *br::pipeTransforms(QList<Transform *> &transforms) | ||
| 1607 | res->setPropertyRecursive("transforms", QVariant::fromValue(transforms)); | 1672 | res->setPropertyRecursive("transforms", QVariant::fromValue(transforms)); |
| 1608 | return res; | 1673 | return res; |
| 1609 | } | 1674 | } |
| 1675 | + | ||
| 1676 | +Representation *Representation::make(QString str, QObject *parent) | ||
| 1677 | +{ | ||
| 1678 | + // Check for custom transforms | ||
| 1679 | + if (Globals->abbreviations.contains(str)) | ||
| 1680 | + return make(Globals->abbreviations[str], parent); | ||
| 1681 | + | ||
| 1682 | + File f = "." + str; | ||
| 1683 | + Representation *rep = Factory<Representation>::make(f); | ||
| 1684 | + | ||
| 1685 | + rep->setParent(parent); | ||
| 1686 | + return rep; | ||
| 1687 | +} | ||
| 1688 | + | ||
| 1689 | +Classifier *Classifier::make(QString str, QObject *parent) | ||
| 1690 | +{ | ||
| 1691 | + // Check for custom transforms | ||
| 1692 | + if (Globals->abbreviations.contains(str)) | ||
| 1693 | + return make(Globals->abbreviations[str], parent); | ||
| 1694 | + | ||
| 1695 | + File f = "." + str; | ||
| 1696 | + Classifier *classifier = Factory<Classifier>::make(f); | ||
| 1697 | + | ||
| 1698 | + classifier->setParent(parent); | ||
| 1699 | + return classifier; | ||
| 1700 | +} |
openbr/openbr_plugin.h
| @@ -1393,10 +1393,13 @@ class BR_EXPORT Representation : public Object | @@ -1393,10 +1393,13 @@ class BR_EXPORT Representation : public Object | ||
| 1393 | public: | 1393 | public: |
| 1394 | virtual ~Representation() {} | 1394 | virtual ~Representation() {} |
| 1395 | 1395 | ||
| 1396 | + static Representation *make(QString str, QObject *parent); /*!< \brief Make a representation from a string. */ | ||
| 1396 | virtual cv::Mat preprocess(const cv::Mat &image) const { return image; } | 1397 | virtual cv::Mat preprocess(const cv::Mat &image) const { return image; } |
| 1398 | + virtual void train(const QList<cv::Mat> &images, const QList<float> &labels) { (void) images; (void)labels; } | ||
| 1397 | // By convention, an empty indices list will result in all feature responses being calculated | 1399 | // By convention, an empty indices list will result in all feature responses being calculated |
| 1398 | // and returned. | 1400 | // and returned. |
| 1399 | virtual cv::Mat evaluate(const cv::Mat &image, const QList<int> &indices = QList<int>()) const = 0; | 1401 | virtual cv::Mat evaluate(const cv::Mat &image, const QList<int> &indices = QList<int>()) const = 0; |
| 1402 | + virtual int numFeatures() const = 0; | ||
| 1400 | }; | 1403 | }; |
| 1401 | 1404 | ||
| 1402 | class BR_EXPORT Classifier : public Object | 1405 | class BR_EXPORT Classifier : public Object |
| @@ -1406,7 +1409,10 @@ class BR_EXPORT Classifier : public Object | @@ -1406,7 +1409,10 @@ class BR_EXPORT Classifier : public Object | ||
| 1406 | public: | 1409 | public: |
| 1407 | virtual ~Classifier() {} | 1410 | virtual ~Classifier() {} |
| 1408 | 1411 | ||
| 1412 | + static Classifier *make(QString str, QObject *parent); /*!< \brief Make a classifier from a string. */ | ||
| 1409 | virtual void train(const QList<cv::Mat> &images, const QList<float> &labels) = 0; | 1413 | virtual void train(const QList<cv::Mat> &images, const QList<float> &labels) = 0; |
| 1414 | + // By convention, classify should return a value normalized such that the threshold is 0. Negative values | ||
| 1415 | + // can be interpreted as a negative classification and positive values as a positive classification. | ||
| 1410 | virtual float classify(const cv::Mat &image) const = 0; | 1416 | virtual float classify(const cv::Mat &image) const = 0; |
| 1411 | }; | 1417 | }; |
| 1412 | 1418 | ||
| @@ -1497,10 +1503,14 @@ Q_DECLARE_METATYPE(br::Template) | @@ -1497,10 +1503,14 @@ Q_DECLARE_METATYPE(br::Template) | ||
| 1497 | Q_DECLARE_METATYPE(br::TemplateList) | 1503 | Q_DECLARE_METATYPE(br::TemplateList) |
| 1498 | Q_DECLARE_METATYPE(br::Transform*) | 1504 | Q_DECLARE_METATYPE(br::Transform*) |
| 1499 | Q_DECLARE_METATYPE(br::Distance*) | 1505 | Q_DECLARE_METATYPE(br::Distance*) |
| 1506 | +Q_DECLARE_METATYPE(br::Representation*) | ||
| 1507 | +Q_DECLARE_METATYPE(br::Classifier*) | ||
| 1500 | Q_DECLARE_METATYPE(QList<int>) | 1508 | Q_DECLARE_METATYPE(QList<int>) |
| 1501 | Q_DECLARE_METATYPE(QList<float>) | 1509 | Q_DECLARE_METATYPE(QList<float>) |
| 1502 | Q_DECLARE_METATYPE(QList<br::Transform*>) | 1510 | Q_DECLARE_METATYPE(QList<br::Transform*>) |
| 1503 | Q_DECLARE_METATYPE(QList<br::Distance*>) | 1511 | Q_DECLARE_METATYPE(QList<br::Distance*>) |
| 1512 | +Q_DECLARE_METATYPE(QList<br::Representation*>) | ||
| 1513 | +Q_DECLARE_METATYPE(QList<br::Classifier*>) | ||
| 1504 | 1514 | ||
| 1505 | #endif // __cplusplus | 1515 | #endif // __cplusplus |
| 1506 | 1516 |