Commit ccee88d8600b615a8e94bdc30fba042bf81fcaf7

Authored by JordanCheney
2 parents 9fdb4ced fe5bc4ab

Merge pull request #303 from biometrics/classification_api

Update Classification API to support OpenBR object handling
openbr/openbr_plugin.cpp
@@ -626,7 +626,11 @@ QStringList Object::prunedArguments(bool expanded) const @@ -626,7 +626,11 @@ QStringList Object::prunedArguments(bool expanded) const
626 if (className.endsWith(interfaceName)) 626 if (className.endsWith(interfaceName))
627 className.chop(interfaceName.size()); 627 className.chop(interfaceName.size());
628 628
629 - if (interfaceName == "Distance") 629 + if (interfaceName == "Representation")
  630 + shellObject.reset(Factory<Representation>::make(className));
  631 + else if (interfaceName == "Classifier")
  632 + shellObject.reset(Factory<Classifier>::make(className));
  633 + else if (interfaceName == "Distance")
630 shellObject.reset(Factory<Distance>::make(className)); 634 shellObject.reset(Factory<Distance>::make(className));
631 else if (interfaceName == "Transform") 635 else if (interfaceName == "Transform")
632 shellObject.reset(Factory<Transform>::make(className)); 636 shellObject.reset(Factory<Transform>::make(className));
@@ -673,6 +677,12 @@ QString Object::argument(int index, bool expanded) const @@ -673,6 +677,12 @@ QString Object::argument(int index, bool expanded) const
673 } else if (type == "QList<br::Distance*>") { 677 } else if (type == "QList<br::Distance*>") {
674 foreach (Distance *distance, variant.value< QList<Distance*> >()) 678 foreach (Distance *distance, variant.value< QList<Distance*> >())
675 strings.append(distance->description(expanded)); 679 strings.append(distance->description(expanded));
  680 + } else if (type == "QList<br::Representation*>") {
  681 + foreach (Representation *representation, variant.value< QList<Representation*> >())
  682 + strings.append(representation->description(expanded));
  683 + } else if (type == "QList<br::Classifier*>") {
  684 + foreach (Classifier *classifier, variant.value< QList<Classifier*> >())
  685 + strings.append(classifier->description(expanded));
676 } else { 686 } else {
677 qFatal("Unrecognized type: %s", qPrintable(type)); 687 qFatal("Unrecognized type: %s", qPrintable(type));
678 } 688 }
@@ -682,6 +692,10 @@ QString Object::argument(int index, bool expanded) const @@ -682,6 +692,10 @@ QString Object::argument(int index, bool expanded) const
682 return variant.value<Transform*>()->description(expanded); 692 return variant.value<Transform*>()->description(expanded);
683 } else if (type == "br::Distance*") { 693 } else if (type == "br::Distance*") {
684 return variant.value<Distance*>()->description(expanded); 694 return variant.value<Distance*>()->description(expanded);
  695 + } else if (type == "br::Representation*") {
  696 + return variant.value<Representation*>()->description(expanded);
  697 + } else if (type == "br::Classifier*") {
  698 + return variant.value<Classifier*>()->description(expanded);
685 } else if (type == "QStringList") { 699 } else if (type == "QStringList") {
686 return "[" + variant.toStringList().join(",") + "]"; 700 return "[" + variant.toStringList().join(",") + "]";
687 } 701 }
@@ -713,10 +727,20 @@ void Object::store(QDataStream &amp;stream) const @@ -713,10 +727,20 @@ void Object::store(QDataStream &amp;stream) const
713 } else if (type == "QList<br::Distance*>") { 727 } else if (type == "QList<br::Distance*>") {
714 foreach (Distance *distance, property.read(this).value< QList<Distance*> >()) 728 foreach (Distance *distance, property.read(this).value< QList<Distance*> >())
715 distance->store(stream); 729 distance->store(stream);
  730 + } else if (type == "QList<br::Representation*>") {
  731 + foreach (Representation *representation, property.read(this).value< QList<Representation*> >())
  732 + representation->store(stream);
  733 + } else if (type == "QList<br::Classifier*>") {
  734 + foreach (Classifier *classifier, property.read(this).value< QList<Classifier*> >())
  735 + classifier->store(stream);
716 } else if (type == "br::Transform*") { 736 } else if (type == "br::Transform*") {
717 property.read(this).value<Transform*>()->store(stream); 737 property.read(this).value<Transform*>()->store(stream);
718 } else if (type == "br::Distance*") { 738 } else if (type == "br::Distance*") {
719 property.read(this).value<Distance*>()->store(stream); 739 property.read(this).value<Distance*>()->store(stream);
  740 + } else if (type == "br::Representation*") {
  741 + property.read(this).value<Representation*>()->store(stream);
  742 + } else if (type == "br::Classifier*") {
  743 + property.read(this).value<Classifier*>()->store(stream);
720 } else if (type == "bool") { 744 } else if (type == "bool") {
721 stream << property.read(this).toBool(); 745 stream << property.read(this).toBool();
722 } else if (type == "int") { 746 } else if (type == "int") {
@@ -750,10 +774,20 @@ void Object::load(QDataStream &amp;stream) @@ -750,10 +774,20 @@ void Object::load(QDataStream &amp;stream)
750 } else if (type == "QList<br::Distance*>") { 774 } else if (type == "QList<br::Distance*>") {
751 foreach (Distance *distance, property.read(this).value< QList<Distance*> >()) 775 foreach (Distance *distance, property.read(this).value< QList<Distance*> >())
752 distance->load(stream); 776 distance->load(stream);
  777 + } else if (type == "QList<br::Representation*>") {
  778 + foreach (Representation *representation, property.read(this).value< QList<Representation*> >())
  779 + representation->load(stream);
  780 + } else if (type == "QList<br::Classifier*>") {
  781 + foreach (Classifier *classifier, property.read(this).value< QList<Classifier*> >())
  782 + classifier->load(stream);
753 } else if (type == "br::Transform*") { 783 } else if (type == "br::Transform*") {
754 property.read(this).value<Transform*>()->load(stream); 784 property.read(this).value<Transform*>()->load(stream);
755 } else if (type == "br::Distance*") { 785 } else if (type == "br::Distance*") {
756 property.read(this).value<Distance*>()->load(stream); 786 property.read(this).value<Distance*>()->load(stream);
  787 + } else if (type == "br::Representation*") {
  788 + property.read(this).value<Representation*>()->load(stream);
  789 + } else if (type == "br::Classifier*") {
  790 + property.read(this).value<Classifier*>()->load(stream);
757 } else if (type == "bool") { 791 } else if (type == "bool") {
758 bool value; 792 bool value;
759 stream >> value; 793 stream >> value;
@@ -919,6 +953,18 @@ void Object::setProperty(const QString &amp;name, QVariant value) @@ -919,6 +953,18 @@ void Object::setProperty(const QString &amp;name, QVariant value)
919 if (element.canConvert<QString>()) parsedValues.append(Distance::make(element.toString(), this)); 953 if (element.canConvert<QString>()) parsedValues.append(Distance::make(element.toString(), this));
920 else parsedValues.append(element.value<Distance*>()); 954 else parsedValues.append(element.value<Distance*>());
921 value.setValue(parsedValues); 955 value.setValue(parsedValues);
  956 + } else if (type == "QList<br::Representation*>") {
  957 + QList<Representation*> parsedValues;
  958 + foreach (const QVariant &element, elements)
  959 + if (element.canConvert<QString>()) parsedValues.append(Representation::make(element.toString(), this));
  960 + else parsedValues.append(element.value<Representation*>());
  961 + value.setValue(parsedValues);
  962 + } else if (type == "QList<br::Classifier*>") {
  963 + QList<Classifier*> parsedValues;
  964 + foreach (const QVariant &element, elements)
  965 + if (element.canConvert<QString>()) parsedValues.append(Classifier::make(element.toString(), this));
  966 + else parsedValues.append(element.value<Classifier*>());
  967 + value.setValue(parsedValues);
922 } else { 968 } else {
923 qFatal("Unrecognized type: %s", qPrintable(type)); 969 qFatal("Unrecognized type: %s", qPrintable(type));
924 } 970 }
@@ -928,6 +974,12 @@ void Object::setProperty(const QString &amp;name, QVariant value) @@ -928,6 +974,12 @@ void Object::setProperty(const QString &amp;name, QVariant value)
928 } else if (type == "br::Distance*") { 974 } else if (type == "br::Distance*") {
929 if (value.canConvert<QString>()) 975 if (value.canConvert<QString>())
930 value.setValue(Distance::make(value.toString(), this)); 976 value.setValue(Distance::make(value.toString(), this));
  977 + } else if (type == "br::Representation*") {
  978 + if (value.canConvert<QString>())
  979 + value.setValue(Representation::make(value.toString(), this));
  980 + } else if (type == "br::Classifier*") {
  981 + if (value.canConvert<QString>())
  982 + value.setValue(Classifier::make(value.toString(), this));
931 } else if (type == "bool") { 983 } else if (type == "bool") {
932 if (value.isNull()) value = true; 984 if (value.isNull()) value = true;
933 else if (value == "false") value = false; 985 else if (value == "false") value = false;
@@ -1086,10 +1138,14 @@ void br::Context::initialize(int &amp;argc, char *argv[], QString sdkPath, bool useG @@ -1086,10 +1138,14 @@ void br::Context::initialize(int &amp;argc, char *argv[], QString sdkPath, bool useG
1086 qRegisterMetaType<br::TemplateList>(); 1138 qRegisterMetaType<br::TemplateList>();
1087 qRegisterMetaType< br::Transform* >(); 1139 qRegisterMetaType< br::Transform* >();
1088 qRegisterMetaType< br::Distance* >(); 1140 qRegisterMetaType< br::Distance* >();
  1141 + qRegisterMetaType< br::Representation* >();
  1142 + qRegisterMetaType< br::Classifier* >();
1089 qRegisterMetaType< QList<int> >(); 1143 qRegisterMetaType< QList<int> >();
1090 qRegisterMetaType< QList<float> >(); 1144 qRegisterMetaType< QList<float> >();
1091 qRegisterMetaType< QList<br::Transform*> >(); 1145 qRegisterMetaType< QList<br::Transform*> >();
1092 qRegisterMetaType< QList<br::Distance*> >(); 1146 qRegisterMetaType< QList<br::Distance*> >();
  1147 + qRegisterMetaType< QList<br::Representation* > >();
  1148 + qRegisterMetaType< QList<br::Classifier* > >();
1093 qRegisterMetaType< QAbstractSocket::SocketState> (); 1149 qRegisterMetaType< QAbstractSocket::SocketState> ();
1094 qRegisterMetaType< QLocalSocket::LocalSocketState> (); 1150 qRegisterMetaType< QLocalSocket::LocalSocketState> ();
1095 1151
@@ -1196,6 +1252,15 @@ QStringList br::Context::objects(const char *abstractions, const char *implement @@ -1196,6 +1252,15 @@ QStringList br::Context::objects(const char *abstractions, const char *implement
1196 if (implementationsRegExp.exactMatch(name)) 1252 if (implementationsRegExp.exactMatch(name))
1197 objectList.append(name + (parameters ? "\t" + Factory<Transform>::parameters(name) : "")); 1253 objectList.append(name + (parameters ? "\t" + Factory<Transform>::parameters(name) : ""));
1198 1254
  1255 + if (abstractionsRegExp.exactMatch("Representation"))
  1256 + foreach (const QString &name, Factory<Representation>::names())
  1257 + if (implementationsRegExp.exactMatch(name))
  1258 + objectList.append(name + (parameters ? "\t" + Factory<Representation>::parameters(name) : ""));
  1259 +
  1260 + if (abstractionsRegExp.exactMatch("Classifier"))
  1261 + foreach (const QString &name, Factory<Classifier>::names())
  1262 + if (implementationsRegExp.exactMatch(name))
  1263 + objectList.append(name + (parameters ? "\t" + Factory<Classifier>::parameters(name) : ""));
1199 1264
1200 return objectList; 1265 return objectList;
1201 } 1266 }
@@ -1607,3 +1672,29 @@ Transform *br::pipeTransforms(QList&lt;Transform *&gt; &amp;transforms) @@ -1607,3 +1672,29 @@ Transform *br::pipeTransforms(QList&lt;Transform *&gt; &amp;transforms)
1607 res->setPropertyRecursive("transforms", QVariant::fromValue(transforms)); 1672 res->setPropertyRecursive("transforms", QVariant::fromValue(transforms));
1608 return res; 1673 return res;
1609 } 1674 }
  1675 +
  1676 +Representation *Representation::make(QString str, QObject *parent)
  1677 +{
  1678 + // Check for custom transforms
  1679 + if (Globals->abbreviations.contains(str))
  1680 + return make(Globals->abbreviations[str], parent);
  1681 +
  1682 + File f = "." + str;
  1683 + Representation *rep = Factory<Representation>::make(f);
  1684 +
  1685 + rep->setParent(parent);
  1686 + return rep;
  1687 +}
  1688 +
  1689 +Classifier *Classifier::make(QString str, QObject *parent)
  1690 +{
  1691 + // Check for custom transforms
  1692 + if (Globals->abbreviations.contains(str))
  1693 + return make(Globals->abbreviations[str], parent);
  1694 +
  1695 + File f = "." + str;
  1696 + Classifier *classifier = Factory<Classifier>::make(f);
  1697 +
  1698 + classifier->setParent(parent);
  1699 + return classifier;
  1700 +}
openbr/openbr_plugin.h
@@ -1393,10 +1393,13 @@ class BR_EXPORT Representation : public Object @@ -1393,10 +1393,13 @@ class BR_EXPORT Representation : public Object
1393 public: 1393 public:
1394 virtual ~Representation() {} 1394 virtual ~Representation() {}
1395 1395
  1396 + static Representation *make(QString str, QObject *parent); /*!< \brief Make a representation from a string. */
1396 virtual cv::Mat preprocess(const cv::Mat &image) const { return image; } 1397 virtual cv::Mat preprocess(const cv::Mat &image) const { return image; }
  1398 + virtual void train(const QList<cv::Mat> &images, const QList<float> &labels) { (void) images; (void)labels; }
1397 // By convention, an empty indices list will result in all feature responses being calculated 1399 // By convention, an empty indices list will result in all feature responses being calculated
1398 // and returned. 1400 // and returned.
1399 virtual cv::Mat evaluate(const cv::Mat &image, const QList<int> &indices = QList<int>()) const = 0; 1401 virtual cv::Mat evaluate(const cv::Mat &image, const QList<int> &indices = QList<int>()) const = 0;
  1402 + virtual int numFeatures() const = 0;
1400 }; 1403 };
1401 1404
1402 class BR_EXPORT Classifier : public Object 1405 class BR_EXPORT Classifier : public Object
@@ -1406,7 +1409,10 @@ class BR_EXPORT Classifier : public Object @@ -1406,7 +1409,10 @@ class BR_EXPORT Classifier : public Object
1406 public: 1409 public:
1407 virtual ~Classifier() {} 1410 virtual ~Classifier() {}
1408 1411
  1412 + static Classifier *make(QString str, QObject *parent); /*!< \brief Make a classifier from a string. */
1409 virtual void train(const QList<cv::Mat> &images, const QList<float> &labels) = 0; 1413 virtual void train(const QList<cv::Mat> &images, const QList<float> &labels) = 0;
  1414 + // By convention, classify should return a value normalized such that the threshold is 0. Negative values
  1415 + // can be interpreted as a negative classification and positive values as a positive classification.
1410 virtual float classify(const cv::Mat &image) const = 0; 1416 virtual float classify(const cv::Mat &image) const = 0;
1411 }; 1417 };
1412 1418
@@ -1497,10 +1503,14 @@ Q_DECLARE_METATYPE(br::Template) @@ -1497,10 +1503,14 @@ Q_DECLARE_METATYPE(br::Template)
1497 Q_DECLARE_METATYPE(br::TemplateList) 1503 Q_DECLARE_METATYPE(br::TemplateList)
1498 Q_DECLARE_METATYPE(br::Transform*) 1504 Q_DECLARE_METATYPE(br::Transform*)
1499 Q_DECLARE_METATYPE(br::Distance*) 1505 Q_DECLARE_METATYPE(br::Distance*)
  1506 +Q_DECLARE_METATYPE(br::Representation*)
  1507 +Q_DECLARE_METATYPE(br::Classifier*)
1500 Q_DECLARE_METATYPE(QList<int>) 1508 Q_DECLARE_METATYPE(QList<int>)
1501 Q_DECLARE_METATYPE(QList<float>) 1509 Q_DECLARE_METATYPE(QList<float>)
1502 Q_DECLARE_METATYPE(QList<br::Transform*>) 1510 Q_DECLARE_METATYPE(QList<br::Transform*>)
1503 Q_DECLARE_METATYPE(QList<br::Distance*>) 1511 Q_DECLARE_METATYPE(QList<br::Distance*>)
  1512 +Q_DECLARE_METATYPE(QList<br::Representation*>)
  1513 +Q_DECLARE_METATYPE(QList<br::Classifier*>)
1504 1514
1505 #endif // __cplusplus 1515 #endif // __cplusplus
1506 1516