Commit 5b4706409f0abf4e2902d75397c6d0ba32221852
1 parent
d97f68bd
Updates to trainable transforms
Showing
4 changed files
with
35 additions
and
31 deletions
openbr/plugins/distance/fuse.cpp
| @@ -18,6 +18,8 @@ | @@ -18,6 +18,8 @@ | ||
| 18 | 18 | ||
| 19 | #include <openbr/plugins/openbr_internal.h> | 19 | #include <openbr/plugins/openbr_internal.h> |
| 20 | 20 | ||
| 21 | +#include <QtConcurrent> | ||
| 22 | + | ||
| 21 | namespace br | 23 | namespace br |
| 22 | { | 24 | { |
| 23 | 25 | ||
| @@ -31,26 +33,25 @@ class FuseDistance : public Distance | @@ -31,26 +33,25 @@ class FuseDistance : public Distance | ||
| 31 | { | 33 | { |
| 32 | Q_OBJECT | 34 | Q_OBJECT |
| 33 | Q_ENUMS(Operation) | 35 | Q_ENUMS(Operation) |
| 34 | - Q_PROPERTY(QStringList descriptions READ get_descriptions WRITE set_descriptions RESET reset_descriptions STORED false) | 36 | + Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances) |
| 35 | Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false) | 37 | Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false) |
| 36 | Q_PROPERTY(QList<float> weights READ get_weights WRITE set_weights RESET reset_weights STORED false) | 38 | Q_PROPERTY(QList<float> weights READ get_weights WRITE set_weights RESET reset_weights STORED false) |
| 37 | 39 | ||
| 38 | - QList<br::Distance*> distances; | ||
| 39 | - | ||
| 40 | public: | 40 | public: |
| 41 | /*!< */ | 41 | /*!< */ |
| 42 | enum Operation {Mean, Sum, Max, Min}; | 42 | enum Operation {Mean, Sum, Max, Min}; |
| 43 | 43 | ||
| 44 | private: | 44 | private: |
| 45 | - BR_PROPERTY(QStringList, descriptions, QStringList() << "L2") | 45 | + BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>()) |
| 46 | BR_PROPERTY(Operation, operation, Mean) | 46 | BR_PROPERTY(Operation, operation, Mean) |
| 47 | BR_PROPERTY(QList<float>, weights, QList<float>()) | 47 | BR_PROPERTY(QList<float>, weights, QList<float>()) |
| 48 | 48 | ||
| 49 | - void init() | 49 | + bool trainable() |
| 50 | { | 50 | { |
| 51 | - for (int i=0; i<descriptions.size(); i++) { | ||
| 52 | - distances.append(make(descriptions[i])); | ||
| 53 | - } | 51 | + for (int i=0; i<distances.size(); i++) |
| 52 | + if (distances[i]->trainable()) | ||
| 53 | + return true; | ||
| 54 | + return false; | ||
| 54 | } | 55 | } |
| 55 | 56 | ||
| 56 | void train(const TemplateList &src) | 57 | void train(const TemplateList &src) |
| @@ -62,10 +63,10 @@ private: | @@ -62,10 +63,10 @@ private: | ||
| 62 | QList<TemplateList> partitionedSrc = src.partition(split); | 63 | QList<TemplateList> partitionedSrc = src.partition(split); |
| 63 | 64 | ||
| 64 | // Train on each of the partitions | 65 | // Train on each of the partitions |
| 65 | - for (int i=0; i<descriptions.size(); i++) { | ||
| 66 | - distances.append(make(descriptions[i])); | ||
| 67 | - distances[i]->train(partitionedSrc[i]); | ||
| 68 | - } | 66 | + QFutureSynchronizer<void> futures; |
| 67 | + for (int i=0; i<distances.size(); i++) | ||
| 68 | + futures.addFuture(QtConcurrent::run(distances[i], &Distance::train, partitionedSrc[i])); | ||
| 69 | + futures.waitForFinished(); | ||
| 69 | } | 70 | } |
| 70 | 71 | ||
| 71 | float compare(const Template &a, const Template &b) const | 72 | float compare(const Template &a, const Template &b) const |
| @@ -99,24 +100,6 @@ private: | @@ -99,24 +100,6 @@ private: | ||
| 99 | } | 100 | } |
| 100 | return 0; | 101 | return 0; |
| 101 | } | 102 | } |
| 102 | - | ||
| 103 | - void store(QDataStream &stream) const | ||
| 104 | - { | ||
| 105 | - stream << distances.size(); | ||
| 106 | - foreach (Distance *distance, distances) | ||
| 107 | - distance->store(stream); | ||
| 108 | - } | ||
| 109 | - | ||
| 110 | - void load(QDataStream &stream) | ||
| 111 | - { | ||
| 112 | - int numDistances; | ||
| 113 | - stream >> numDistances; | ||
| 114 | - for (int i=0; i<descriptions.size(); i++) { | ||
| 115 | - distances.append(make(descriptions[i])); | ||
| 116 | - distances[i]->load(stream); | ||
| 117 | - } | ||
| 118 | - | ||
| 119 | - } | ||
| 120 | }; | 103 | }; |
| 121 | 104 | ||
| 122 | BR_REGISTER(Distance, FuseDistance) | 105 | BR_REGISTER(Distance, FuseDistance) |
openbr/plugins/distance/neglogplusone.cpp
| @@ -24,12 +24,17 @@ namespace br | @@ -24,12 +24,17 @@ namespace br | ||
| 24 | * \brief Returns -log(distance(a,b)+1) | 24 | * \brief Returns -log(distance(a,b)+1) |
| 25 | * \author Josh Klontz \cite jklontz | 25 | * \author Josh Klontz \cite jklontz |
| 26 | */ | 26 | */ |
| 27 | -class NegativeLogPlusOneDistance : public UntrainableDistance | 27 | +class NegativeLogPlusOneDistance : public Distance |
| 28 | { | 28 | { |
| 29 | Q_OBJECT | 29 | Q_OBJECT |
| 30 | Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) | 30 | Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) |
| 31 | BR_PROPERTY(br::Distance*, distance, NULL) | 31 | BR_PROPERTY(br::Distance*, distance, NULL) |
| 32 | 32 | ||
| 33 | + bool trainable() | ||
| 34 | + { | ||
| 35 | + return distance->trainable(); | ||
| 36 | + } | ||
| 37 | + | ||
| 33 | void train(const TemplateList &src) | 38 | void train(const TemplateList &src) |
| 34 | { | 39 | { |
| 35 | distance->train(src); | 40 | distance->train(src); |
openbr/plugins/distance/pipe.cpp
| @@ -36,6 +36,14 @@ class PipeDistance : public Distance | @@ -36,6 +36,14 @@ class PipeDistance : public Distance | ||
| 36 | Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances) | 36 | Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances) |
| 37 | BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>()) | 37 | BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>()) |
| 38 | 38 | ||
| 39 | + bool trainable() | ||
| 40 | + { | ||
| 41 | + for (int i=0; i<distances.size(); i++) | ||
| 42 | + if (distances[i]->trainable()) | ||
| 43 | + return true; | ||
| 44 | + return false; | ||
| 45 | + } | ||
| 46 | + | ||
| 39 | void train(const TemplateList &data) | 47 | void train(const TemplateList &data) |
| 40 | { | 48 | { |
| 41 | QFutureSynchronizer<void> futures; | 49 | QFutureSynchronizer<void> futures; |
openbr/plugins/distance/sum.cpp
| @@ -32,6 +32,14 @@ class SumDistance : public UntrainableDistance | @@ -32,6 +32,14 @@ class SumDistance : public UntrainableDistance | ||
| 32 | Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances) | 32 | Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances) |
| 33 | BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>()) | 33 | BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>()) |
| 34 | 34 | ||
| 35 | + bool trainable() | ||
| 36 | + { | ||
| 37 | + for (int i=0; i<distances.size(); i++) | ||
| 38 | + if (distances[i]->trainable()) | ||
| 39 | + return true; | ||
| 40 | + return false; | ||
| 41 | + } | ||
| 42 | + | ||
| 35 | void train(const TemplateList &data) | 43 | void train(const TemplateList &data) |
| 36 | { | 44 | { |
| 37 | QFutureSynchronizer<void> futures; | 45 | QFutureSynchronizer<void> futures; |