Commit ccee88d8600b615a8e94bdc30fba042bf81fcaf7
Merge pull request #303 from biometrics/classification_api
Update Classification API to support OpenBR object handling
Showing
2 changed files
with
102 additions
and
1 deletions
openbr/openbr_plugin.cpp
| ... | ... | @@ -626,7 +626,11 @@ QStringList Object::prunedArguments(bool expanded) const |
| 626 | 626 | if (className.endsWith(interfaceName)) |
| 627 | 627 | className.chop(interfaceName.size()); |
| 628 | 628 | |
| 629 | - if (interfaceName == "Distance") | |
| 629 | + if (interfaceName == "Representation") | |
| 630 | + shellObject.reset(Factory<Representation>::make(className)); | |
| 631 | + else if (interfaceName == "Classifier") | |
| 632 | + shellObject.reset(Factory<Classifier>::make(className)); | |
| 633 | + else if (interfaceName == "Distance") | |
| 630 | 634 | shellObject.reset(Factory<Distance>::make(className)); |
| 631 | 635 | else if (interfaceName == "Transform") |
| 632 | 636 | shellObject.reset(Factory<Transform>::make(className)); |
| ... | ... | @@ -673,6 +677,12 @@ QString Object::argument(int index, bool expanded) const |
| 673 | 677 | } else if (type == "QList<br::Distance*>") { |
| 674 | 678 | foreach (Distance *distance, variant.value< QList<Distance*> >()) |
| 675 | 679 | strings.append(distance->description(expanded)); |
| 680 | + } else if (type == "QList<br::Representation*>") { | |
| 681 | + foreach (Representation *representation, variant.value< QList<Representation*> >()) | |
| 682 | + strings.append(representation->description(expanded)); | |
| 683 | + } else if (type == "QList<br::Classifier*>") { | |
| 684 | + foreach (Classifier *classifier, variant.value< QList<Classifier*> >()) | |
| 685 | + strings.append(classifier->description(expanded)); | |
| 676 | 686 | } else { |
| 677 | 687 | qFatal("Unrecognized type: %s", qPrintable(type)); |
| 678 | 688 | } |
| ... | ... | @@ -682,6 +692,10 @@ QString Object::argument(int index, bool expanded) const |
| 682 | 692 | return variant.value<Transform*>()->description(expanded); |
| 683 | 693 | } else if (type == "br::Distance*") { |
| 684 | 694 | return variant.value<Distance*>()->description(expanded); |
| 695 | + } else if (type == "br::Representation*") { | |
| 696 | + return variant.value<Representation*>()->description(expanded); | |
| 697 | + } else if (type == "br::Classifier*") { | |
| 698 | + return variant.value<Classifier*>()->description(expanded); | |
| 685 | 699 | } else if (type == "QStringList") { |
| 686 | 700 | return "[" + variant.toStringList().join(",") + "]"; |
| 687 | 701 | } |
| ... | ... | @@ -713,10 +727,20 @@ void Object::store(QDataStream &stream) const |
| 713 | 727 | } else if (type == "QList<br::Distance*>") { |
| 714 | 728 | foreach (Distance *distance, property.read(this).value< QList<Distance*> >()) |
| 715 | 729 | distance->store(stream); |
| 730 | + } else if (type == "QList<br::Representation*>") { | |
| 731 | + foreach (Representation *representation, property.read(this).value< QList<Representation*> >()) | |
| 732 | + representation->store(stream); | |
| 733 | + } else if (type == "QList<br::Classifier*>") { | |
| 734 | + foreach (Classifier *classifier, property.read(this).value< QList<Classifier*> >()) | |
| 735 | + classifier->store(stream); | |
| 716 | 736 | } else if (type == "br::Transform*") { |
| 717 | 737 | property.read(this).value<Transform*>()->store(stream); |
| 718 | 738 | } else if (type == "br::Distance*") { |
| 719 | 739 | property.read(this).value<Distance*>()->store(stream); |
| 740 | + } else if (type == "br::Representation*") { | |
| 741 | + property.read(this).value<Representation*>()->store(stream); | |
| 742 | + } else if (type == "br::Classifier*") { | |
| 743 | + property.read(this).value<Classifier*>()->store(stream); | |
| 720 | 744 | } else if (type == "bool") { |
| 721 | 745 | stream << property.read(this).toBool(); |
| 722 | 746 | } else if (type == "int") { |
| ... | ... | @@ -750,10 +774,20 @@ void Object::load(QDataStream &stream) |
| 750 | 774 | } else if (type == "QList<br::Distance*>") { |
| 751 | 775 | foreach (Distance *distance, property.read(this).value< QList<Distance*> >()) |
| 752 | 776 | distance->load(stream); |
| 777 | + } else if (type == "QList<br::Representation*>") { | |
| 778 | + foreach (Representation *representation, property.read(this).value< QList<Representation*> >()) | |
| 779 | + representation->load(stream); | |
| 780 | + } else if (type == "QList<br::Classifier*>") { | |
| 781 | + foreach (Classifier *classifier, property.read(this).value< QList<Classifier*> >()) | |
| 782 | + classifier->load(stream); | |
| 753 | 783 | } else if (type == "br::Transform*") { |
| 754 | 784 | property.read(this).value<Transform*>()->load(stream); |
| 755 | 785 | } else if (type == "br::Distance*") { |
| 756 | 786 | property.read(this).value<Distance*>()->load(stream); |
| 787 | + } else if (type == "br::Representation*") { | |
| 788 | + property.read(this).value<Representation*>()->load(stream); | |
| 789 | + } else if (type == "br::Classifier*") { | |
| 790 | + property.read(this).value<Classifier*>()->load(stream); | |
| 757 | 791 | } else if (type == "bool") { |
| 758 | 792 | bool value; |
| 759 | 793 | stream >> value; |
| ... | ... | @@ -919,6 +953,18 @@ void Object::setProperty(const QString &name, QVariant value) |
| 919 | 953 | if (element.canConvert<QString>()) parsedValues.append(Distance::make(element.toString(), this)); |
| 920 | 954 | else parsedValues.append(element.value<Distance*>()); |
| 921 | 955 | value.setValue(parsedValues); |
| 956 | + } else if (type == "QList<br::Representation*>") { | |
| 957 | + QList<Representation*> parsedValues; | |
| 958 | + foreach (const QVariant &element, elements) | |
| 959 | + if (element.canConvert<QString>()) parsedValues.append(Representation::make(element.toString(), this)); | |
| 960 | + else parsedValues.append(element.value<Representation*>()); | |
| 961 | + value.setValue(parsedValues); | |
| 962 | + } else if (type == "QList<br::Classifier*>") { | |
| 963 | + QList<Classifier*> parsedValues; | |
| 964 | + foreach (const QVariant &element, elements) | |
| 965 | + if (element.canConvert<QString>()) parsedValues.append(Classifier::make(element.toString(), this)); | |
| 966 | + else parsedValues.append(element.value<Classifier*>()); | |
| 967 | + value.setValue(parsedValues); | |
| 922 | 968 | } else { |
| 923 | 969 | qFatal("Unrecognized type: %s", qPrintable(type)); |
| 924 | 970 | } |
| ... | ... | @@ -928,6 +974,12 @@ void Object::setProperty(const QString &name, QVariant value) |
| 928 | 974 | } else if (type == "br::Distance*") { |
| 929 | 975 | if (value.canConvert<QString>()) |
| 930 | 976 | value.setValue(Distance::make(value.toString(), this)); |
| 977 | + } else if (type == "br::Representation*") { | |
| 978 | + if (value.canConvert<QString>()) | |
| 979 | + value.setValue(Representation::make(value.toString(), this)); | |
| 980 | + } else if (type == "br::Classifier*") { | |
| 981 | + if (value.canConvert<QString>()) | |
| 982 | + value.setValue(Classifier::make(value.toString(), this)); | |
| 931 | 983 | } else if (type == "bool") { |
| 932 | 984 | if (value.isNull()) value = true; |
| 933 | 985 | else if (value == "false") value = false; |
| ... | ... | @@ -1086,10 +1138,14 @@ void br::Context::initialize(int &argc, char *argv[], QString sdkPath, bool useG |
| 1086 | 1138 | qRegisterMetaType<br::TemplateList>(); |
| 1087 | 1139 | qRegisterMetaType< br::Transform* >(); |
| 1088 | 1140 | qRegisterMetaType< br::Distance* >(); |
| 1141 | + qRegisterMetaType< br::Representation* >(); | |
| 1142 | + qRegisterMetaType< br::Classifier* >(); | |
| 1089 | 1143 | qRegisterMetaType< QList<int> >(); |
| 1090 | 1144 | qRegisterMetaType< QList<float> >(); |
| 1091 | 1145 | qRegisterMetaType< QList<br::Transform*> >(); |
| 1092 | 1146 | qRegisterMetaType< QList<br::Distance*> >(); |
| 1147 | + qRegisterMetaType< QList<br::Representation* > >(); | |
| 1148 | + qRegisterMetaType< QList<br::Classifier* > >(); | |
| 1093 | 1149 | qRegisterMetaType< QAbstractSocket::SocketState> (); |
| 1094 | 1150 | qRegisterMetaType< QLocalSocket::LocalSocketState> (); |
| 1095 | 1151 | |
| ... | ... | @@ -1196,6 +1252,15 @@ QStringList br::Context::objects(const char *abstractions, const char *implement |
| 1196 | 1252 | if (implementationsRegExp.exactMatch(name)) |
| 1197 | 1253 | objectList.append(name + (parameters ? "\t" + Factory<Transform>::parameters(name) : "")); |
| 1198 | 1254 | |
| 1255 | + if (abstractionsRegExp.exactMatch("Representation")) | |
| 1256 | + foreach (const QString &name, Factory<Representation>::names()) | |
| 1257 | + if (implementationsRegExp.exactMatch(name)) | |
| 1258 | + objectList.append(name + (parameters ? "\t" + Factory<Representation>::parameters(name) : "")); | |
| 1259 | + | |
| 1260 | + if (abstractionsRegExp.exactMatch("Classifier")) | |
| 1261 | + foreach (const QString &name, Factory<Classifier>::names()) | |
| 1262 | + if (implementationsRegExp.exactMatch(name)) | |
| 1263 | + objectList.append(name + (parameters ? "\t" + Factory<Classifier>::parameters(name) : "")); | |
| 1199 | 1264 | |
| 1200 | 1265 | return objectList; |
| 1201 | 1266 | } |
| ... | ... | @@ -1607,3 +1672,29 @@ Transform *br::pipeTransforms(QList<Transform *> &transforms) |
| 1607 | 1672 | res->setPropertyRecursive("transforms", QVariant::fromValue(transforms)); |
| 1608 | 1673 | return res; |
| 1609 | 1674 | } |
| 1675 | + | |
| 1676 | +Representation *Representation::make(QString str, QObject *parent) | |
| 1677 | +{ | |
| 1678 | + // Check for custom transforms | |
| 1679 | + if (Globals->abbreviations.contains(str)) | |
| 1680 | + return make(Globals->abbreviations[str], parent); | |
| 1681 | + | |
| 1682 | + File f = "." + str; | |
| 1683 | + Representation *rep = Factory<Representation>::make(f); | |
| 1684 | + | |
| 1685 | + rep->setParent(parent); | |
| 1686 | + return rep; | |
| 1687 | +} | |
| 1688 | + | |
| 1689 | +Classifier *Classifier::make(QString str, QObject *parent) | |
| 1690 | +{ | |
| 1691 | + // Check for custom transforms | |
| 1692 | + if (Globals->abbreviations.contains(str)) | |
| 1693 | + return make(Globals->abbreviations[str], parent); | |
| 1694 | + | |
| 1695 | + File f = "." + str; | |
| 1696 | + Classifier *classifier = Factory<Classifier>::make(f); | |
| 1697 | + | |
| 1698 | + classifier->setParent(parent); | |
| 1699 | + return classifier; | |
| 1700 | +} | ... | ... |
openbr/openbr_plugin.h
| ... | ... | @@ -1393,10 +1393,13 @@ class BR_EXPORT Representation : public Object |
| 1393 | 1393 | public: |
| 1394 | 1394 | virtual ~Representation() {} |
| 1395 | 1395 | |
| 1396 | + static Representation *make(QString str, QObject *parent); /*!< \brief Make a representation from a string. */ | |
| 1396 | 1397 | virtual cv::Mat preprocess(const cv::Mat &image) const { return image; } |
| 1398 | + virtual void train(const QList<cv::Mat> &images, const QList<float> &labels) { (void) images; (void)labels; } | |
| 1397 | 1399 | // By convention, an empty indices list will result in all feature responses being calculated |
| 1398 | 1400 | // and returned. |
| 1399 | 1401 | virtual cv::Mat evaluate(const cv::Mat &image, const QList<int> &indices = QList<int>()) const = 0; |
| 1402 | + virtual int numFeatures() const = 0; | |
| 1400 | 1403 | }; |
| 1401 | 1404 | |
| 1402 | 1405 | class BR_EXPORT Classifier : public Object |
| ... | ... | @@ -1406,7 +1409,10 @@ class BR_EXPORT Classifier : public Object |
| 1406 | 1409 | public: |
| 1407 | 1410 | virtual ~Classifier() {} |
| 1408 | 1411 | |
| 1412 | + static Classifier *make(QString str, QObject *parent); /*!< \brief Make a classifier from a string. */ | |
| 1409 | 1413 | virtual void train(const QList<cv::Mat> &images, const QList<float> &labels) = 0; |
| 1414 | + // By convention, classify should return a value normalized such that the threshold is 0. Negative values | |
| 1415 | + // can be interpreted as a negative classification and positive values as a positive classification. | |
| 1410 | 1416 | virtual float classify(const cv::Mat &image) const = 0; |
| 1411 | 1417 | }; |
| 1412 | 1418 | |
| ... | ... | @@ -1497,10 +1503,14 @@ Q_DECLARE_METATYPE(br::Template) |
| 1497 | 1503 | Q_DECLARE_METATYPE(br::TemplateList) |
| 1498 | 1504 | Q_DECLARE_METATYPE(br::Transform*) |
| 1499 | 1505 | Q_DECLARE_METATYPE(br::Distance*) |
| 1506 | +Q_DECLARE_METATYPE(br::Representation*) | |
| 1507 | +Q_DECLARE_METATYPE(br::Classifier*) | |
| 1500 | 1508 | Q_DECLARE_METATYPE(QList<int>) |
| 1501 | 1509 | Q_DECLARE_METATYPE(QList<float>) |
| 1502 | 1510 | Q_DECLARE_METATYPE(QList<br::Transform*>) |
| 1503 | 1511 | Q_DECLARE_METATYPE(QList<br::Distance*>) |
| 1512 | +Q_DECLARE_METATYPE(QList<br::Representation*>) | |
| 1513 | +Q_DECLARE_METATYPE(QList<br::Classifier*>) | |
| 1504 | 1514 | |
| 1505 | 1515 | #endif // __cplusplus |
| 1506 | 1516 | ... | ... |