Commit 5b4706409f0abf4e2902d75397c6d0ba32221852

Authored by Scott Klum
1 parent d97f68bd

Updates to trainable transforms

openbr/plugins/distance/fuse.cpp
... ... @@ -18,6 +18,8 @@
18 18  
19 19 #include <openbr/plugins/openbr_internal.h>
20 20  
  21 +#include <QtConcurrent>
  22 +
21 23 namespace br
22 24 {
23 25  
... ... @@ -31,26 +33,25 @@ class FuseDistance : public Distance
31 33 {
32 34 Q_OBJECT
33 35 Q_ENUMS(Operation)
34   - Q_PROPERTY(QStringList descriptions READ get_descriptions WRITE set_descriptions RESET reset_descriptions STORED false)
  36 + Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances)
35 37 Q_PROPERTY(Operation operation READ get_operation WRITE set_operation RESET reset_operation STORED false)
36 38 Q_PROPERTY(QList<float> weights READ get_weights WRITE set_weights RESET reset_weights STORED false)
37 39  
38   - QList<br::Distance*> distances;
39   -
40 40 public:
41 41 /*!< */
42 42 enum Operation {Mean, Sum, Max, Min};
43 43  
44 44 private:
45   - BR_PROPERTY(QStringList, descriptions, QStringList() << "L2")
  45 + BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>())
46 46 BR_PROPERTY(Operation, operation, Mean)
47 47 BR_PROPERTY(QList<float>, weights, QList<float>())
48 48  
49   - void init()
  49 + bool trainable()
50 50 {
51   - for (int i=0; i<descriptions.size(); i++) {
52   - distances.append(make(descriptions[i]));
53   - }
  51 + for (int i=0; i<distances.size(); i++)
  52 + if (distances[i]->trainable())
  53 + return true;
  54 + return false;
54 55 }
55 56  
56 57 void train(const TemplateList &src)
... ... @@ -62,10 +63,10 @@ private:
62 63 QList<TemplateList> partitionedSrc = src.partition(split);
63 64  
64 65 // Train on each of the partitions
65   - for (int i=0; i<descriptions.size(); i++) {
66   - distances.append(make(descriptions[i]));
67   - distances[i]->train(partitionedSrc[i]);
68   - }
  66 + QFutureSynchronizer<void> futures;
  67 + for (int i=0; i<distances.size(); i++)
  68 + futures.addFuture(QtConcurrent::run(distances[i], &Distance::train, partitionedSrc[i]));
  69 + futures.waitForFinished();
69 70 }
70 71  
71 72 float compare(const Template &a, const Template &b) const
... ... @@ -99,24 +100,6 @@ private:
99 100 }
100 101 return 0;
101 102 }
102   -
103   - void store(QDataStream &stream) const
104   - {
105   - stream << distances.size();
106   - foreach (Distance *distance, distances)
107   - distance->store(stream);
108   - }
109   -
110   - void load(QDataStream &stream)
111   - {
112   - int numDistances;
113   - stream >> numDistances;
114   - for (int i=0; i<descriptions.size(); i++) {
115   - distances.append(make(descriptions[i]));
116   - distances[i]->load(stream);
117   - }
118   -
119   - }
120 103 };
121 104  
122 105 BR_REGISTER(Distance, FuseDistance)
... ...
openbr/plugins/distance/neglogplusone.cpp
... ... @@ -24,12 +24,17 @@ namespace br
24 24 * \brief Returns -log(distance(a,b)+1)
25 25 * \author Josh Klontz \cite jklontz
26 26 */
27   -class NegativeLogPlusOneDistance : public UntrainableDistance
  27 +class NegativeLogPlusOneDistance : public Distance
28 28 {
29 29 Q_OBJECT
30 30 Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false)
31 31 BR_PROPERTY(br::Distance*, distance, NULL)
32 32  
  33 + bool trainable()
  34 + {
  35 + return distance->trainable();
  36 + }
  37 +
33 38 void train(const TemplateList &src)
34 39 {
35 40 distance->train(src);
... ...
openbr/plugins/distance/pipe.cpp
... ... @@ -36,6 +36,14 @@ class PipeDistance : public Distance
36 36 Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances)
37 37 BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>())
38 38  
  39 + bool trainable()
  40 + {
  41 + for (int i=0; i<distances.size(); i++)
  42 + if (distances[i]->trainable())
  43 + return true;
  44 + return false;
  45 + }
  46 +
39 47 void train(const TemplateList &data)
40 48 {
41 49 QFutureSynchronizer<void> futures;
... ...
openbr/plugins/distance/sum.cpp
... ... @@ -32,6 +32,14 @@ class SumDistance : public UntrainableDistance
32 32 Q_PROPERTY(QList<br::Distance*> distances READ get_distances WRITE set_distances RESET reset_distances)
33 33 BR_PROPERTY(QList<br::Distance*>, distances, QList<br::Distance*>())
34 34  
  35 + bool trainable()
  36 + {
  37 + for (int i=0; i<distances.size(); i++)
  38 + if (distances[i]->trainable())
  39 + return true;
  40 + return false;
  41 + }
  42 +
35 43 void train(const TemplateList &data)
36 44 {
37 45 QFutureSynchronizer<void> futures;
... ...