Commit 5b4706409f0abf4e2902d75397c6d0ba32221852
1 parent
d97f68bd
Updates to trainable transforms
Showing
4 changed files
with
35 additions
and
31 deletions
openbr/plugins/distance/fuse.cpp
| ... | ... | @@ -18,6 +18,8 @@ |
| 18 | 18 | |
| 19 | 19 | #include <openbr/plugins/openbr_internal.h> |
| 20 | 20 | |
| 21 | +#include <QtConcurrent> | |
| 22 | + | |
| 21 | 23 | namespace br |
| 22 | 24 | { |
| 23 | 25 | |
| ... | ... | @@ -31,26 +33,25 @@ class FuseDistance : public Distance |
| 31 | 33 | { |
| 32 | 34 | Q_OBJECT |
| 33 | 35 | Q_ENUMS(Operation) |
| 34 | - Q_PROPERTY(QStringList descriptions READ get_descriptions WRITE set_descriptions RESET reset_descriptions STORED false) | |
| 36 | + Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances) | |
| 35 | 37 | Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false) |
| 36 | 38 | Q_PROPERTY(QList<float> weights READ get_weights WRITE set_weights RESET reset_weights STORED false) |
| 37 | 39 | |
| 38 | - QList<br::Distance*> distances; | |
| 39 | - | |
| 40 | 40 | public: |
| 41 | 41 | /*!< */ |
| 42 | 42 | enum Operation {Mean, Sum, Max, Min}; |
| 43 | 43 | |
| 44 | 44 | private: |
| 45 | - BR_PROPERTY(QStringList, descriptions, QStringList() << "L2") | |
| 45 | + BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>()) | |
| 46 | 46 | BR_PROPERTY(Operation, operation, Mean) |
| 47 | 47 | BR_PROPERTY(QList<float>, weights, QList<float>()) |
| 48 | 48 | |
| 49 | - void init() | |
| 49 | + bool trainable() | |
| 50 | 50 | { |
| 51 | - for (int i=0; i<descriptions.size(); i++) { | |
| 52 | - distances.append(make(descriptions[i])); | |
| 53 | - } | |
| 51 | + for (int i=0; i<distances.size(); i++) | |
| 52 | + if (distances[i]->trainable()) | |
| 53 | + return true; | |
| 54 | + return false; | |
| 54 | 55 | } |
| 55 | 56 | |
| 56 | 57 | void train(const TemplateList &src) |
| ... | ... | @@ -62,10 +63,10 @@ private: |
| 62 | 63 | QList<TemplateList> partitionedSrc = src.partition(split); |
| 63 | 64 | |
| 64 | 65 | // Train on each of the partitions |
| 65 | - for (int i=0; i<descriptions.size(); i++) { | |
| 66 | - distances.append(make(descriptions[i])); | |
| 67 | - distances[i]->train(partitionedSrc[i]); | |
| 68 | - } | |
| 66 | + QFutureSynchronizer<void> futures; | |
| 67 | + for (int i=0; i<distances.size(); i++) | |
| 68 | + futures.addFuture(QtConcurrent::run(distances[i], &Distance::train, partitionedSrc[i])); | |
| 69 | + futures.waitForFinished(); | |
| 69 | 70 | } |
| 70 | 71 | |
| 71 | 72 | float compare(const Template &a, const Template &b) const |
| ... | ... | @@ -99,24 +100,6 @@ private: |
| 99 | 100 | } |
| 100 | 101 | return 0; |
| 101 | 102 | } |
| 102 | - | |
| 103 | - void store(QDataStream &stream) const | |
| 104 | - { | |
| 105 | - stream << distances.size(); | |
| 106 | - foreach (Distance *distance, distances) | |
| 107 | - distance->store(stream); | |
| 108 | - } | |
| 109 | - | |
| 110 | - void load(QDataStream &stream) | |
| 111 | - { | |
| 112 | - int numDistances; | |
| 113 | - stream >> numDistances; | |
| 114 | - for (int i=0; i<descriptions.size(); i++) { | |
| 115 | - distances.append(make(descriptions[i])); | |
| 116 | - distances[i]->load(stream); | |
| 117 | - } | |
| 118 | - | |
| 119 | - } | |
| 120 | 103 | }; |
| 121 | 104 | |
| 122 | 105 | BR_REGISTER(Distance, FuseDistance) | ... | ... |
openbr/plugins/distance/neglogplusone.cpp
| ... | ... | @@ -24,12 +24,17 @@ namespace br |
| 24 | 24 | * \brief Returns -log(distance(a,b)+1) |
| 25 | 25 | * \author Josh Klontz \cite jklontz |
| 26 | 26 | */ |
| 27 | -class NegativeLogPlusOneDistance : public UntrainableDistance | |
| 27 | +class NegativeLogPlusOneDistance : public Distance | |
| 28 | 28 | { |
| 29 | 29 | Q_OBJECT |
| 30 | 30 | Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) |
| 31 | 31 | BR_PROPERTY(br::Distance*, distance, NULL) |
| 32 | 32 | |
| 33 | + bool trainable() | |
| 34 | + { | |
| 35 | + return distance->trainable(); | |
| 36 | + } | |
| 37 | + | |
| 33 | 38 | void train(const TemplateList &src) |
| 34 | 39 | { |
| 35 | 40 | distance->train(src); | ... | ... |
openbr/plugins/distance/pipe.cpp
| ... | ... | @@ -36,6 +36,14 @@ class PipeDistance : public Distance |
| 36 | 36 | Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances) |
| 37 | 37 | BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>()) |
| 38 | 38 | |
| 39 | + bool trainable() | |
| 40 | + { | |
| 41 | + for (int i=0; i<distances.size(); i++) | |
| 42 | + if (distances[i]->trainable()) | |
| 43 | + return true; | |
| 44 | + return false; | |
| 45 | + } | |
| 46 | + | |
| 39 | 47 | void train(const TemplateList &data) |
| 40 | 48 | { |
| 41 | 49 | QFutureSynchronizer<void> futures; | ... | ... |
openbr/plugins/distance/sum.cpp
| ... | ... | @@ -32,6 +32,14 @@ class SumDistance : public UntrainableDistance |
| 32 | 32 | Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances) |
| 33 | 33 | BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>()) |
| 34 | 34 | |
| 35 | + bool trainable() | |
| 36 | + { | |
| 37 | + for (int i=0; i<distances.size(); i++) | |
| 38 | + if (distances[i]->trainable()) | |
| 39 | + return true; | |
| 40 | + return false; | |
| 41 | + } | |
| 42 | + | |
| 35 | 43 | void train(const TemplateList &data) |
| 36 | 44 | { |
| 37 | 45 | QFutureSynchronizer<void> futures; | ... | ... |