Commit 9e374684426a57ae37b6792fb2c9ae64f8c42857

Authored by JordanCheney
1 parent b32f0957

Update Classification API to support OpenBR object handling

openbr/openbr_plugin.cpp
@@ -625,7 +625,11 @@ QStringList Object::prunedArguments(bool expanded) const @@ -625,7 +625,11 @@ QStringList Object::prunedArguments(bool expanded) const
625 if (className.endsWith(interfaceName)) 625 if (className.endsWith(interfaceName))
626 className.chop(interfaceName.size()); 626 className.chop(interfaceName.size());
627 627
628 - if (interfaceName == "Distance") 628 + if (interfaceName == "Representation")
  629 + shellObject.reset(Factory<Representation>::make(className));
  630 + else if (interfaceName == "Classifier")
  631 + shellObject.reset(Factory<Classifier>::make(className));
  632 + else if (interfaceName == "Distance")
629 shellObject.reset(Factory<Distance>::make(className)); 633 shellObject.reset(Factory<Distance>::make(className));
630 else if (interfaceName == "Transform") 634 else if (interfaceName == "Transform")
631 shellObject.reset(Factory<Transform>::make(className)); 635 shellObject.reset(Factory<Transform>::make(className));
@@ -672,6 +676,12 @@ QString Object::argument(int index, bool expanded) const @@ -672,6 +676,12 @@ QString Object::argument(int index, bool expanded) const
672 } else if (type == "QList<br::Distance*>") { 676 } else if (type == "QList<br::Distance*>") {
673 foreach (Distance *distance, variant.value< QList<Distance*> >()) 677 foreach (Distance *distance, variant.value< QList<Distance*> >())
674 strings.append(distance->description(expanded)); 678 strings.append(distance->description(expanded));
  679 + } else if (type == "QList<br::Representation*>") {
  680 + foreach (Representation *representation, variant.value< QList<Representation*> >())
  681 + strings.append(representation->description(expanded));
  682 + } else if (type == "QList<br::Classifier*>") {
  683 + foreach (Classifier *classifier, variant.value< QList<Classifier*> >())
  684 + strings.append(classifier->description(expanded));
675 } else { 685 } else {
676 qFatal("Unrecognized type: %s", qPrintable(type)); 686 qFatal("Unrecognized type: %s", qPrintable(type));
677 } 687 }
@@ -681,6 +691,10 @@ QString Object::argument(int index, bool expanded) const @@ -681,6 +691,10 @@ QString Object::argument(int index, bool expanded) const
681 return variant.value<Transform*>()->description(expanded); 691 return variant.value<Transform*>()->description(expanded);
682 } else if (type == "br::Distance*") { 692 } else if (type == "br::Distance*") {
683 return variant.value<Distance*>()->description(expanded); 693 return variant.value<Distance*>()->description(expanded);
  694 + } else if (type == "br::Representation*") {
  695 + return variant.value<Representation*>()->description(expanded);
  696 + } else if (type == "br::Classifier*") {
  697 + return variant.value<Classifier*>()->description(expanded);
684 } else if (type == "QStringList") { 698 } else if (type == "QStringList") {
685 return "[" + variant.toStringList().join(",") + "]"; 699 return "[" + variant.toStringList().join(",") + "]";
686 } 700 }
@@ -712,10 +726,20 @@ void Object::store(QDataStream &amp;stream) const @@ -712,10 +726,20 @@ void Object::store(QDataStream &amp;stream) const
712 } else if (type == "QList<br::Distance*>") { 726 } else if (type == "QList<br::Distance*>") {
713 foreach (Distance *distance, property.read(this).value< QList<Distance*> >()) 727 foreach (Distance *distance, property.read(this).value< QList<Distance*> >())
714 distance->store(stream); 728 distance->store(stream);
  729 + } else if (type == "QList<br::Representation*>") {
  730 + foreach (Representation *representation, property.read(this).value< QList<Representation*> >())
  731 + representation->store(stream);
  732 + } else if (type == "QList<br::Classifier*>") {
  733 + foreach (Classifier *classifier, property.read(this).value< QList<Classifier*> >())
  734 + classifier->store(stream);
715 } else if (type == "br::Transform*") { 735 } else if (type == "br::Transform*") {
716 property.read(this).value<Transform*>()->store(stream); 736 property.read(this).value<Transform*>()->store(stream);
717 } else if (type == "br::Distance*") { 737 } else if (type == "br::Distance*") {
718 property.read(this).value<Distance*>()->store(stream); 738 property.read(this).value<Distance*>()->store(stream);
  739 + } else if (type == "br::Representation*") {
  740 + property.read(this).value<Representation*>()->store(stream);
  741 + } else if (type == "br::Classifier*") {
  742 + property.read(this).value<Classifier*>()->store(stream);
719 } else if (type == "bool") { 743 } else if (type == "bool") {
720 stream << property.read(this).toBool(); 744 stream << property.read(this).toBool();
721 } else if (type == "int") { 745 } else if (type == "int") {
@@ -749,10 +773,20 @@ void Object::load(QDataStream &amp;stream) @@ -749,10 +773,20 @@ void Object::load(QDataStream &amp;stream)
749 } else if (type == "QList<br::Distance*>") { 773 } else if (type == "QList<br::Distance*>") {
750 foreach (Distance *distance, property.read(this).value< QList<Distance*> >()) 774 foreach (Distance *distance, property.read(this).value< QList<Distance*> >())
751 distance->load(stream); 775 distance->load(stream);
  776 + } else if (type == "QList<br::Representation*>") {
  777 + foreach (Representation *representation, property.read(this).value< QList<Representation*> >())
  778 + representation->load(stream);
  779 + } else if (type == "QList<br::Classifier*>") {
  780 + foreach (Classifier *classifier, property.read(this).value< QList<Classifier*> >())
  781 + classifier->load(stream);
752 } else if (type == "br::Transform*") { 782 } else if (type == "br::Transform*") {
753 property.read(this).value<Transform*>()->load(stream); 783 property.read(this).value<Transform*>()->load(stream);
754 } else if (type == "br::Distance*") { 784 } else if (type == "br::Distance*") {
755 property.read(this).value<Distance*>()->load(stream); 785 property.read(this).value<Distance*>()->load(stream);
  786 + } else if (type == "br::Representation*") {
  787 + property.read(this).value<Representation*>()->load(stream);
  788 + } else if (type == "br::Classifier*") {
  789 + property.read(this).value<Classifier*>()->load(stream);
756 } else if (type == "bool") { 790 } else if (type == "bool") {
757 bool value; 791 bool value;
758 stream >> value; 792 stream >> value;
@@ -868,6 +902,18 @@ void Object::setProperty(const QString &amp;name, QVariant value) @@ -868,6 +902,18 @@ void Object::setProperty(const QString &amp;name, QVariant value)
868 if (element.canConvert<QString>()) parsedValues.append(Distance::make(element.toString(), this)); 902 if (element.canConvert<QString>()) parsedValues.append(Distance::make(element.toString(), this));
869 else parsedValues.append(element.value<Distance*>()); 903 else parsedValues.append(element.value<Distance*>());
870 value.setValue(parsedValues); 904 value.setValue(parsedValues);
  905 + } else if (type == "QList<br::Representation*>") {
  906 + QList<Representation*> parsedValues;
  907 + foreach (const QVariant &element, elements)
  908 + if (element.canConvert<QString>()) parsedValues.append(Representation::make(element.toString()), this);
  909 + else parsedValues.append(element.value<Representation*>());
  910 + value.setValue(parsedValues);
  911 + } else if (type == "QList<br::Classifier*>") {
  912 + QList<Classifier*> parsedValues;
  913 + foreach (const QVariant &element, elements)
  914 + if (element.canConvert<QString>()) parsedValues.append(Classifier::make(element.toString()), this);
  915 + else parsedValues.append(element.value<Classifier*>());
  916 + value.setValue(parsedValues);
871 } else { 917 } else {
872 qFatal("Unrecognized type: %s", qPrintable(type)); 918 qFatal("Unrecognized type: %s", qPrintable(type));
873 } 919 }
@@ -877,6 +923,12 @@ void Object::setProperty(const QString &amp;name, QVariant value) @@ -877,6 +923,12 @@ void Object::setProperty(const QString &amp;name, QVariant value)
877 } else if (type == "br::Distance*") { 923 } else if (type == "br::Distance*") {
878 if (value.canConvert<QString>()) 924 if (value.canConvert<QString>())
879 value.setValue(Distance::make(value.toString(), this)); 925 value.setValue(Distance::make(value.toString(), this));
  926 + } else if (type == "br::Representation*") {
  927 + if (value.canConvert<QString>())
  928 + value.setValue(Representation::make(value.toString(), this));
  929 + } else if (type == "br::Classifier*") {
  930 + if (value.canConvert<QString>())
  931 + value.setValue(Classifier::make(value.toString(), this));
880 } else if (type == "bool") { 932 } else if (type == "bool") {
881 if (value.isNull()) value = true; 933 if (value.isNull()) value = true;
882 else if (value == "false") value = false; 934 else if (value == "false") value = false;
@@ -1035,10 +1087,14 @@ void br::Context::initialize(int &amp;argc, char *argv[], QString sdkPath, bool useG @@ -1035,10 +1087,14 @@ void br::Context::initialize(int &amp;argc, char *argv[], QString sdkPath, bool useG
1035 qRegisterMetaType<br::TemplateList>(); 1087 qRegisterMetaType<br::TemplateList>();
1036 qRegisterMetaType< br::Transform* >(); 1088 qRegisterMetaType< br::Transform* >();
1037 qRegisterMetaType< br::Distance* >(); 1089 qRegisterMetaType< br::Distance* >();
  1090 + qRegisterMetaType< br::Representation* >();
  1091 + qRegisterMetaType< br::Classifier* >();
1038 qRegisterMetaType< QList<int> >(); 1092 qRegisterMetaType< QList<int> >();
1039 qRegisterMetaType< QList<float> >(); 1093 qRegisterMetaType< QList<float> >();
1040 qRegisterMetaType< QList<br::Transform*> >(); 1094 qRegisterMetaType< QList<br::Transform*> >();
1041 qRegisterMetaType< QList<br::Distance*> >(); 1095 qRegisterMetaType< QList<br::Distance*> >();
  1096 + qRegisterMetaType< QList<br::Representation* > >();
  1097 + qRegisterMetaType< QList<br::Classifier* > >();
1042 qRegisterMetaType< QAbstractSocket::SocketState> (); 1098 qRegisterMetaType< QAbstractSocket::SocketState> ();
1043 qRegisterMetaType< QLocalSocket::LocalSocketState> (); 1099 qRegisterMetaType< QLocalSocket::LocalSocketState> ();
1044 1100
@@ -1145,6 +1201,15 @@ QStringList br::Context::objects(const char *abstractions, const char *implement @@ -1145,6 +1201,15 @@ QStringList br::Context::objects(const char *abstractions, const char *implement
1145 if (implementationsRegExp.exactMatch(name)) 1201 if (implementationsRegExp.exactMatch(name))
1146 objectList.append(name + (parameters ? "\t" + Factory<Transform>::parameters(name) : "")); 1202 objectList.append(name + (parameters ? "\t" + Factory<Transform>::parameters(name) : ""));
1147 1203
  1204 + if (abstractionsRegExp.exactMatch("Representation"))
  1205 + foreach (const QString &name, Factory<Representation>::names())
  1206 + if (implementationsRegExp.exactMatch(name))
  1207 + objectList.append(name + (parameters ? "\t" + Factory<Representation>::parameters(name) : ""));
  1208 +
  1209 + if (abstractionsRegExp.exactMatch("Classifier"))
  1210 + foreach (const QString &name, Factory<Classifier>::names())
  1211 + if (implementationsRegExp.exactMatch(name))
  1212 + objectList.append(name + (parameters ? "\t" + Factory<Classifier>::parameters(name) : ""));
1148 1213
1149 return objectList; 1214 return objectList;
1150 } 1215 }
@@ -1548,3 +1613,29 @@ Transform *br::pipeTransforms(QList&lt;Transform *&gt; &amp;transforms) @@ -1548,3 +1613,29 @@ Transform *br::pipeTransforms(QList&lt;Transform *&gt; &amp;transforms)
1548 res->setPropertyRecursive("transforms", QVariant::fromValue(transforms)); 1613 res->setPropertyRecursive("transforms", QVariant::fromValue(transforms));
1549 return res; 1614 return res;
1550 } 1615 }
  1616 +
  1617 +Representation *Representation::make(QString str, QObject *parent)
  1618 +{
  1619 + // Check for custom transforms
  1620 + if (Globals->abbreviations.contains(str))
  1621 + return make(Globals->abbreviations[str], parent);
  1622 +
  1623 + File f = "." + str;
  1624 + Representation *rep = Factory<Representation>::make(f);
  1625 +
  1626 + rep->setParent(parent);
  1627 + return rep;
  1628 +}
  1629 +
  1630 +Classifier *Classifier::make(QString str, QObject *parent)
  1631 +{
  1632 + // Check for custom transforms
  1633 + if (Globals->abbreviations.contains(str))
  1634 + return make(Globals->abbreviations[str], parent);
  1635 +
  1636 + File f = "." + str;
  1637 + Classifier *classifier = Factory<Classifier>::make(f);
  1638 +
  1639 + classifier->setParent(parent);
  1640 + return classifier;
  1641 +}
openbr/openbr_plugin.h
@@ -1383,10 +1383,12 @@ class BR_EXPORT Representation : public Object @@ -1383,10 +1383,12 @@ class BR_EXPORT Representation : public Object
1383 public: 1383 public:
1384 virtual ~Representation() {} 1384 virtual ~Representation() {}
1385 1385
  1386 + static Representation *make(QString str, QObject *parent); /*!< \brief Make a representation from a string. */
1386 virtual cv::Mat preprocess(const cv::Mat &image) const { return image; } 1387 virtual cv::Mat preprocess(const cv::Mat &image) const { return image; }
1387 // By convention, an empty indices list will result in all feature responses being calculated 1388 // By convention, an empty indices list will result in all feature responses being calculated
1388 // and returned. 1389 // and returned.
1389 virtual cv::Mat evaluate(const cv::Mat &image, const QList<int> &indices = QList<int>()) const = 0; 1390 virtual cv::Mat evaluate(const cv::Mat &image, const QList<int> &indices = QList<int>()) const = 0;
  1391 + virtual int numFeatures() const = 0;
1390 }; 1392 };
1391 1393
1392 class BR_EXPORT Classifier : public Object 1394 class BR_EXPORT Classifier : public Object
@@ -1396,6 +1398,7 @@ class BR_EXPORT Classifier : public Object @@ -1396,6 +1398,7 @@ class BR_EXPORT Classifier : public Object
1396 public: 1398 public:
1397 virtual ~Classifier() {} 1399 virtual ~Classifier() {}
1398 1400
  1401 + static Classifier *make(QString str, QObject *parent); /*!< \brief Make a classifier from a string. */
1399 virtual void train(const QList<cv::Mat> &images, const QList<float> &labels) = 0; 1402 virtual void train(const QList<cv::Mat> &images, const QList<float> &labels) = 0;
1400 virtual float classify(const cv::Mat &image) const = 0; 1403 virtual float classify(const cv::Mat &image) const = 0;
1401 }; 1404 };
@@ -1487,10 +1490,14 @@ Q_DECLARE_METATYPE(br::Template) @@ -1487,10 +1490,14 @@ Q_DECLARE_METATYPE(br::Template)
1487 Q_DECLARE_METATYPE(br::TemplateList) 1490 Q_DECLARE_METATYPE(br::TemplateList)
1488 Q_DECLARE_METATYPE(br::Transform*) 1491 Q_DECLARE_METATYPE(br::Transform*)
1489 Q_DECLARE_METATYPE(br::Distance*) 1492 Q_DECLARE_METATYPE(br::Distance*)
  1493 +Q_DECLARE_METATYPE(br::Representation*)
  1494 +Q_DECLARE_METATYPE(br::Classifier*)
1490 Q_DECLARE_METATYPE(QList<int>) 1495 Q_DECLARE_METATYPE(QList<int>)
1491 Q_DECLARE_METATYPE(QList<float>) 1496 Q_DECLARE_METATYPE(QList<float>)
1492 Q_DECLARE_METATYPE(QList<br::Transform*>) 1497 Q_DECLARE_METATYPE(QList<br::Transform*>)
1493 Q_DECLARE_METATYPE(QList<br::Distance*>) 1498 Q_DECLARE_METATYPE(QList<br::Distance*>)
  1499 +Q_DECLARE_METATYPE(QList<br::Representation*>)
  1500 +Q_DECLARE_METATYPE(QList<br::Classifier*>)
1494 1501
1495 #endif // __cplusplus 1502 #endif // __cplusplus
1496 1503