Commit 5b4706409f0abf4e2902d75397c6d0ba32221852

Authored by Scott Klum
1 parent d97f68bd

Updates to trainable transforms

openbr/plugins/distance/fuse.cpp
@@ -18,6 +18,8 @@ @@ -18,6 +18,8 @@
18 18
19 #include <openbr/plugins/openbr_internal.h> 19 #include <openbr/plugins/openbr_internal.h>
20 20
  21 +#include <QtConcurrent>
  22 +
21 namespace br 23 namespace br
22 { 24 {
23 25
@@ -31,26 +33,25 @@ class FuseDistance : public Distance @@ -31,26 +33,25 @@ class FuseDistance : public Distance
31 { 33 {
32 Q_OBJECT 34 Q_OBJECT
33 Q_ENUMS(Operation) 35 Q_ENUMS(Operation)
34 - Q_PROPERTY(QStringList descriptions READ get_descriptions WRITE set_descriptions RESET reset_descriptions STORED false) 36 + Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances)
35 Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false) 37 Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false)
36 Q_PROPERTY(QList<float> weights READ get_weights WRITE set_weights RESET reset_weights STORED false) 38 Q_PROPERTY(QList<float> weights READ get_weights WRITE set_weights RESET reset_weights STORED false)
37 39
38 - QList<br::Distance*> distances;  
39 -  
40 public: 40 public:
41 /*!< */ 41 /*!< */
42 enum Operation {Mean, Sum, Max, Min}; 42 enum Operation {Mean, Sum, Max, Min};
43 43
44 private: 44 private:
45 - BR_PROPERTY(QStringList, descriptions, QStringList() << "L2") 45 + BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>())
46 BR_PROPERTY(Operation, operation, Mean) 46 BR_PROPERTY(Operation, operation, Mean)
47 BR_PROPERTY(QList<float>, weights, QList<float>()) 47 BR_PROPERTY(QList<float>, weights, QList<float>())
48 48
49 - void init() 49 + bool trainable()
50 { 50 {
51 - for (int i=0; i<descriptions.size(); i++) {  
52 - distances.append(make(descriptions[i]));  
53 - } 51 + for (int i=0; i<distances.size(); i++)
  52 + if (distances[i]->trainable())
  53 + return true;
  54 + return false;
54 } 55 }
55 56
56 void train(const TemplateList &src) 57 void train(const TemplateList &src)
@@ -62,10 +63,10 @@ private: @@ -62,10 +63,10 @@ private:
62 QList<TemplateList> partitionedSrc = src.partition(split); 63 QList<TemplateList> partitionedSrc = src.partition(split);
63 64
64 // Train on each of the partitions 65 // Train on each of the partitions
65 - for (int i=0; i<descriptions.size(); i++) {  
66 - distances.append(make(descriptions[i]));  
67 - distances[i]->train(partitionedSrc[i]);  
68 - } 66 + QFutureSynchronizer<void> futures;
  67 + for (int i=0; i<distances.size(); i++)
  68 + futures.addFuture(QtConcurrent::run(distances[i], &Distance::train, partitionedSrc[i]));
  69 + futures.waitForFinished();
69 } 70 }
70 71
71 float compare(const Template &a, const Template &b) const 72 float compare(const Template &a, const Template &b) const
@@ -99,24 +100,6 @@ private: @@ -99,24 +100,6 @@ private:
99 } 100 }
100 return 0; 101 return 0;
101 } 102 }
102 -  
103 - void store(QDataStream &stream) const  
104 - {  
105 - stream << distances.size();  
106 - foreach (Distance *distance, distances)  
107 - distance->store(stream);  
108 - }  
109 -  
110 - void load(QDataStream &stream)  
111 - {  
112 - int numDistances;  
113 - stream >> numDistances;  
114 - for (int i=0; i<descriptions.size(); i++) {  
115 - distances.append(make(descriptions[i]));  
116 - distances[i]->load(stream);  
117 - }  
118 -  
119 - }  
120 }; 103 };
121 104
122 BR_REGISTER(Distance, FuseDistance) 105 BR_REGISTER(Distance, FuseDistance)
openbr/plugins/distance/neglogplusone.cpp
@@ -24,12 +24,17 @@ namespace br @@ -24,12 +24,17 @@ namespace br
24 * \brief Returns -log(distance(a,b)+1) 24 * \brief Returns -log(distance(a,b)+1)
25 * \author Josh Klontz \cite jklontz 25 * \author Josh Klontz \cite jklontz
26 */ 26 */
27 -class NegativeLogPlusOneDistance : public UntrainableDistance 27 +class NegativeLogPlusOneDistance : public Distance
28 { 28 {
29 Q_OBJECT 29 Q_OBJECT
30 Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) 30 Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
31 BR_PROPERTY(br::Distance*, distance, NULL) 31 BR_PROPERTY(br::Distance*, distance, NULL)
32 32
  33 + bool trainable()
  34 + {
  35 + return distance->trainable();
  36 + }
  37 +
33 void train(const TemplateList &src) 38 void train(const TemplateList &src)
34 { 39 {
35 distance->train(src); 40 distance->train(src);
openbr/plugins/distance/pipe.cpp
@@ -36,6 +36,14 @@ class PipeDistance : public Distance @@ -36,6 +36,14 @@ class PipeDistance : public Distance
36 Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances) 36 Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances)
37 BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>()) 37 BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>())
38 38
  39 + bool trainable()
  40 + {
  41 + for (int i=0; i<distances.size(); i++)
  42 + if (distances[i]->trainable())
  43 + return true;
  44 + return false;
  45 + }
  46 +
39 void train(const TemplateList &data) 47 void train(const TemplateList &data)
40 { 48 {
41 QFutureSynchronizer<void> futures; 49 QFutureSynchronizer<void> futures;
openbr/plugins/distance/sum.cpp
@@ -32,6 +32,14 @@ class SumDistance : public UntrainableDistance @@ -32,6 +32,14 @@ class SumDistance : public UntrainableDistance
32 Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances) 32 Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances)
33 BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>()) 33 BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>())
34 34
  35 + bool trainable()
  36 + {
  37 + for (int i=0; i<distances.size(); i++)
  38 + if (distances[i]->trainable())
  39 + return true;
  40 + return false;
  41 + }
  42 +
35 void train(const TemplateList &data) 43 void train(const TemplateList &data)
36 { 44 {
37 QFutureSynchronizer<void> futures; 45 QFutureSynchronizer<void> futures;