Commit f60a7d57ca655d087c3656b1e0f49b7d5b1054e3
1 parent
e9000d40
refactored Turk interface
Showing
3 changed files
with
74 additions
and
148 deletions
openbr/core/common.h
| @@ -116,15 +116,24 @@ T Max(const QList<T> &vals) | @@ -116,15 +116,24 @@ T Max(const QList<T> &vals) | ||
| 116 | } | 116 | } |
| 117 | 117 | ||
| 118 | /*! | 118 | /*! |
| 119 | + * \brief Returns the sum of a vector of values. | ||
| 120 | + */ | ||
| 121 | +template <template<class> class V, typename T> | ||
| 122 | +double Sum(const V<T> &vals) | ||
| 123 | +{ | ||
| 124 | + double sum = 0; | ||
| 125 | + foreach (T val, vals) sum += val; | ||
| 126 | + return sum; | ||
| 127 | +} | ||
| 128 | + | ||
| 129 | +/*! | ||
| 119 | * \brief Returns the mean and standard deviation of a vector of values. | 130 | * \brief Returns the mean and standard deviation of a vector of values. |
| 120 | */ | 131 | */ |
| 121 | template <template<class> class V, typename T> | 132 | template <template<class> class V, typename T> |
| 122 | double Mean(const V<T> &vals) | 133 | double Mean(const V<T> &vals) |
| 123 | { | 134 | { |
| 124 | if (vals.isEmpty()) return 0; | 135 | if (vals.isEmpty()) return 0; |
| 125 | - double sum = 0; | ||
| 126 | - foreach (T val, vals) sum += val; | ||
| 127 | - return sum / vals.size(); | 136 | + return Sum(vals) / vals.size(); |
| 128 | } | 137 | } |
| 129 | 138 | ||
| 130 | /*! | 139 | /*! |
openbr/plugins/output.cpp
| @@ -365,15 +365,22 @@ BR_REGISTER(Output, EmptyOutput) | @@ -365,15 +365,22 @@ BR_REGISTER(Output, EmptyOutput) | ||
| 365 | class evalOutput : public MatrixOutput | 365 | class evalOutput : public MatrixOutput |
| 366 | { | 366 | { |
| 367 | Q_OBJECT | 367 | Q_OBJECT |
| 368 | + Q_PROPERTY(QString target READ get_target WRITE set_target RESET reset_target STORED false) | ||
| 369 | + Q_PROPERTY(QString query READ get_query WRITE set_query RESET reset_query STORED false) | ||
| 368 | Q_PROPERTY(bool crossValidate READ get_crossValidate WRITE set_crossValidate RESET reset_crossValidate STORED false) | 370 | Q_PROPERTY(bool crossValidate READ get_crossValidate WRITE set_crossValidate RESET reset_crossValidate STORED false) |
| 369 | BR_PROPERTY(bool, crossValidate, true) | 371 | BR_PROPERTY(bool, crossValidate, true) |
| 372 | + BR_PROPERTY(QString, target, QString()) | ||
| 373 | + BR_PROPERTY(QString, query, QString()) | ||
| 370 | 374 | ||
| 371 | ~evalOutput() | 375 | ~evalOutput() |
| 372 | { | 376 | { |
| 377 | + if (!target.isEmpty()) targetFiles = TemplateList::fromGallery(target).files(); | ||
| 378 | + if (!query.isEmpty()) queryFiles = TemplateList::fromGallery(query).files(); | ||
| 379 | + | ||
| 373 | if (data.data) { | 380 | if (data.data) { |
| 374 | const QString csv = QString(file.name).replace(".eval", ".csv"); | 381 | const QString csv = QString(file.name).replace(".eval", ".csv"); |
| 375 | if ((Globals->crossValidate == 0) || (!crossValidate)) { | 382 | if ((Globals->crossValidate == 0) || (!crossValidate)) { |
| 376 | - Evaluate(data,targetFiles, queryFiles, csv); | 383 | + Evaluate(data, targetFiles, queryFiles, csv); |
| 377 | } else { | 384 | } else { |
| 378 | QFutureSynchronizer<float> futures; | 385 | QFutureSynchronizer<float> futures; |
| 379 | for (int i=0; i<Globals->crossValidate; i++) | 386 | for (int i=0; i<Globals->crossValidate; i++) |
openbr/plugins/turk.cpp
| 1 | #include "openbr_internal.h" | 1 | #include "openbr_internal.h" |
| 2 | +#include "openbr/core/common.h" | ||
| 2 | #include "openbr/core/qtutils.h" | 3 | #include "openbr/core/qtutils.h" |
| 3 | 4 | ||
| 4 | namespace br | 5 | namespace br |
| @@ -13,41 +14,61 @@ class turkGallery : public Gallery | @@ -13,41 +14,61 @@ class turkGallery : public Gallery | ||
| 13 | { | 14 | { |
| 14 | Q_OBJECT | 15 | Q_OBJECT |
| 15 | 16 | ||
| 16 | - TemplateList readBlock(bool *done) | 17 | + struct Attribute : public QStringList |
| 17 | { | 18 | { |
| 18 | - *done = true; | ||
| 19 | - TemplateList templates; | ||
| 20 | - if (!file.exists()) return templates; | 19 | + QString name; |
| 20 | + Attribute(const QString &str = QString()) | ||
| 21 | + { | ||
| 22 | + const int i = str.indexOf('['); | ||
| 23 | + name = str.mid(0, i); | ||
| 24 | + if (i != -1) | ||
| 25 | + append(str.mid(i+1, str.length()-i-2).split(",")); | ||
| 26 | + } | ||
| 21 | 27 | ||
| 22 | - QStringList lines = QtUtils::readLines(file); | ||
| 23 | - QRegExp regexp(",(?!(?:\\w+,?)+\\])"); | 28 | + Attribute normalized() const |
| 29 | + { | ||
| 30 | + bool ok; | ||
| 31 | + QList<float> values; | ||
| 32 | + foreach (const QString &value, *this) { | ||
| 33 | + values.append(value.toFloat(&ok)); | ||
| 34 | + if (!ok) | ||
| 35 | + qFatal("Can't normalize non-numeric vector!"); | ||
| 36 | + } | ||
| 24 | 37 | ||
| 25 | - QStringList headers; | 38 | + Attribute normal(name); |
| 39 | + float sum = Common::Sum(values); | ||
| 40 | + if (sum == 0) sum = 1; | ||
| 41 | + for (int i=0; i<values.size(); i++) | ||
| 42 | + normal.append(QString::number(values[i] / sum)); | ||
| 43 | + return normal; | ||
| 44 | + } | ||
| 45 | + }; | ||
| 26 | 46 | ||
| 27 | - if (!lines.isEmpty()) headers = lines.takeFirst().split(regexp); | 47 | + TemplateList readBlock(bool *done) |
| 48 | + { | ||
| 49 | + *done = true; | ||
| 50 | + QStringList lines = QtUtils::readLines(file); | ||
| 51 | + QList<Attribute> headers; | ||
| 52 | + if (!lines.isEmpty()) | ||
| 53 | + foreach (const QString &header, parse(lines.takeFirst())) | ||
| 54 | + headers.append(header); | ||
| 28 | 55 | ||
| 56 | + TemplateList templates; | ||
| 29 | foreach (const QString &line, lines) { | 57 | foreach (const QString &line, lines) { |
| 30 | - QStringList words = line.split(regexp); | ||
| 31 | - if (words.size() != headers.size()) continue; | 58 | + QStringList words = parse(line); |
| 59 | + if (words.size() != headers.size()) | ||
| 60 | + qFatal("turkGallery invalid column count"); | ||
| 61 | + | ||
| 32 | File f; | 62 | File f; |
| 33 | f.name = words[0]; | 63 | f.name = words[0]; |
| 34 | f.set("Label", words[0].mid(0,5)); | 64 | f.set("Label", words[0].mid(0,5)); |
| 35 | 65 | ||
| 36 | for (int i=1; i<words.size(); i++) { | 66 | for (int i=1; i<words.size(); i++) { |
| 37 | - QStringList categories = headers[i].split('['); | ||
| 38 | - categories.last().chop(1); // Remove trailing bracket | ||
| 39 | - QStringList types = categories.last().split(','); | ||
| 40 | - | ||
| 41 | - QStringList ratings = words[i].split(','); | ||
| 42 | - ratings.first() = ratings.first().mid(1); // Remove first bracket | ||
| 43 | - ratings.last().chop(1); // Remove trailing bracket | ||
| 44 | - | ||
| 45 | - if (types.size() != ratings.size()) continue; | ||
| 46 | - | ||
| 47 | - QMap<QString,QVariant> categoryMap; | ||
| 48 | - for (int j=0; j<types.size(); j++) categoryMap.insert(types[j],ratings[j]); | ||
| 49 | - | ||
| 50 | - f.set(categories[0], categoryMap); | 67 | + Attribute ratings = Attribute(words[i]).normalized(); |
| 68 | + if (headers[i].size() != ratings.size()) | ||
| 69 | + qFatal("turkGallery invalid attribute count"); | ||
| 70 | + for (int j=0; j<ratings.size(); j++) | ||
| 71 | + f.set(headers[i].name + "_" + headers[i][j], ratings[j]); | ||
| 51 | } | 72 | } |
| 52 | templates.append(f); | 73 | templates.append(f); |
| 53 | } | 74 | } |
| @@ -63,57 +84,6 @@ class turkGallery : public Gallery | @@ -63,57 +84,6 @@ class turkGallery : public Gallery | ||
| 63 | 84 | ||
| 64 | BR_REGISTER(Gallery, turkGallery) | 85 | BR_REGISTER(Gallery, turkGallery) |
| 65 | 86 | ||
| 66 | -static Template unmap(const Template &t, const QString& variable, const float maxVotes, const float maxRange, const float minRange, const bool classify, const bool consensusOnly) { | ||
| 67 | - // Create a new template matching the one containing the votes in the map structure | ||
| 68 | - // but remove the map structure | ||
| 69 | - Template expandedT = t; | ||
| 70 | - | ||
| 71 | - QMap<QString,QVariant> map = t.file.get<QMap<QString,QVariant> >(variable); | ||
| 72 | - QMapIterator<QString, QVariant> i(map); | ||
| 73 | - bool ok; | ||
| 74 | - | ||
| 75 | - while (i.hasNext()) { | ||
| 76 | - i.next(); | ||
| 77 | - // Normalize to [minRange,maxRange] | ||
| 78 | - float value = i.value().toFloat(&ok)*(maxRange-minRange)/maxVotes - minRange; | ||
| 79 | - if (!ok) qFatal("Failed to expand Turk votes for %s", qPrintable(variable)); | ||
| 80 | - if (classify) (value > maxRange-((maxRange-minRange)/2)) ? value = maxRange : value = minRange; | ||
| 81 | - else if (consensusOnly && (value != maxRange && value != minRange)) continue; | ||
| 82 | - expandedT.file.set(i.key(),value); | ||
| 83 | - } | ||
| 84 | - | ||
| 85 | - return expandedT; | ||
| 86 | -} | ||
| 87 | - | ||
| 88 | -/*! | ||
| 89 | - * \ingroup transforms | ||
| 90 | - * \brief Converts Amazon MTurk labels to a non-map format for use in a transform | ||
| 91 | - * \author Scott Klum \cite sklum | ||
| 92 | - */ | ||
| 93 | -class TurkTransform : public UntrainableTransform | ||
| 94 | -{ | ||
| 95 | - Q_OBJECT | ||
| 96 | - Q_PROPERTY(QString HIT READ get_HIT WRITE set_HIT RESET reset_HIT STORED false) | ||
| 97 | - Q_PROPERTY(float maxVotes READ get_maxVotes WRITE set_maxVotes RESET reset_maxVotes STORED false) | ||
| 98 | - Q_PROPERTY(float maxRange READ get_maxRange WRITE set_maxRange RESET reset_maxRange STORED false) | ||
| 99 | - Q_PROPERTY(float minRange READ get_minRange WRITE set_minRange RESET reset_minRange STORED false) | ||
| 100 | - Q_PROPERTY(bool classify READ get_classify WRITE set_classify RESET reset_classify STORED false) | ||
| 101 | - Q_PROPERTY(bool consensusOnly READ get_consensusOnly WRITE set_consensusOnly RESET reset_consensusOnly STORED false) | ||
| 102 | - BR_PROPERTY(QString, HIT, QString()) | ||
| 103 | - BR_PROPERTY(float, maxVotes, 1) | ||
| 104 | - BR_PROPERTY(float, maxRange, 1) | ||
| 105 | - BR_PROPERTY(float, minRange, 0) | ||
| 106 | - BR_PROPERTY(bool, classify, false) | ||
| 107 | - BR_PROPERTY(bool, consensusOnly, false) | ||
| 108 | - | ||
| 109 | - void project(const Template &src, Template &dst) const | ||
| 110 | - { | ||
| 111 | - dst = unmap(src, HIT, maxVotes, maxRange, minRange, classify, consensusOnly); | ||
| 112 | - } | ||
| 113 | -}; | ||
| 114 | - | ||
| 115 | -BR_REGISTER(Transform, TurkTransform) | ||
| 116 | - | ||
| 117 | /*! | 87 | /*! |
| 118 | * \ingroup transforms | 88 | * \ingroup transforms |
| 119 | * \brief Convenience class for training turk attribute regressors | 89 | * \brief Convenience class for training turk attribute regressors |
| @@ -124,23 +94,17 @@ class TurkClassifierTransform : public Transform | @@ -124,23 +94,17 @@ class TurkClassifierTransform : public Transform | ||
| 124 | Q_OBJECT | 94 | Q_OBJECT |
| 125 | Q_PROPERTY(QString key READ get_key WRITE set_key RESET reset_key STORED false) | 95 | Q_PROPERTY(QString key READ get_key WRITE set_key RESET reset_key STORED false) |
| 126 | Q_PROPERTY(QStringList values READ get_values WRITE set_values RESET reset_values STORED false) | 96 | Q_PROPERTY(QStringList values READ get_values WRITE set_values RESET reset_values STORED false) |
| 127 | - Q_PROPERTY(float maxVotes READ get_maxVotes WRITE set_maxVotes RESET reset_maxVotes STORED false) | ||
| 128 | BR_PROPERTY(QString, key, QString()) | 97 | BR_PROPERTY(QString, key, QString()) |
| 129 | BR_PROPERTY(QStringList, values, QStringList()) | 98 | BR_PROPERTY(QStringList, values, QStringList()) |
| 130 | - BR_PROPERTY(float, maxVotes, 1) | ||
| 131 | 99 | ||
| 132 | Transform *child; | 100 | Transform *child; |
| 133 | 101 | ||
| 134 | void init() | 102 | void init() |
| 135 | { | 103 | { |
| 136 | - QString algorithm = QString("Turk(%1, %2)+").arg(key, QString::number(maxVotes)); | ||
| 137 | QStringList classifiers; | 104 | QStringList classifiers; |
| 138 | foreach (const QString &value, values) | 105 | foreach (const QString &value, values) |
| 139 | - classifiers.append(QString("SVM(RBF,EPS_SVR,returnDFVal=true,inputVariable=%1,outputVariable=predicted_%1)").arg(value)); | ||
| 140 | - algorithm += classifiers.join("/"); | ||
| 141 | - if (values.size() > 1) | ||
| 142 | - algorithm += "+Cat"; | ||
| 143 | - child = Transform::make(algorithm); | 106 | + classifiers.append(QString("SVM(RBF,EPS_SVR,returnDFVal=true,inputVariable=%1,outputVariable=predicted_%1)").arg(key + "_" + value)); |
| 107 | + child = Transform::make(classifiers.join("/") + (classifiers.size() > 1 ? "+Cat" : "")); | ||
| 144 | } | 108 | } |
| 145 | 109 | ||
| 146 | void train(const QList<TemplateList> &data) | 110 | void train(const QList<TemplateList> &data) |
| @@ -167,44 +131,6 @@ class TurkClassifierTransform : public Transform | @@ -167,44 +131,6 @@ class TurkClassifierTransform : public Transform | ||
| 167 | BR_REGISTER(Transform, TurkClassifierTransform) | 131 | BR_REGISTER(Transform, TurkClassifierTransform) |
| 168 | 132 | ||
| 169 | /*! | 133 | /*! |
| 170 | - * \ingroup transforms | ||
| 171 | - * \brief Converts metadata into a map structure | ||
| 172 | - * \author Scott Klum \cite sklum | ||
| 173 | - */ | ||
| 174 | -class MapTransform : public UntrainableTransform | ||
| 175 | -{ | ||
| 176 | - Q_OBJECT | ||
| 177 | - Q_PROPERTY(QStringList inputVariables READ get_inputVariables WRITE set_inputVariables RESET reset_inputVariables STORED false) | ||
| 178 | - Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | ||
| 179 | - BR_PROPERTY(QStringList, inputVariables, QStringList()) | ||
| 180 | - BR_PROPERTY(QString, outputVariable, QString()) | ||
| 181 | - | ||
| 182 | - void project(const Template &src, Template &dst) const | ||
| 183 | - { | ||
| 184 | - dst = map(src); | ||
| 185 | - } | ||
| 186 | - | ||
| 187 | - Template map(const Template &t) const { | ||
| 188 | - Template mappedT = t; | ||
| 189 | - QMap<QString,QVariant> map; | ||
| 190 | - | ||
| 191 | - foreach(const QString &s, inputVariables) { | ||
| 192 | - if (t.file.contains(s)) { | ||
| 193 | - map.insert(s,t.file.value(s)); | ||
| 194 | - mappedT.file.remove(s); | ||
| 195 | - } | ||
| 196 | - } | ||
| 197 | - | ||
| 198 | - if (!map.isEmpty()) mappedT.file.set(outputVariable,map); | ||
| 199 | - | ||
| 200 | - return mappedT; | ||
| 201 | - } | ||
| 202 | -}; | ||
| 203 | - | ||
| 204 | -BR_REGISTER(Transform, MapTransform) | ||
| 205 | - | ||
| 206 | - | ||
| 207 | -/*! | ||
| 208 | * \ingroup distances | 134 | * \ingroup distances |
| 209 | * \brief Unmaps Turk HITs to be compared against query mats | 135 | * \brief Unmaps Turk HITs to be compared against query mats |
| 210 | * \author Scott Klum \cite sklum | 136 | * \author Scott Klum \cite sklum |
| @@ -212,33 +138,17 @@ BR_REGISTER(Transform, MapTransform) | @@ -212,33 +138,17 @@ BR_REGISTER(Transform, MapTransform) | ||
| 212 | class TurkDistance : public Distance | 138 | class TurkDistance : public Distance |
| 213 | { | 139 | { |
| 214 | Q_OBJECT | 140 | Q_OBJECT |
| 215 | - Q_PROPERTY(QString HIT READ get_HIT WRITE set_HIT RESET reset_HIT) | ||
| 216 | - Q_PROPERTY(QStringList keys READ get_keys WRITE set_keys RESET reset_keys STORED false) | ||
| 217 | - Q_PROPERTY(float maxVotes READ get_maxVotes WRITE set_maxVotes RESET reset_maxVotes STORED false) | ||
| 218 | - Q_PROPERTY(float maxRange READ get_maxRange WRITE set_maxRange RESET reset_maxRange STORED false) | ||
| 219 | - Q_PROPERTY(float minRange READ get_minRange WRITE set_minRange RESET reset_minRange STORED false) | ||
| 220 | - Q_PROPERTY(bool classify READ get_classify WRITE set_classify RESET reset_classify STORED false) | ||
| 221 | - Q_PROPERTY(bool consensusOnly READ get_consensusOnly WRITE set_consensusOnly RESET reset_consensusOnly STORED false) | ||
| 222 | - BR_PROPERTY(QString, HIT, QString()) | ||
| 223 | - BR_PROPERTY(QStringList, keys, QStringList()) | ||
| 224 | - BR_PROPERTY(float, maxVotes, 1) | ||
| 225 | - BR_PROPERTY(float, maxRange, 1) | ||
| 226 | - BR_PROPERTY(float, minRange, 0) | ||
| 227 | - BR_PROPERTY(bool, classify, false) | ||
| 228 | - BR_PROPERTY(bool, consensusOnly, false) | 141 | + Q_PROPERTY(QString key READ get_key WRITE set_key RESET reset_key) |
| 142 | + Q_PROPERTY(QStringList values READ get_values WRITE set_values RESET reset_values STORED false) | ||
| 143 | + BR_PROPERTY(QString, key, QString()) | ||
| 144 | + BR_PROPERTY(QStringList, values, QStringList()) | ||
| 229 | 145 | ||
| 230 | float compare(const Template &target, const Template &query) const | 146 | float compare(const Template &target, const Template &query) const |
| 231 | { | 147 | { |
| 232 | - Template t = unmap(target, HIT, maxVotes, maxRange, minRange, classify, consensusOnly); | ||
| 233 | - | ||
| 234 | - QList<float> targetValues; | ||
| 235 | - foreach(const QString &s, keys) targetValues.append(t.file.get<float>(s)); | ||
| 236 | - | ||
| 237 | - float stddev = .75; | ||
| 238 | - | 148 | + const float stddev = .75; |
| 239 | float score = 0; | 149 | float score = 0; |
| 240 | - for (int i=0; i<targetValues.size(); i++) score += 1/(stddev*sqrt(2*CV_PI))*exp(-0.5*pow((query.m().at<float>(0,i)-targetValues[i])/stddev, 2)); | ||
| 241 | - | 150 | + for (int i=0; i<values.size(); i++) |
| 151 | + score += 1 / (stddev*sqrt(2*CV_PI)) * exp(-0.5*pow((query.m().at<float>(0,i)-target.file.get<float>(key + "_" + values[i]))/stddev, 2)); | ||
| 242 | return score; | 152 | return score; |
| 243 | } | 153 | } |
| 244 | }; | 154 | }; |