Commit ccee88d8600b615a8e94bdc30fba042bf81fcaf7

Authored by JordanCheney
2 parents 9fdb4ced fe5bc4ab

Merge pull request #303 from biometrics/classification_api

Update Classification API to support OpenBR object handling
openbr/openbr_plugin.cpp
... ... @@ -626,7 +626,11 @@ QStringList Object::prunedArguments(bool expanded) const
626 626 if (className.endsWith(interfaceName))
627 627 className.chop(interfaceName.size());
628 628  
629   - if (interfaceName == "Distance")
  629 + if (interfaceName == "Representation")
  630 + shellObject.reset(Factory<Representation>::make(className));
  631 + else if (interfaceName == "Classifier")
  632 + shellObject.reset(Factory<Classifier>::make(className));
  633 + else if (interfaceName == "Distance")
630 634 shellObject.reset(Factory<Distance>::make(className));
631 635 else if (interfaceName == "Transform")
632 636 shellObject.reset(Factory<Transform>::make(className));
... ... @@ -673,6 +677,12 @@ QString Object::argument(int index, bool expanded) const
673 677 } else if (type == "QList<br::Distance*>") {
674 678 foreach (Distance *distance, variant.value< QList<Distance*> >())
675 679 strings.append(distance->description(expanded));
  680 + } else if (type == "QList<br::Representation*>") {
  681 + foreach (Representation *representation, variant.value< QList<Representation*> >())
  682 + strings.append(representation->description(expanded));
  683 + } else if (type == "QList<br::Classifier*>") {
  684 + foreach (Classifier *classifier, variant.value< QList<Classifier*> >())
  685 + strings.append(classifier->description(expanded));
676 686 } else {
677 687 qFatal("Unrecognized type: %s", qPrintable(type));
678 688 }
... ... @@ -682,6 +692,10 @@ QString Object::argument(int index, bool expanded) const
682 692 return variant.value<Transform*>()->description(expanded);
683 693 } else if (type == "br::Distance*") {
684 694 return variant.value<Distance*>()->description(expanded);
  695 + } else if (type == "br::Representation*") {
  696 + return variant.value<Representation*>()->description(expanded);
  697 + } else if (type == "br::Classifier*") {
  698 + return variant.value<Classifier*>()->description(expanded);
685 699 } else if (type == "QStringList") {
686 700 return "[" + variant.toStringList().join(",") + "]";
687 701 }
... ... @@ -713,10 +727,20 @@ void Object::store(QDataStream &amp;stream) const
713 727 } else if (type == "QList<br::Distance*>") {
714 728 foreach (Distance *distance, property.read(this).value< QList<Distance*> >())
715 729 distance->store(stream);
  730 + } else if (type == "QList<br::Representation*>") {
  731 + foreach (Representation *representation, property.read(this).value< QList<Representation*> >())
  732 + representation->store(stream);
  733 + } else if (type == "QList<br::Classifier*>") {
  734 + foreach (Classifier *classifier, property.read(this).value< QList<Classifier*> >())
  735 + classifier->store(stream);
716 736 } else if (type == "br::Transform*") {
717 737 property.read(this).value<Transform*>()->store(stream);
718 738 } else if (type == "br::Distance*") {
719 739 property.read(this).value<Distance*>()->store(stream);
  740 + } else if (type == "br::Representation*") {
  741 + property.read(this).value<Representation*>()->store(stream);
  742 + } else if (type == "br::Classifier*") {
  743 + property.read(this).value<Classifier*>()->store(stream);
720 744 } else if (type == "bool") {
721 745 stream << property.read(this).toBool();
722 746 } else if (type == "int") {
... ... @@ -750,10 +774,20 @@ void Object::load(QDataStream &amp;stream)
750 774 } else if (type == "QList<br::Distance*>") {
751 775 foreach (Distance *distance, property.read(this).value< QList<Distance*> >())
752 776 distance->load(stream);
  777 + } else if (type == "QList<br::Representation*>") {
  778 + foreach (Representation *representation, property.read(this).value< QList<Representation*> >())
  779 + representation->load(stream);
  780 + } else if (type == "QList<br::Classifier*>") {
  781 + foreach (Classifier *classifier, property.read(this).value< QList<Classifier*> >())
  782 + classifier->load(stream);
753 783 } else if (type == "br::Transform*") {
754 784 property.read(this).value<Transform*>()->load(stream);
755 785 } else if (type == "br::Distance*") {
756 786 property.read(this).value<Distance*>()->load(stream);
  787 + } else if (type == "br::Representation*") {
  788 + property.read(this).value<Representation*>()->load(stream);
  789 + } else if (type == "br::Classifier*") {
  790 + property.read(this).value<Classifier*>()->load(stream);
757 791 } else if (type == "bool") {
758 792 bool value;
759 793 stream >> value;
... ... @@ -919,6 +953,18 @@ void Object::setProperty(const QString &amp;name, QVariant value)
919 953 if (element.canConvert<QString>()) parsedValues.append(Distance::make(element.toString(), this));
920 954 else parsedValues.append(element.value<Distance*>());
921 955 value.setValue(parsedValues);
  956 + } else if (type == "QList<br::Representation*>") {
  957 + QList<Representation*> parsedValues;
  958 + foreach (const QVariant &element, elements)
  959 + if (element.canConvert<QString>()) parsedValues.append(Representation::make(element.toString(), this));
  960 + else parsedValues.append(element.value<Representation*>());
  961 + value.setValue(parsedValues);
  962 + } else if (type == "QList<br::Classifier*>") {
  963 + QList<Classifier*> parsedValues;
  964 + foreach (const QVariant &element, elements)
  965 + if (element.canConvert<QString>()) parsedValues.append(Classifier::make(element.toString(), this));
  966 + else parsedValues.append(element.value<Classifier*>());
  967 + value.setValue(parsedValues);
922 968 } else {
923 969 qFatal("Unrecognized type: %s", qPrintable(type));
924 970 }
... ... @@ -928,6 +974,12 @@ void Object::setProperty(const QString &amp;name, QVariant value)
928 974 } else if (type == "br::Distance*") {
929 975 if (value.canConvert<QString>())
930 976 value.setValue(Distance::make(value.toString(), this));
  977 + } else if (type == "br::Representation*") {
  978 + if (value.canConvert<QString>())
  979 + value.setValue(Representation::make(value.toString(), this));
  980 + } else if (type == "br::Classifier*") {
  981 + if (value.canConvert<QString>())
  982 + value.setValue(Classifier::make(value.toString(), this));
931 983 } else if (type == "bool") {
932 984 if (value.isNull()) value = true;
933 985 else if (value == "false") value = false;
... ... @@ -1086,10 +1138,14 @@ void br::Context::initialize(int &amp;argc, char *argv[], QString sdkPath, bool useG
1086 1138 qRegisterMetaType<br::TemplateList>();
1087 1139 qRegisterMetaType< br::Transform* >();
1088 1140 qRegisterMetaType< br::Distance* >();
  1141 + qRegisterMetaType< br::Representation* >();
  1142 + qRegisterMetaType< br::Classifier* >();
1089 1143 qRegisterMetaType< QList<int> >();
1090 1144 qRegisterMetaType< QList<float> >();
1091 1145 qRegisterMetaType< QList<br::Transform*> >();
1092 1146 qRegisterMetaType< QList<br::Distance*> >();
  1147 + qRegisterMetaType< QList<br::Representation* > >();
  1148 + qRegisterMetaType< QList<br::Classifier* > >();
1093 1149 qRegisterMetaType< QAbstractSocket::SocketState> ();
1094 1150 qRegisterMetaType< QLocalSocket::LocalSocketState> ();
1095 1151  
... ... @@ -1196,6 +1252,15 @@ QStringList br::Context::objects(const char *abstractions, const char *implement
1196 1252 if (implementationsRegExp.exactMatch(name))
1197 1253 objectList.append(name + (parameters ? "\t" + Factory<Transform>::parameters(name) : ""));
1198 1254  
  1255 + if (abstractionsRegExp.exactMatch("Representation"))
  1256 + foreach (const QString &name, Factory<Representation>::names())
  1257 + if (implementationsRegExp.exactMatch(name))
  1258 + objectList.append(name + (parameters ? "\t" + Factory<Representation>::parameters(name) : ""));
  1259 +
  1260 + if (abstractionsRegExp.exactMatch("Classifier"))
  1261 + foreach (const QString &name, Factory<Classifier>::names())
  1262 + if (implementationsRegExp.exactMatch(name))
  1263 + objectList.append(name + (parameters ? "\t" + Factory<Classifier>::parameters(name) : ""));
1199 1264  
1200 1265 return objectList;
1201 1266 }
... ... @@ -1607,3 +1672,29 @@ Transform *br::pipeTransforms(QList&lt;Transform *&gt; &amp;transforms)
1607 1672 res->setPropertyRecursive("transforms", QVariant::fromValue(transforms));
1608 1673 return res;
1609 1674 }
  1675 +
  1676 +Representation *Representation::make(QString str, QObject *parent)
  1677 +{
  1678 + // Check for custom transforms
  1679 + if (Globals->abbreviations.contains(str))
  1680 + return make(Globals->abbreviations[str], parent);
  1681 +
  1682 + File f = "." + str;
  1683 + Representation *rep = Factory<Representation>::make(f);
  1684 +
  1685 + rep->setParent(parent);
  1686 + return rep;
  1687 +}
  1688 +
  1689 +Classifier *Classifier::make(QString str, QObject *parent)
  1690 +{
  1691 + // Check for custom transforms
  1692 + if (Globals->abbreviations.contains(str))
  1693 + return make(Globals->abbreviations[str], parent);
  1694 +
  1695 + File f = "." + str;
  1696 + Classifier *classifier = Factory<Classifier>::make(f);
  1697 +
  1698 + classifier->setParent(parent);
  1699 + return classifier;
  1700 +}
... ...
openbr/openbr_plugin.h
... ... @@ -1393,10 +1393,13 @@ class BR_EXPORT Representation : public Object
1393 1393 public:
1394 1394 virtual ~Representation() {}
1395 1395  
  1396 + static Representation *make(QString str, QObject *parent); /*!< \brief Make a representation from a string. */
1396 1397 virtual cv::Mat preprocess(const cv::Mat &image) const { return image; }
  1398 + virtual void train(const QList<cv::Mat> &images, const QList<float> &labels) { (void) images; (void)labels; }
1397 1399 // By convention, an empty indices list will result in all feature responses being calculated
1398 1400 // and returned.
1399 1401 virtual cv::Mat evaluate(const cv::Mat &image, const QList<int> &indices = QList<int>()) const = 0;
  1402 + virtual int numFeatures() const = 0;
1400 1403 };
1401 1404  
1402 1405 class BR_EXPORT Classifier : public Object
... ... @@ -1406,7 +1409,10 @@ class BR_EXPORT Classifier : public Object
1406 1409 public:
1407 1410 virtual ~Classifier() {}
1408 1411  
  1412 + static Classifier *make(QString str, QObject *parent); /*!< \brief Make a classifier from a string. */
1409 1413 virtual void train(const QList<cv::Mat> &images, const QList<float> &labels) = 0;
  1414 + // By convention, classify should return a value normalized such that the threshold is 0. Negative values
  1415 + // can be interpreted as a negative classification and positive values as a positive classification.
1410 1416 virtual float classify(const cv::Mat &image) const = 0;
1411 1417 };
1412 1418  
... ... @@ -1497,10 +1503,14 @@ Q_DECLARE_METATYPE(br::Template)
1497 1503 Q_DECLARE_METATYPE(br::TemplateList)
1498 1504 Q_DECLARE_METATYPE(br::Transform*)
1499 1505 Q_DECLARE_METATYPE(br::Distance*)
  1506 +Q_DECLARE_METATYPE(br::Representation*)
  1507 +Q_DECLARE_METATYPE(br::Classifier*)
1500 1508 Q_DECLARE_METATYPE(QList<int>)
1501 1509 Q_DECLARE_METATYPE(QList<float>)
1502 1510 Q_DECLARE_METATYPE(QList<br::Transform*>)
1503 1511 Q_DECLARE_METATYPE(QList<br::Distance*>)
  1512 +Q_DECLARE_METATYPE(QList<br::Representation*>)
  1513 +Q_DECLARE_METATYPE(QList<br::Classifier*>)
1504 1514  
1505 1515 #endif // __cplusplus
1506 1516  
... ...