Commit f60a7d57ca655d087c3656b1e0f49b7d5b1054e3
1 parent
e9000d40
refactored Turk interface
Showing
3 changed files
with
74 additions
and
148 deletions
openbr/core/common.h
| ... | ... | @@ -116,15 +116,24 @@ T Max(const QList<T> &vals) |
| 116 | 116 | } |
| 117 | 117 | |
| 118 | 118 | /*! |
| 119 | + * \brief Returns the sum of a vector of values. | |
| 120 | + */ | |
| 121 | +template <template<class> class V, typename T> | |
| 122 | +double Sum(const V<T> &vals) | |
| 123 | +{ | |
| 124 | + double sum = 0; | |
| 125 | + foreach (T val, vals) sum += val; | |
| 126 | + return sum; | |
| 127 | +} | |
| 128 | + | |
| 129 | +/*! | |
| 119 | 130 | * \brief Returns the mean and standard deviation of a vector of values. |
| 120 | 131 | */ |
| 121 | 132 | template <template<class> class V, typename T> |
| 122 | 133 | double Mean(const V<T> &vals) |
| 123 | 134 | { |
| 124 | 135 | if (vals.isEmpty()) return 0; |
| 125 | - double sum = 0; | |
| 126 | - foreach (T val, vals) sum += val; | |
| 127 | - return sum / vals.size(); | |
| 136 | + return Sum(vals) / vals.size(); | |
| 128 | 137 | } |
| 129 | 138 | |
| 130 | 139 | /*! | ... | ... |
openbr/plugins/output.cpp
| ... | ... | @@ -365,15 +365,22 @@ BR_REGISTER(Output, EmptyOutput) |
| 365 | 365 | class evalOutput : public MatrixOutput |
| 366 | 366 | { |
| 367 | 367 | Q_OBJECT |
| 368 | + Q_PROPERTY(QString target READ get_target WRITE set_target RESET reset_target STORED false) | |
| 369 | + Q_PROPERTY(QString query READ get_query WRITE set_query RESET reset_query STORED false) | |
| 368 | 370 | Q_PROPERTY(bool crossValidate READ get_crossValidate WRITE set_crossValidate RESET reset_crossValidate STORED false) |
| 369 | 371 | BR_PROPERTY(bool, crossValidate, true) |
| 372 | + BR_PROPERTY(QString, target, QString()) | |
| 373 | + BR_PROPERTY(QString, query, QString()) | |
| 370 | 374 | |
| 371 | 375 | ~evalOutput() |
| 372 | 376 | { |
| 377 | + if (!target.isEmpty()) targetFiles = TemplateList::fromGallery(target).files(); | |
| 378 | + if (!query.isEmpty()) queryFiles = TemplateList::fromGallery(query).files(); | |
| 379 | + | |
| 373 | 380 | if (data.data) { |
| 374 | 381 | const QString csv = QString(file.name).replace(".eval", ".csv"); |
| 375 | 382 | if ((Globals->crossValidate == 0) || (!crossValidate)) { |
| 376 | - Evaluate(data,targetFiles, queryFiles, csv); | |
| 383 | + Evaluate(data, targetFiles, queryFiles, csv); | |
| 377 | 384 | } else { |
| 378 | 385 | QFutureSynchronizer<float> futures; |
| 379 | 386 | for (int i=0; i<Globals->crossValidate; i++) | ... | ... |
openbr/plugins/turk.cpp
| 1 | 1 | #include "openbr_internal.h" |
| 2 | +#include "openbr/core/common.h" | |
| 2 | 3 | #include "openbr/core/qtutils.h" |
| 3 | 4 | |
| 4 | 5 | namespace br |
| ... | ... | @@ -13,41 +14,61 @@ class turkGallery : public Gallery |
| 13 | 14 | { |
| 14 | 15 | Q_OBJECT |
| 15 | 16 | |
| 16 | - TemplateList readBlock(bool *done) | |
| 17 | + struct Attribute : public QStringList | |
| 17 | 18 | { |
| 18 | - *done = true; | |
| 19 | - TemplateList templates; | |
| 20 | - if (!file.exists()) return templates; | |
| 19 | + QString name; | |
| 20 | + Attribute(const QString &str = QString()) | |
| 21 | + { | |
| 22 | + const int i = str.indexOf('['); | |
| 23 | + name = str.mid(0, i); | |
| 24 | + if (i != -1) | |
| 25 | + append(str.mid(i+1, str.length()-i-2).split(",")); | |
| 26 | + } | |
| 21 | 27 | |
| 22 | - QStringList lines = QtUtils::readLines(file); | |
| 23 | - QRegExp regexp(",(?!(?:\\w+,?)+\\])"); | |
| 28 | + Attribute normalized() const | |
| 29 | + { | |
| 30 | + bool ok; | |
| 31 | + QList<float> values; | |
| 32 | + foreach (const QString &value, *this) { | |
| 33 | + values.append(value.toFloat(&ok)); | |
| 34 | + if (!ok) | |
| 35 | + qFatal("Can't normalize non-numeric vector!"); | |
| 36 | + } | |
| 24 | 37 | |
| 25 | - QStringList headers; | |
| 38 | + Attribute normal(name); | |
| 39 | + float sum = Common::Sum(values); | |
| 40 | + if (sum == 0) sum = 1; | |
| 41 | + for (int i=0; i<values.size(); i++) | |
| 42 | + normal.append(QString::number(values[i] / sum)); | |
| 43 | + return normal; | |
| 44 | + } | |
| 45 | + }; | |
| 26 | 46 | |
| 27 | - if (!lines.isEmpty()) headers = lines.takeFirst().split(regexp); | |
| 47 | + TemplateList readBlock(bool *done) | |
| 48 | + { | |
| 49 | + *done = true; | |
| 50 | + QStringList lines = QtUtils::readLines(file); | |
| 51 | + QList<Attribute> headers; | |
| 52 | + if (!lines.isEmpty()) | |
| 53 | + foreach (const QString &header, parse(lines.takeFirst())) | |
| 54 | + headers.append(header); | |
| 28 | 55 | |
| 56 | + TemplateList templates; | |
| 29 | 57 | foreach (const QString &line, lines) { |
| 30 | - QStringList words = line.split(regexp); | |
| 31 | - if (words.size() != headers.size()) continue; | |
| 58 | + QStringList words = parse(line); | |
| 59 | + if (words.size() != headers.size()) | |
| 60 | + qFatal("turkGallery invalid column count"); | |
| 61 | + | |
| 32 | 62 | File f; |
| 33 | 63 | f.name = words[0]; |
| 34 | 64 | f.set("Label", words[0].mid(0,5)); |
| 35 | 65 | |
| 36 | 66 | for (int i=1; i<words.size(); i++) { |
| 37 | - QStringList categories = headers[i].split('['); | |
| 38 | - categories.last().chop(1); // Remove trailing bracket | |
| 39 | - QStringList types = categories.last().split(','); | |
| 40 | - | |
| 41 | - QStringList ratings = words[i].split(','); | |
| 42 | - ratings.first() = ratings.first().mid(1); // Remove first bracket | |
| 43 | - ratings.last().chop(1); // Remove trailing bracket | |
| 44 | - | |
| 45 | - if (types.size() != ratings.size()) continue; | |
| 46 | - | |
| 47 | - QMap<QString,QVariant> categoryMap; | |
| 48 | - for (int j=0; j<types.size(); j++) categoryMap.insert(types[j],ratings[j]); | |
| 49 | - | |
| 50 | - f.set(categories[0], categoryMap); | |
| 67 | + Attribute ratings = Attribute(words[i]).normalized(); | |
| 68 | + if (headers[i].size() != ratings.size()) | |
| 69 | + qFatal("turkGallery invalid attribute count"); | |
| 70 | + for (int j=0; j<ratings.size(); j++) | |
| 71 | + f.set(headers[i].name + "_" + headers[i][j], ratings[j]); | |
| 51 | 72 | } |
| 52 | 73 | templates.append(f); |
| 53 | 74 | } |
| ... | ... | @@ -63,57 +84,6 @@ class turkGallery : public Gallery |
| 63 | 84 | |
| 64 | 85 | BR_REGISTER(Gallery, turkGallery) |
| 65 | 86 | |
| 66 | -static Template unmap(const Template &t, const QString& variable, const float maxVotes, const float maxRange, const float minRange, const bool classify, const bool consensusOnly) { | |
| 67 | - // Create a new template matching the one containing the votes in the map structure | |
| 68 | - // but remove the map structure | |
| 69 | - Template expandedT = t; | |
| 70 | - | |
| 71 | - QMap<QString,QVariant> map = t.file.get<QMap<QString,QVariant> >(variable); | |
| 72 | - QMapIterator<QString, QVariant> i(map); | |
| 73 | - bool ok; | |
| 74 | - | |
| 75 | - while (i.hasNext()) { | |
| 76 | - i.next(); | |
| 77 | - // Normalize to [minRange,maxRange] | |
| 78 | - float value = i.value().toFloat(&ok)*(maxRange-minRange)/maxVotes - minRange; | |
| 79 | - if (!ok) qFatal("Failed to expand Turk votes for %s", qPrintable(variable)); | |
| 80 | - if (classify) (value > maxRange-((maxRange-minRange)/2)) ? value = maxRange : value = minRange; | |
| 81 | - else if (consensusOnly && (value != maxRange && value != minRange)) continue; | |
| 82 | - expandedT.file.set(i.key(),value); | |
| 83 | - } | |
| 84 | - | |
| 85 | - return expandedT; | |
| 86 | -} | |
| 87 | - | |
| 88 | -/*! | |
| 89 | - * \ingroup transforms | |
| 90 | - * \brief Converts Amazon MTurk labels to a non-map format for use in a transform | |
| 91 | - * \author Scott Klum \cite sklum | |
| 92 | - */ | |
| 93 | -class TurkTransform : public UntrainableTransform | |
| 94 | -{ | |
| 95 | - Q_OBJECT | |
| 96 | - Q_PROPERTY(QString HIT READ get_HIT WRITE set_HIT RESET reset_HIT STORED false) | |
| 97 | - Q_PROPERTY(float maxVotes READ get_maxVotes WRITE set_maxVotes RESET reset_maxVotes STORED false) | |
| 98 | - Q_PROPERTY(float maxRange READ get_maxRange WRITE set_maxRange RESET reset_maxRange STORED false) | |
| 99 | - Q_PROPERTY(float minRange READ get_minRange WRITE set_minRange RESET reset_minRange STORED false) | |
| 100 | - Q_PROPERTY(bool classify READ get_classify WRITE set_classify RESET reset_classify STORED false) | |
| 101 | - Q_PROPERTY(bool consensusOnly READ get_consensusOnly WRITE set_consensusOnly RESET reset_consensusOnly STORED false) | |
| 102 | - BR_PROPERTY(QString, HIT, QString()) | |
| 103 | - BR_PROPERTY(float, maxVotes, 1) | |
| 104 | - BR_PROPERTY(float, maxRange, 1) | |
| 105 | - BR_PROPERTY(float, minRange, 0) | |
| 106 | - BR_PROPERTY(bool, classify, false) | |
| 107 | - BR_PROPERTY(bool, consensusOnly, false) | |
| 108 | - | |
| 109 | - void project(const Template &src, Template &dst) const | |
| 110 | - { | |
| 111 | - dst = unmap(src, HIT, maxVotes, maxRange, minRange, classify, consensusOnly); | |
| 112 | - } | |
| 113 | -}; | |
| 114 | - | |
| 115 | -BR_REGISTER(Transform, TurkTransform) | |
| 116 | - | |
| 117 | 87 | /*! |
| 118 | 88 | * \ingroup transforms |
| 119 | 89 | * \brief Convenience class for training turk attribute regressors |
| ... | ... | @@ -124,23 +94,17 @@ class TurkClassifierTransform : public Transform |
| 124 | 94 | Q_OBJECT |
| 125 | 95 | Q_PROPERTY(QString key READ get_key WRITE set_key RESET reset_key STORED false) |
| 126 | 96 | Q_PROPERTY(QStringList values READ get_values WRITE set_values RESET reset_values STORED false) |
| 127 | - Q_PROPERTY(float maxVotes READ get_maxVotes WRITE set_maxVotes RESET reset_maxVotes STORED false) | |
| 128 | 97 | BR_PROPERTY(QString, key, QString()) |
| 129 | 98 | BR_PROPERTY(QStringList, values, QStringList()) |
| 130 | - BR_PROPERTY(float, maxVotes, 1) | |
| 131 | 99 | |
| 132 | 100 | Transform *child; |
| 133 | 101 | |
| 134 | 102 | void init() |
| 135 | 103 | { |
| 136 | - QString algorithm = QString("Turk(%1, %2)+").arg(key, QString::number(maxVotes)); | |
| 137 | 104 | QStringList classifiers; |
| 138 | 105 | foreach (const QString &value, values) |
| 139 | - classifiers.append(QString("SVM(RBF,EPS_SVR,returnDFVal=true,inputVariable=%1,outputVariable=predicted_%1)").arg(value)); | |
| 140 | - algorithm += classifiers.join("/"); | |
| 141 | - if (values.size() > 1) | |
| 142 | - algorithm += "+Cat"; | |
| 143 | - child = Transform::make(algorithm); | |
| 106 | + classifiers.append(QString("SVM(RBF,EPS_SVR,returnDFVal=true,inputVariable=%1,outputVariable=predicted_%1)").arg(key + "_" + value)); | |
| 107 | + child = Transform::make(classifiers.join("/") + (classifiers.size() > 1 ? "+Cat" : "")); | |
| 144 | 108 | } |
| 145 | 109 | |
| 146 | 110 | void train(const QList<TemplateList> &data) |
| ... | ... | @@ -167,44 +131,6 @@ class TurkClassifierTransform : public Transform |
| 167 | 131 | BR_REGISTER(Transform, TurkClassifierTransform) |
| 168 | 132 | |
| 169 | 133 | /*! |
| 170 | - * \ingroup transforms | |
| 171 | - * \brief Converts metadata into a map structure | |
| 172 | - * \author Scott Klum \cite sklum | |
| 173 | - */ | |
| 174 | -class MapTransform : public UntrainableTransform | |
| 175 | -{ | |
| 176 | - Q_OBJECT | |
| 177 | - Q_PROPERTY(QStringList inputVariables READ get_inputVariables WRITE set_inputVariables RESET reset_inputVariables STORED false) | |
| 178 | - Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | |
| 179 | - BR_PROPERTY(QStringList, inputVariables, QStringList()) | |
| 180 | - BR_PROPERTY(QString, outputVariable, QString()) | |
| 181 | - | |
| 182 | - void project(const Template &src, Template &dst) const | |
| 183 | - { | |
| 184 | - dst = map(src); | |
| 185 | - } | |
| 186 | - | |
| 187 | - Template map(const Template &t) const { | |
| 188 | - Template mappedT = t; | |
| 189 | - QMap<QString,QVariant> map; | |
| 190 | - | |
| 191 | - foreach(const QString &s, inputVariables) { | |
| 192 | - if (t.file.contains(s)) { | |
| 193 | - map.insert(s,t.file.value(s)); | |
| 194 | - mappedT.file.remove(s); | |
| 195 | - } | |
| 196 | - } | |
| 197 | - | |
| 198 | - if (!map.isEmpty()) mappedT.file.set(outputVariable,map); | |
| 199 | - | |
| 200 | - return mappedT; | |
| 201 | - } | |
| 202 | -}; | |
| 203 | - | |
| 204 | -BR_REGISTER(Transform, MapTransform) | |
| 205 | - | |
| 206 | - | |
| 207 | -/*! | |
| 208 | 134 | * \ingroup distances |
| 209 | 135 | * \brief Unmaps Turk HITs to be compared against query mats |
| 210 | 136 | * \author Scott Klum \cite sklum |
| ... | ... | @@ -212,33 +138,17 @@ BR_REGISTER(Transform, MapTransform) |
| 212 | 138 | class TurkDistance : public Distance |
| 213 | 139 | { |
| 214 | 140 | Q_OBJECT |
| 215 | - Q_PROPERTY(QString HIT READ get_HIT WRITE set_HIT RESET reset_HIT) | |
| 216 | - Q_PROPERTY(QStringList keys READ get_keys WRITE set_keys RESET reset_keys STORED false) | |
| 217 | - Q_PROPERTY(float maxVotes READ get_maxVotes WRITE set_maxVotes RESET reset_maxVotes STORED false) | |
| 218 | - Q_PROPERTY(float maxRange READ get_maxRange WRITE set_maxRange RESET reset_maxRange STORED false) | |
| 219 | - Q_PROPERTY(float minRange READ get_minRange WRITE set_minRange RESET reset_minRange STORED false) | |
| 220 | - Q_PROPERTY(bool classify READ get_classify WRITE set_classify RESET reset_classify STORED false) | |
| 221 | - Q_PROPERTY(bool consensusOnly READ get_consensusOnly WRITE set_consensusOnly RESET reset_consensusOnly STORED false) | |
| 222 | - BR_PROPERTY(QString, HIT, QString()) | |
| 223 | - BR_PROPERTY(QStringList, keys, QStringList()) | |
| 224 | - BR_PROPERTY(float, maxVotes, 1) | |
| 225 | - BR_PROPERTY(float, maxRange, 1) | |
| 226 | - BR_PROPERTY(float, minRange, 0) | |
| 227 | - BR_PROPERTY(bool, classify, false) | |
| 228 | - BR_PROPERTY(bool, consensusOnly, false) | |
| 141 | + Q_PROPERTY(QString key READ get_key WRITE set_key RESET reset_key) | |
| 142 | + Q_PROPERTY(QStringList values READ get_values WRITE set_values RESET reset_values STORED false) | |
| 143 | + BR_PROPERTY(QString, key, QString()) | |
| 144 | + BR_PROPERTY(QStringList, values, QStringList()) | |
| 229 | 145 | |
| 230 | 146 | float compare(const Template &target, const Template &query) const |
| 231 | 147 | { |
| 232 | - Template t = unmap(target, HIT, maxVotes, maxRange, minRange, classify, consensusOnly); | |
| 233 | - | |
| 234 | - QList<float> targetValues; | |
| 235 | - foreach(const QString &s, keys) targetValues.append(t.file.get<float>(s)); | |
| 236 | - | |
| 237 | - float stddev = .75; | |
| 238 | - | |
| 148 | + const float stddev = .75; | |
| 239 | 149 | float score = 0; |
| 240 | - for (int i=0; i<targetValues.size(); i++) score += 1/(stddev*sqrt(2*CV_PI))*exp(-0.5*pow((query.m().at<float>(0,i)-targetValues[i])/stddev, 2)); | |
| 241 | - | |
| 150 | + for (int i=0; i<values.size(); i++) | |
| 151 | + score += 1 / (stddev*sqrt(2*CV_PI)) * exp(-0.5*pow((query.m().at<float>(0,i)-target.file.get<float>(key + "_" + values[i]))/stddev, 2)); | |
| 242 | 152 | return score; |
| 243 | 153 | } |
| 244 | 154 | }; | ... | ... |