Commit f60a7d57ca655d087c3656b1e0f49b7d5b1054e3

Authored by Josh Klontz
1 parent e9000d40

refactored Turk interface

openbr/core/common.h
... ... @@ -116,15 +116,24 @@ T Max(const QList<T> &vals)
116 116 }
117 117  
118 118 /*!
  119 + * \brief Returns the sum of a vector of values.
  120 + */
  121 +template <template<class> class V, typename T>
  122 +double Sum(const V<T> &vals)
  123 +{
  124 + double sum = 0;
  125 + foreach (T val, vals) sum += val;
  126 + return sum;
  127 +}
  128 +
  129 +/*!
119 130 * \brief Returns the mean and standard deviation of a vector of values.
120 131 */
121 132 template <template<class> class V, typename T>
122 133 double Mean(const V<T> &vals)
123 134 {
124 135 if (vals.isEmpty()) return 0;
125   - double sum = 0;
126   - foreach (T val, vals) sum += val;
127   - return sum / vals.size();
  136 + return Sum(vals) / vals.size();
128 137 }
129 138  
130 139 /*!
... ...
openbr/plugins/output.cpp
... ... @@ -365,15 +365,22 @@ BR_REGISTER(Output, EmptyOutput)
365 365 class evalOutput : public MatrixOutput
366 366 {
367 367 Q_OBJECT
  368 + Q_PROPERTY(QString target READ get_target WRITE set_target RESET reset_target STORED false)
  369 + Q_PROPERTY(QString query READ get_query WRITE set_query RESET reset_query STORED false)
368 370 Q_PROPERTY(bool crossValidate READ get_crossValidate WRITE set_crossValidate RESET reset_crossValidate STORED false)
369 371 BR_PROPERTY(bool, crossValidate, true)
  372 + BR_PROPERTY(QString, target, QString())
  373 + BR_PROPERTY(QString, query, QString())
370 374  
371 375 ~evalOutput()
372 376 {
  377 + if (!target.isEmpty()) targetFiles = TemplateList::fromGallery(target).files();
  378 + if (!query.isEmpty()) queryFiles = TemplateList::fromGallery(query).files();
  379 +
373 380 if (data.data) {
374 381 const QString csv = QString(file.name).replace(".eval", ".csv");
375 382 if ((Globals->crossValidate == 0) || (!crossValidate)) {
376   - Evaluate(data,targetFiles, queryFiles, csv);
  383 + Evaluate(data, targetFiles, queryFiles, csv);
377 384 } else {
378 385 QFutureSynchronizer<float> futures;
379 386 for (int i=0; i<Globals->crossValidate; i++)
... ...
openbr/plugins/turk.cpp
1 1 #include "openbr_internal.h"
  2 +#include "openbr/core/common.h"
2 3 #include "openbr/core/qtutils.h"
3 4  
4 5 namespace br
... ... @@ -13,41 +14,61 @@ class turkGallery : public Gallery
13 14 {
14 15 Q_OBJECT
15 16  
16   - TemplateList readBlock(bool *done)
  17 + struct Attribute : public QStringList
17 18 {
18   - *done = true;
19   - TemplateList templates;
20   - if (!file.exists()) return templates;
  19 + QString name;
  20 + Attribute(const QString &str = QString())
  21 + {
  22 + const int i = str.indexOf('[');
  23 + name = str.mid(0, i);
  24 + if (i != -1)
  25 + append(str.mid(i+1, str.length()-i-2).split(","));
  26 + }
21 27  
22   - QStringList lines = QtUtils::readLines(file);
23   - QRegExp regexp(",(?!(?:\\w+,?)+\\])");
  28 + Attribute normalized() const
  29 + {
  30 + bool ok;
  31 + QList<float> values;
  32 + foreach (const QString &value, *this) {
  33 + values.append(value.toFloat(&ok));
  34 + if (!ok)
  35 + qFatal("Can't normalize non-numeric vector!");
  36 + }
24 37  
25   - QStringList headers;
  38 + Attribute normal(name);
  39 + float sum = Common::Sum(values);
  40 + if (sum == 0) sum = 1;
  41 + for (int i=0; i<values.size(); i++)
  42 + normal.append(QString::number(values[i] / sum));
  43 + return normal;
  44 + }
  45 + };
26 46  
27   - if (!lines.isEmpty()) headers = lines.takeFirst().split(regexp);
  47 + TemplateList readBlock(bool *done)
  48 + {
  49 + *done = true;
  50 + QStringList lines = QtUtils::readLines(file);
  51 + QList<Attribute> headers;
  52 + if (!lines.isEmpty())
  53 + foreach (const QString &header, parse(lines.takeFirst()))
  54 + headers.append(header);
28 55  
  56 + TemplateList templates;
29 57 foreach (const QString &line, lines) {
30   - QStringList words = line.split(regexp);
31   - if (words.size() != headers.size()) continue;
  58 + QStringList words = parse(line);
  59 + if (words.size() != headers.size())
  60 + qFatal("turkGallery invalid column count");
  61 +
32 62 File f;
33 63 f.name = words[0];
34 64 f.set("Label", words[0].mid(0,5));
35 65  
36 66 for (int i=1; i<words.size(); i++) {
37   - QStringList categories = headers[i].split('[');
38   - categories.last().chop(1); // Remove trailing bracket
39   - QStringList types = categories.last().split(',');
40   -
41   - QStringList ratings = words[i].split(',');
42   - ratings.first() = ratings.first().mid(1); // Remove first bracket
43   - ratings.last().chop(1); // Remove trailing bracket
44   -
45   - if (types.size() != ratings.size()) continue;
46   -
47   - QMap<QString,QVariant> categoryMap;
48   - for (int j=0; j<types.size(); j++) categoryMap.insert(types[j],ratings[j]);
49   -
50   - f.set(categories[0], categoryMap);
  67 + Attribute ratings = Attribute(words[i]).normalized();
  68 + if (headers[i].size() != ratings.size())
  69 + qFatal("turkGallery invalid attribute count");
  70 + for (int j=0; j<ratings.size(); j++)
  71 + f.set(headers[i].name + "_" + headers[i][j], ratings[j]);
51 72 }
52 73 templates.append(f);
53 74 }
... ... @@ -63,57 +84,6 @@ class turkGallery : public Gallery
63 84  
64 85 BR_REGISTER(Gallery, turkGallery)
65 86  
66   -static Template unmap(const Template &t, const QString& variable, const float maxVotes, const float maxRange, const float minRange, const bool classify, const bool consensusOnly) {
67   - // Create a new template matching the one containing the votes in the map structure
68   - // but remove the map structure
69   - Template expandedT = t;
70   -
71   - QMap<QString,QVariant> map = t.file.get<QMap<QString,QVariant> >(variable);
72   - QMapIterator<QString, QVariant> i(map);
73   - bool ok;
74   -
75   - while (i.hasNext()) {
76   - i.next();
77   - // Normalize to [minRange,maxRange]
78   - float value = i.value().toFloat(&ok)*(maxRange-minRange)/maxVotes - minRange;
79   - if (!ok) qFatal("Failed to expand Turk votes for %s", qPrintable(variable));
80   - if (classify) (value > maxRange-((maxRange-minRange)/2)) ? value = maxRange : value = minRange;
81   - else if (consensusOnly && (value != maxRange && value != minRange)) continue;
82   - expandedT.file.set(i.key(),value);
83   - }
84   -
85   - return expandedT;
86   -}
87   -
88   -/*!
89   - * \ingroup transforms
90   - * \brief Converts Amazon MTurk labels to a non-map format for use in a transform
91   - * \author Scott Klum \cite sklum
92   - */
93   -class TurkTransform : public UntrainableTransform
94   -{
95   - Q_OBJECT
96   - Q_PROPERTY(QString HIT READ get_HIT WRITE set_HIT RESET reset_HIT STORED false)
97   - Q_PROPERTY(float maxVotes READ get_maxVotes WRITE set_maxVotes RESET reset_maxVotes STORED false)
98   - Q_PROPERTY(float maxRange READ get_maxRange WRITE set_maxRange RESET reset_maxRange STORED false)
99   - Q_PROPERTY(float minRange READ get_minRange WRITE set_minRange RESET reset_minRange STORED false)
100   - Q_PROPERTY(bool classify READ get_classify WRITE set_classify RESET reset_classify STORED false)
101   - Q_PROPERTY(bool consensusOnly READ get_consensusOnly WRITE set_consensusOnly RESET reset_consensusOnly STORED false)
102   - BR_PROPERTY(QString, HIT, QString())
103   - BR_PROPERTY(float, maxVotes, 1)
104   - BR_PROPERTY(float, maxRange, 1)
105   - BR_PROPERTY(float, minRange, 0)
106   - BR_PROPERTY(bool, classify, false)
107   - BR_PROPERTY(bool, consensusOnly, false)
108   -
109   - void project(const Template &src, Template &dst) const
110   - {
111   - dst = unmap(src, HIT, maxVotes, maxRange, minRange, classify, consensusOnly);
112   - }
113   -};
114   -
115   -BR_REGISTER(Transform, TurkTransform)
116   -
117 87 /*!
118 88 * \ingroup transforms
119 89 * \brief Convenience class for training turk attribute regressors
... ... @@ -124,23 +94,17 @@ class TurkClassifierTransform : public Transform
124 94 Q_OBJECT
125 95 Q_PROPERTY(QString key READ get_key WRITE set_key RESET reset_key STORED false)
126 96 Q_PROPERTY(QStringList values READ get_values WRITE set_values RESET reset_values STORED false)
127   - Q_PROPERTY(float maxVotes READ get_maxVotes WRITE set_maxVotes RESET reset_maxVotes STORED false)
128 97 BR_PROPERTY(QString, key, QString())
129 98 BR_PROPERTY(QStringList, values, QStringList())
130   - BR_PROPERTY(float, maxVotes, 1)
131 99  
132 100 Transform *child;
133 101  
134 102 void init()
135 103 {
136   - QString algorithm = QString("Turk(%1, %2)+").arg(key, QString::number(maxVotes));
137 104 QStringList classifiers;
138 105 foreach (const QString &value, values)
139   - classifiers.append(QString("SVM(RBF,EPS_SVR,returnDFVal=true,inputVariable=%1,outputVariable=predicted_%1)").arg(value));
140   - algorithm += classifiers.join("/");
141   - if (values.size() > 1)
142   - algorithm += "+Cat";
143   - child = Transform::make(algorithm);
  106 + classifiers.append(QString("SVM(RBF,EPS_SVR,returnDFVal=true,inputVariable=%1,outputVariable=predicted_%1)").arg(key + "_" + value));
  107 + child = Transform::make(classifiers.join("/") + (classifiers.size() > 1 ? "+Cat" : ""));
144 108 }
145 109  
146 110 void train(const QList<TemplateList> &data)
... ... @@ -167,44 +131,6 @@ class TurkClassifierTransform : public Transform
167 131 BR_REGISTER(Transform, TurkClassifierTransform)
168 132  
169 133 /*!
170   - * \ingroup transforms
171   - * \brief Converts metadata into a map structure
172   - * \author Scott Klum \cite sklum
173   - */
174   -class MapTransform : public UntrainableTransform
175   -{
176   - Q_OBJECT
177   - Q_PROPERTY(QStringList inputVariables READ get_inputVariables WRITE set_inputVariables RESET reset_inputVariables STORED false)
178   - Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)
179   - BR_PROPERTY(QStringList, inputVariables, QStringList())
180   - BR_PROPERTY(QString, outputVariable, QString())
181   -
182   - void project(const Template &src, Template &dst) const
183   - {
184   - dst = map(src);
185   - }
186   -
187   - Template map(const Template &t) const {
188   - Template mappedT = t;
189   - QMap<QString,QVariant> map;
190   -
191   - foreach(const QString &s, inputVariables) {
192   - if (t.file.contains(s)) {
193   - map.insert(s,t.file.value(s));
194   - mappedT.file.remove(s);
195   - }
196   - }
197   -
198   - if (!map.isEmpty()) mappedT.file.set(outputVariable,map);
199   -
200   - return mappedT;
201   - }
202   -};
203   -
204   -BR_REGISTER(Transform, MapTransform)
205   -
206   -
207   -/*!
208 134 * \ingroup distances
209 135 * \brief Unmaps Turk HITs to be compared against query mats
210 136 * \author Scott Klum \cite sklum
... ... @@ -212,33 +138,17 @@ BR_REGISTER(Transform, MapTransform)
212 138 class TurkDistance : public Distance
213 139 {
214 140 Q_OBJECT
215   - Q_PROPERTY(QString HIT READ get_HIT WRITE set_HIT RESET reset_HIT)
216   - Q_PROPERTY(QStringList keys READ get_keys WRITE set_keys RESET reset_keys STORED false)
217   - Q_PROPERTY(float maxVotes READ get_maxVotes WRITE set_maxVotes RESET reset_maxVotes STORED false)
218   - Q_PROPERTY(float maxRange READ get_maxRange WRITE set_maxRange RESET reset_maxRange STORED false)
219   - Q_PROPERTY(float minRange READ get_minRange WRITE set_minRange RESET reset_minRange STORED false)
220   - Q_PROPERTY(bool classify READ get_classify WRITE set_classify RESET reset_classify STORED false)
221   - Q_PROPERTY(bool consensusOnly READ get_consensusOnly WRITE set_consensusOnly RESET reset_consensusOnly STORED false)
222   - BR_PROPERTY(QString, HIT, QString())
223   - BR_PROPERTY(QStringList, keys, QStringList())
224   - BR_PROPERTY(float, maxVotes, 1)
225   - BR_PROPERTY(float, maxRange, 1)
226   - BR_PROPERTY(float, minRange, 0)
227   - BR_PROPERTY(bool, classify, false)
228   - BR_PROPERTY(bool, consensusOnly, false)
  141 + Q_PROPERTY(QString key READ get_key WRITE set_key RESET reset_key)
  142 + Q_PROPERTY(QStringList values READ get_values WRITE set_values RESET reset_values STORED false)
  143 + BR_PROPERTY(QString, key, QString())
  144 + BR_PROPERTY(QStringList, values, QStringList())
229 145  
230 146 float compare(const Template &target, const Template &query) const
231 147 {
232   - Template t = unmap(target, HIT, maxVotes, maxRange, minRange, classify, consensusOnly);
233   -
234   - QList<float> targetValues;
235   - foreach(const QString &s, keys) targetValues.append(t.file.get<float>(s));
236   -
237   - float stddev = .75;
238   -
  148 + const float stddev = .75;
239 149 float score = 0;
240   - for (int i=0; i<targetValues.size(); i++) score += 1/(stddev*sqrt(2*CV_PI))*exp(-0.5*pow((query.m().at<float>(0,i)-targetValues[i])/stddev, 2));
241   -
  150 + for (int i=0; i<values.size(); i++)
  151 + score += 1 / (stddev*sqrt(2*CV_PI)) * exp(-0.5*pow((query.m().at<float>(0,i)-target.file.get<float>(key + "_" + values[i]))/stddev, 2));
242 152 return score;
243 153 }
244 154 };
... ...