Commit f60a7d57ca655d087c3656b1e0f49b7d5b1054e3

Authored by Josh Klontz
1 parent e9000d40

refactored Turk interface

openbr/core/common.h
@@ -116,15 +116,24 @@ T Max(const QList<T> &vals) @@ -116,15 +116,24 @@ T Max(const QList<T> &vals)
116 } 116 }
117 117
118 /*! 118 /*!
  119 + * \brief Returns the sum of a vector of values.
  120 + */
  121 +template <template<class> class V, typename T>
  122 +double Sum(const V<T> &vals)
  123 +{
  124 + double sum = 0;
  125 + foreach (T val, vals) sum += val;
  126 + return sum;
  127 +}
  128 +
  129 +/*!
119 * \brief Returns the mean and standard deviation of a vector of values. 130 * \brief Returns the mean and standard deviation of a vector of values.
120 */ 131 */
121 template <template<class> class V, typename T> 132 template <template<class> class V, typename T>
122 double Mean(const V<T> &vals) 133 double Mean(const V<T> &vals)
123 { 134 {
124 if (vals.isEmpty()) return 0; 135 if (vals.isEmpty()) return 0;
125 - double sum = 0;  
126 - foreach (T val, vals) sum += val;  
127 - return sum / vals.size(); 136 + return Sum(vals) / vals.size();
128 } 137 }
129 138
130 /*! 139 /*!
openbr/plugins/output.cpp
@@ -365,15 +365,22 @@ BR_REGISTER(Output, EmptyOutput) @@ -365,15 +365,22 @@ BR_REGISTER(Output, EmptyOutput)
365 class evalOutput : public MatrixOutput 365 class evalOutput : public MatrixOutput
366 { 366 {
367 Q_OBJECT 367 Q_OBJECT
  368 + Q_PROPERTY(QString target READ get_target WRITE set_target RESET reset_target STORED false)
  369 + Q_PROPERTY(QString query READ get_query WRITE set_query RESET reset_query STORED false)
368 Q_PROPERTY(bool crossValidate READ get_crossValidate WRITE set_crossValidate RESET reset_crossValidate STORED false) 370 Q_PROPERTY(bool crossValidate READ get_crossValidate WRITE set_crossValidate RESET reset_crossValidate STORED false)
369 BR_PROPERTY(bool, crossValidate, true) 371 BR_PROPERTY(bool, crossValidate, true)
  372 + BR_PROPERTY(QString, target, QString())
  373 + BR_PROPERTY(QString, query, QString())
370 374
371 ~evalOutput() 375 ~evalOutput()
372 { 376 {
  377 + if (!target.isEmpty()) targetFiles = TemplateList::fromGallery(target).files();
  378 + if (!query.isEmpty()) queryFiles = TemplateList::fromGallery(query).files();
  379 +
373 if (data.data) { 380 if (data.data) {
374 const QString csv = QString(file.name).replace(".eval", ".csv"); 381 const QString csv = QString(file.name).replace(".eval", ".csv");
375 if ((Globals->crossValidate == 0) || (!crossValidate)) { 382 if ((Globals->crossValidate == 0) || (!crossValidate)) {
376 - Evaluate(data,targetFiles, queryFiles, csv); 383 + Evaluate(data, targetFiles, queryFiles, csv);
377 } else { 384 } else {
378 QFutureSynchronizer<float> futures; 385 QFutureSynchronizer<float> futures;
379 for (int i=0; i<Globals->crossValidate; i++) 386 for (int i=0; i<Globals->crossValidate; i++)
openbr/plugins/turk.cpp
1 #include "openbr_internal.h" 1 #include "openbr_internal.h"
  2 +#include "openbr/core/common.h"
2 #include "openbr/core/qtutils.h" 3 #include "openbr/core/qtutils.h"
3 4
4 namespace br 5 namespace br
@@ -13,41 +14,61 @@ class turkGallery : public Gallery @@ -13,41 +14,61 @@ class turkGallery : public Gallery
13 { 14 {
14 Q_OBJECT 15 Q_OBJECT
15 16
16 - TemplateList readBlock(bool *done) 17 + struct Attribute : public QStringList
17 { 18 {
18 - *done = true;  
19 - TemplateList templates;  
20 - if (!file.exists()) return templates; 19 + QString name;
  20 + Attribute(const QString &str = QString())
  21 + {
  22 + const int i = str.indexOf('[');
  23 + name = str.mid(0, i);
  24 + if (i != -1)
  25 + append(str.mid(i+1, str.length()-i-2).split(","));
  26 + }
21 27
22 - QStringList lines = QtUtils::readLines(file);  
23 - QRegExp regexp(",(?!(?:\\w+,?)+\\])"); 28 + Attribute normalized() const
  29 + {
  30 + bool ok;
  31 + QList<float> values;
  32 + foreach (const QString &value, *this) {
  33 + values.append(value.toFloat(&ok));
  34 + if (!ok)
  35 + qFatal("Can't normalize non-numeric vector!");
  36 + }
24 37
25 - QStringList headers; 38 + Attribute normal(name);
  39 + float sum = Common::Sum(values);
  40 + if (sum == 0) sum = 1;
  41 + for (int i=0; i<values.size(); i++)
  42 + normal.append(QString::number(values[i] / sum));
  43 + return normal;
  44 + }
  45 + };
26 46
27 - if (!lines.isEmpty()) headers = lines.takeFirst().split(regexp); 47 + TemplateList readBlock(bool *done)
  48 + {
  49 + *done = true;
  50 + QStringList lines = QtUtils::readLines(file);
  51 + QList<Attribute> headers;
  52 + if (!lines.isEmpty())
  53 + foreach (const QString &header, parse(lines.takeFirst()))
  54 + headers.append(header);
28 55
  56 + TemplateList templates;
29 foreach (const QString &line, lines) { 57 foreach (const QString &line, lines) {
30 - QStringList words = line.split(regexp);  
31 - if (words.size() != headers.size()) continue; 58 + QStringList words = parse(line);
  59 + if (words.size() != headers.size())
  60 + qFatal("turkGallery invalid column count");
  61 +
32 File f; 62 File f;
33 f.name = words[0]; 63 f.name = words[0];
34 f.set("Label", words[0].mid(0,5)); 64 f.set("Label", words[0].mid(0,5));
35 65
36 for (int i=1; i<words.size(); i++) { 66 for (int i=1; i<words.size(); i++) {
37 - QStringList categories = headers[i].split('[');  
38 - categories.last().chop(1); // Remove trailing bracket  
39 - QStringList types = categories.last().split(',');  
40 -  
41 - QStringList ratings = words[i].split(',');  
42 - ratings.first() = ratings.first().mid(1); // Remove first bracket  
43 - ratings.last().chop(1); // Remove trailing bracket  
44 -  
45 - if (types.size() != ratings.size()) continue;  
46 -  
47 - QMap<QString,QVariant> categoryMap;  
48 - for (int j=0; j<types.size(); j++) categoryMap.insert(types[j],ratings[j]);  
49 -  
50 - f.set(categories[0], categoryMap); 67 + Attribute ratings = Attribute(words[i]).normalized();
  68 + if (headers[i].size() != ratings.size())
  69 + qFatal("turkGallery invalid attribute count");
  70 + for (int j=0; j<ratings.size(); j++)
  71 + f.set(headers[i].name + "_" + headers[i][j], ratings[j]);
51 } 72 }
52 templates.append(f); 73 templates.append(f);
53 } 74 }
@@ -63,57 +84,6 @@ class turkGallery : public Gallery @@ -63,57 +84,6 @@ class turkGallery : public Gallery
63 84
64 BR_REGISTER(Gallery, turkGallery) 85 BR_REGISTER(Gallery, turkGallery)
65 86
66 -static Template unmap(const Template &t, const QString& variable, const float maxVotes, const float maxRange, const float minRange, const bool classify, const bool consensusOnly) {  
67 - // Create a new template matching the one containing the votes in the map structure  
68 - // but remove the map structure  
69 - Template expandedT = t;  
70 -  
71 - QMap<QString,QVariant> map = t.file.get<QMap<QString,QVariant> >(variable);  
72 - QMapIterator<QString, QVariant> i(map);  
73 - bool ok;  
74 -  
75 - while (i.hasNext()) {  
76 - i.next();  
77 - // Normalize to [minRange,maxRange]  
78 - float value = i.value().toFloat(&ok)*(maxRange-minRange)/maxVotes - minRange;  
79 - if (!ok) qFatal("Failed to expand Turk votes for %s", qPrintable(variable));  
80 - if (classify) (value > maxRange-((maxRange-minRange)/2)) ? value = maxRange : value = minRange;  
81 - else if (consensusOnly && (value != maxRange && value != minRange)) continue;  
82 - expandedT.file.set(i.key(),value);  
83 - }  
84 -  
85 - return expandedT;  
86 -}  
87 -  
88 -/*!  
89 - * \ingroup transforms  
90 - * \brief Converts Amazon MTurk labels to a non-map format for use in a transform  
91 - * \author Scott Klum \cite sklum  
92 - */  
93 -class TurkTransform : public UntrainableTransform  
94 -{  
95 - Q_OBJECT  
96 - Q_PROPERTY(QString HIT READ get_HIT WRITE set_HIT RESET reset_HIT STORED false)  
97 - Q_PROPERTY(float maxVotes READ get_maxVotes WRITE set_maxVotes RESET reset_maxVotes STORED false)  
98 - Q_PROPERTY(float maxRange READ get_maxRange WRITE set_maxRange RESET reset_maxRange STORED false)  
99 - Q_PROPERTY(float minRange READ get_minRange WRITE set_minRange RESET reset_minRange STORED false)  
100 - Q_PROPERTY(bool classify READ get_classify WRITE set_classify RESET reset_classify STORED false)  
101 - Q_PROPERTY(bool consensusOnly READ get_consensusOnly WRITE set_consensusOnly RESET reset_consensusOnly STORED false)  
102 - BR_PROPERTY(QString, HIT, QString())  
103 - BR_PROPERTY(float, maxVotes, 1)  
104 - BR_PROPERTY(float, maxRange, 1)  
105 - BR_PROPERTY(float, minRange, 0)  
106 - BR_PROPERTY(bool, classify, false)  
107 - BR_PROPERTY(bool, consensusOnly, false)  
108 -  
109 - void project(const Template &src, Template &dst) const  
110 - {  
111 - dst = unmap(src, HIT, maxVotes, maxRange, minRange, classify, consensusOnly);  
112 - }  
113 -};  
114 -  
115 -BR_REGISTER(Transform, TurkTransform)  
116 -  
117 /*! 87 /*!
118 * \ingroup transforms 88 * \ingroup transforms
119 * \brief Convenience class for training turk attribute regressors 89 * \brief Convenience class for training turk attribute regressors
@@ -124,23 +94,17 @@ class TurkClassifierTransform : public Transform @@ -124,23 +94,17 @@ class TurkClassifierTransform : public Transform
124 Q_OBJECT 94 Q_OBJECT
125 Q_PROPERTY(QString key READ get_key WRITE set_key RESET reset_key STORED false) 95 Q_PROPERTY(QString key READ get_key WRITE set_key RESET reset_key STORED false)
126 Q_PROPERTY(QStringList values READ get_values WRITE set_values RESET reset_values STORED false) 96 Q_PROPERTY(QStringList values READ get_values WRITE set_values RESET reset_values STORED false)
127 - Q_PROPERTY(float maxVotes READ get_maxVotes WRITE set_maxVotes RESET reset_maxVotes STORED false)  
128 BR_PROPERTY(QString, key, QString()) 97 BR_PROPERTY(QString, key, QString())
129 BR_PROPERTY(QStringList, values, QStringList()) 98 BR_PROPERTY(QStringList, values, QStringList())
130 - BR_PROPERTY(float, maxVotes, 1)  
131 99
132 Transform *child; 100 Transform *child;
133 101
134 void init() 102 void init()
135 { 103 {
136 - QString algorithm = QString("Turk(%1, %2)+").arg(key, QString::number(maxVotes));  
137 QStringList classifiers; 104 QStringList classifiers;
138 foreach (const QString &value, values) 105 foreach (const QString &value, values)
139 - classifiers.append(QString("SVM(RBF,EPS_SVR,returnDFVal=true,inputVariable=%1,outputVariable=predicted_%1)").arg(value));  
140 - algorithm += classifiers.join("/");  
141 - if (values.size() > 1)  
142 - algorithm += "+Cat";  
143 - child = Transform::make(algorithm); 106 + classifiers.append(QString("SVM(RBF,EPS_SVR,returnDFVal=true,inputVariable=%1,outputVariable=predicted_%1)").arg(key + "_" + value));
  107 + child = Transform::make(classifiers.join("/") + (classifiers.size() > 1 ? "+Cat" : ""));
144 } 108 }
145 109
146 void train(const QList<TemplateList> &data) 110 void train(const QList<TemplateList> &data)
@@ -167,44 +131,6 @@ class TurkClassifierTransform : public Transform @@ -167,44 +131,6 @@ class TurkClassifierTransform : public Transform
167 BR_REGISTER(Transform, TurkClassifierTransform) 131 BR_REGISTER(Transform, TurkClassifierTransform)
168 132
169 /*! 133 /*!
170 - * \ingroup transforms  
171 - * \brief Converts metadata into a map structure  
172 - * \author Scott Klum \cite sklum  
173 - */  
174 -class MapTransform : public UntrainableTransform  
175 -{  
176 - Q_OBJECT  
177 - Q_PROPERTY(QStringList inputVariables READ get_inputVariables WRITE set_inputVariables RESET reset_inputVariables STORED false)  
178 - Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false)  
179 - BR_PROPERTY(QStringList, inputVariables, QStringList())  
180 - BR_PROPERTY(QString, outputVariable, QString())  
181 -  
182 - void project(const Template &src, Template &dst) const  
183 - {  
184 - dst = map(src);  
185 - }  
186 -  
187 - Template map(const Template &t) const {  
188 - Template mappedT = t;  
189 - QMap<QString,QVariant> map;  
190 -  
191 - foreach(const QString &s, inputVariables) {  
192 - if (t.file.contains(s)) {  
193 - map.insert(s,t.file.value(s));  
194 - mappedT.file.remove(s);  
195 - }  
196 - }  
197 -  
198 - if (!map.isEmpty()) mappedT.file.set(outputVariable,map);  
199 -  
200 - return mappedT;  
201 - }  
202 -};  
203 -  
204 -BR_REGISTER(Transform, MapTransform)  
205 -  
206 -  
207 -/*!  
208 * \ingroup distances 134 * \ingroup distances
209 * \brief Unmaps Turk HITs to be compared against query mats 135 * \brief Unmaps Turk HITs to be compared against query mats
210 * \author Scott Klum \cite sklum 136 * \author Scott Klum \cite sklum
@@ -212,33 +138,17 @@ BR_REGISTER(Transform, MapTransform) @@ -212,33 +138,17 @@ BR_REGISTER(Transform, MapTransform)
212 class TurkDistance : public Distance 138 class TurkDistance : public Distance
213 { 139 {
214 Q_OBJECT 140 Q_OBJECT
215 - Q_PROPERTY(QString HIT READ get_HIT WRITE set_HIT RESET reset_HIT)  
216 - Q_PROPERTY(QStringList keys READ get_keys WRITE set_keys RESET reset_keys STORED false)  
217 - Q_PROPERTY(float maxVotes READ get_maxVotes WRITE set_maxVotes RESET reset_maxVotes STORED false)  
218 - Q_PROPERTY(float maxRange READ get_maxRange WRITE set_maxRange RESET reset_maxRange STORED false)  
219 - Q_PROPERTY(float minRange READ get_minRange WRITE set_minRange RESET reset_minRange STORED false)  
220 - Q_PROPERTY(bool classify READ get_classify WRITE set_classify RESET reset_classify STORED false)  
221 - Q_PROPERTY(bool consensusOnly READ get_consensusOnly WRITE set_consensusOnly RESET reset_consensusOnly STORED false)  
222 - BR_PROPERTY(QString, HIT, QString())  
223 - BR_PROPERTY(QStringList, keys, QStringList())  
224 - BR_PROPERTY(float, maxVotes, 1)  
225 - BR_PROPERTY(float, maxRange, 1)  
226 - BR_PROPERTY(float, minRange, 0)  
227 - BR_PROPERTY(bool, classify, false)  
228 - BR_PROPERTY(bool, consensusOnly, false) 141 + Q_PROPERTY(QString key READ get_key WRITE set_key RESET reset_key)
  142 + Q_PROPERTY(QStringList values READ get_values WRITE set_values RESET reset_values STORED false)
  143 + BR_PROPERTY(QString, key, QString())
  144 + BR_PROPERTY(QStringList, values, QStringList())
229 145
230 float compare(const Template &target, const Template &query) const 146 float compare(const Template &target, const Template &query) const
231 { 147 {
232 - Template t = unmap(target, HIT, maxVotes, maxRange, minRange, classify, consensusOnly);  
233 -  
234 - QList<float> targetValues;  
235 - foreach(const QString &s, keys) targetValues.append(t.file.get<float>(s));  
236 -  
237 - float stddev = .75;  
238 - 148 + const float stddev = .75;
239 float score = 0; 149 float score = 0;
240 - for (int i=0; i<targetValues.size(); i++) score += 1/(stddev*sqrt(2*CV_PI))*exp(-0.5*pow((query.m().at<float>(0,i)-targetValues[i])/stddev, 2));  
241 - 150 + for (int i=0; i<values.size(); i++)
  151 + score += 1 / (stddev*sqrt(2*CV_PI)) * exp(-0.5*pow((query.m().at<float>(0,i)-target.file.get<float>(key + "_" + values[i]))/stddev, 2));
242 return score; 152 return score;
243 } 153 }
244 }; 154 };