Commit 22de0ca461d6cb90cfda0781d5d797e87504d486
1 parent
aebd450c
Preliminar changes towards refering to class labels etc. more specifically
Change default label name from Subjet to Label (since label is a more general term). Use different default variable names for classification (label), regression (regressor/regressand), and clustering (ClusterID) Update some (far from all) transforms to accept arguments specifying their input/output variables. Update eval classification to optionally take target variable names as arguments
Showing
22 changed files
with
104 additions
and
55 deletions
app/br/br.cpp
| @@ -130,7 +130,7 @@ public: | @@ -130,7 +130,7 @@ public: | ||
| 130 | br_convert(parv[0], parv[1], parv[2]); | 130 | br_convert(parv[0], parv[1], parv[2]); |
| 131 | } else if (!strcmp(fun, "evalClassification")) { | 131 | } else if (!strcmp(fun, "evalClassification")) { |
| 132 | check(parc == 2, "Incorrect parameter count for 'evalClassification'."); | 132 | check(parc == 2, "Incorrect parameter count for 'evalClassification'."); |
| 133 | - br_eval_classification(parv[0], parv[1]); | 133 | + br_eval_classification(parv[0], parv[1], parc >= 3 ? parv[2] : NULL, parc >= 4 ? parv[3] : NULL); |
| 134 | } else if (!strcmp(fun, "evalRegression")) { | 134 | } else if (!strcmp(fun, "evalRegression")) { |
| 135 | check(parc == 2, "Incorrect parameter count for 'evalRegression'."); | 135 | check(parc == 2, "Incorrect parameter count for 'evalRegression'."); |
| 136 | br_eval_regression(parv[0], parv[1]); | 136 | br_eval_regression(parv[0], parv[1]); |
app/examples/age_estimation.cpp
| @@ -29,7 +29,8 @@ | @@ -29,7 +29,8 @@ | ||
| 29 | 29 | ||
| 30 | static void printTemplate(const br::Template &t) | 30 | static void printTemplate(const br::Template &t) |
| 31 | { | 31 | { |
| 32 | - printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Subject"))); | 32 | + // may use age directly -cao |
| 33 | + printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get<float>("Regressand"))); | ||
| 33 | } | 34 | } |
| 34 | 35 | ||
| 35 | int main(int argc, char *argv[]) | 36 | int main(int argc, char *argv[]) |
app/examples/gender_estimation.cpp
| @@ -29,7 +29,8 @@ | @@ -29,7 +29,8 @@ | ||
| 29 | 29 | ||
| 30 | static void printTemplate(const br::Template &t) | 30 | static void printTemplate(const br::Template &t) |
| 31 | { | 31 | { |
| 32 | - printf("%s gender: %s\n", qPrintable(t.file.fileName()), qPrintable(t.file.get<QString>("Subject"))); | 32 | + // may use gender directly -cao |
| 33 | + printf("%s gender: %s\n", qPrintable(t.file.fileName()), qPrintable(t.file.get<QString>("Label"))); | ||
| 33 | } | 34 | } |
| 34 | 35 | ||
| 35 | int main(int argc, char *argv[]) | 36 | int main(int argc, char *argv[]) |
openbr/core/bee.cpp
| @@ -93,10 +93,10 @@ void BEE::writeSigset(const QString &sigset, const br::FileList &files, bool ign | @@ -93,10 +93,10 @@ void BEE::writeSigset(const QString &sigset, const br::FileList &files, bool ign | ||
| 93 | QStringList metadata; | 93 | QStringList metadata; |
| 94 | if (!ignoreMetadata) | 94 | if (!ignoreMetadata) |
| 95 | foreach (const QString &key, file.localKeys()) { | 95 | foreach (const QString &key, file.localKeys()) { |
| 96 | - if ((key == "Index") || (key == "Subject")) continue; | 96 | + if ((key == "Index") || (key == "Label")) continue; |
| 97 | metadata.append(key+"=\""+QtUtils::toString(file.value(key))+"\""); | 97 | metadata.append(key+"=\""+QtUtils::toString(file.value(key))+"\""); |
| 98 | } | 98 | } |
| 99 | - lines.append("\t<biometric-signature name=\"" + file.get<QString>("Subject") +"\">"); | 99 | + lines.append("\t<biometric-signature name=\"" + file.get<QString>("Label") +"\">"); |
| 100 | lines.append("\t\t<presentation file-name=\"" + file.name + "\" " + metadata.join(" ") + "/>"); | 100 | lines.append("\t\t<presentation file-name=\"" + file.name + "\" " + metadata.join(" ") + "/>"); |
| 101 | lines.append("\t</biometric-signature>"); | 101 | lines.append("\t</biometric-signature>"); |
| 102 | } | 102 | } |
| @@ -260,10 +260,10 @@ void BEE::makeMask(const QString &targetInput, const QString &queryInput, const | @@ -260,10 +260,10 @@ void BEE::makeMask(const QString &targetInput, const QString &queryInput, const | ||
| 260 | 260 | ||
| 261 | cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition) | 261 | cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition) |
| 262 | { | 262 | { |
| 263 | - // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet | ||
| 264 | - // -cao | ||
| 265 | - QList<QString> targetLabels = targets.get<QString>("Subject", "-1"); | ||
| 266 | - QList<QString> queryLabels = queries.get<QString>("Subject", "-1"); | 263 | + // Direct use of "Label" isn't general, also would prefer to use indexProperty, rather than |
| 264 | + // doing string comparisons (but that isn't implemented yet for FileList) -cao | ||
| 265 | + QList<QString> targetLabels = targets.get<QString>("Label", "-1"); | ||
| 266 | + QList<QString> queryLabels = queries.get<QString>("Label", "-1"); | ||
| 267 | QList<int> targetPartitions = targets.crossValidationPartitions(); | 267 | QList<int> targetPartitions = targets.crossValidationPartitions(); |
| 268 | QList<int> queryPartitions = queries.crossValidationPartitions(); | 268 | QList<int> queryPartitions = queries.crossValidationPartitions(); |
| 269 | 269 |
openbr/core/classify.cpp
| @@ -31,7 +31,7 @@ struct Counter | @@ -31,7 +31,7 @@ struct Counter | ||
| 31 | } | 31 | } |
| 32 | }; | 32 | }; |
| 33 | 33 | ||
| 34 | -void br::EvalClassification(const QString &predictedInput, const QString &truthInput) | 34 | +void br::EvalClassification(const QString &predictedInput, const QString &truthInput, const QString & predictedProperty, const QString & truthProperty) |
| 35 | { | 35 | { |
| 36 | qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); | 36 | qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); |
| 37 | 37 | ||
| @@ -44,9 +44,8 @@ void br::EvalClassification(const QString &predictedInput, const QString &truthI | @@ -44,9 +44,8 @@ void br::EvalClassification(const QString &predictedInput, const QString &truthI | ||
| 44 | if (predicted[i].file.name != truth[i].file.name) | 44 | if (predicted[i].file.name != truth[i].file.name) |
| 45 | qFatal("Input order mismatch."); | 45 | qFatal("Input order mismatch."); |
| 46 | 46 | ||
| 47 | - // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy. | ||
| 48 | - QString predictedSubject = predicted[i].file.get<QString>("Subject"); | ||
| 49 | - QString trueSubject = truth[i].file.get<QString>("Subject"); | 47 | + QString predictedSubject = predicted[i].file.get<QString>(predictedProperty); |
| 48 | + QString trueSubject = truth[i].file.get<QString>(truthProperty); | ||
| 50 | 49 | ||
| 51 | QStringList predictedSubjects(predictedSubject); | 50 | QStringList predictedSubjects(predictedSubject); |
| 52 | QStringList trueSubjects(trueSubject); | 51 | QStringList trueSubjects(trueSubject); |
| @@ -99,13 +98,19 @@ void br::EvalRegression(const QString &predictedInput, const QString &truthInput | @@ -99,13 +98,19 @@ void br::EvalRegression(const QString &predictedInput, const QString &truthInput | ||
| 99 | if (predicted.size() != truth.size()) qFatal("Input size mismatch."); | 98 | if (predicted.size() != truth.size()) qFatal("Input size mismatch."); |
| 100 | 99 | ||
| 101 | float rmsError = 0; | 100 | float rmsError = 0; |
| 101 | + float maeError = 0; | ||
| 102 | + // Direct use of Regressor/Regressand is not general -cao | ||
| 102 | QStringList truthValues, predictedValues; | 103 | QStringList truthValues, predictedValues; |
| 103 | for (int i=0; i<predicted.size(); i++) { | 104 | for (int i=0; i<predicted.size(); i++) { |
| 104 | if (predicted[i].file.name != truth[i].file.name) | 105 | if (predicted[i].file.name != truth[i].file.name) |
| 105 | qFatal("Input order mismatch."); | 106 | qFatal("Input order mismatch."); |
| 106 | - rmsError += pow(predicted[i].file.get<float>("Subject")-truth[i].file.get<float>("Subject"), 2.f); | ||
| 107 | - truthValues.append(QString::number(truth[i].file.get<float>("Subject"))); | ||
| 108 | - predictedValues.append(QString::number(predicted[i].file.get<float>("Subject"))); | 107 | + |
| 108 | + float difference = predicted[i].file.get<float>("Regressand") - truth[i].file.get<float>("Regressor"); | ||
| 109 | + | ||
| 110 | + rmsError += pow(difference, 2.f); | ||
| 111 | + maeError += fabsf(difference); | ||
| 112 | + truthValues.append(QString::number(truth[i].file.get<float>("Regressor"))); | ||
| 113 | + predictedValues.append(QString::number(predicted[i].file.get<float>("Regressand"))); | ||
| 109 | } | 114 | } |
| 110 | 115 | ||
| 111 | QStringList rSource; | 116 | QStringList rSource; |
| @@ -125,4 +130,6 @@ void br::EvalRegression(const QString &predictedInput, const QString &truthInput | @@ -125,4 +130,6 @@ void br::EvalRegression(const QString &predictedInput, const QString &truthInput | ||
| 125 | if (success) QtUtils::showFile("EvalRegression.pdf"); | 130 | if (success) QtUtils::showFile("EvalRegression.pdf"); |
| 126 | 131 | ||
| 127 | qDebug("RMS Error = %f", sqrt(rmsError/predicted.size())); | 132 | qDebug("RMS Error = %f", sqrt(rmsError/predicted.size())); |
| 133 | + qDebug("MAE = %f", maeError/predicted.size()); | ||
| 134 | + | ||
| 128 | } | 135 | } |
openbr/core/classify.h
| @@ -22,7 +22,7 @@ | @@ -22,7 +22,7 @@ | ||
| 22 | 22 | ||
| 23 | namespace br | 23 | namespace br |
| 24 | { | 24 | { |
| 25 | - void EvalClassification(const QString &predictedInput, const QString &truthInput); | 25 | + void EvalClassification(const QString &predictedInput, const QString &truthInput, const QString & predictedProperty="Label", const QString & truthProperty="Label"); |
| 26 | void EvalRegression(const QString &predictedInput, const QString &truthInput); | 26 | void EvalRegression(const QString &predictedInput, const QString &truthInput); |
| 27 | } | 27 | } |
| 28 | 28 |
openbr/core/cluster.cpp
| @@ -279,8 +279,8 @@ void br::EvalClustering(const QString &csv, const QString &input) | @@ -279,8 +279,8 @@ void br::EvalClustering(const QString &csv, const QString &input) | ||
| 279 | qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input)); | 279 | qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input)); |
| 280 | 280 | ||
| 281 | // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are | 281 | // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are |
| 282 | - // not named). | ||
| 283 | - QList<int> labels = TemplateList::fromGallery(input).files().get<int>("Subject"); | 282 | + // not named). Direct use of ClusterID is not general -cao |
| 283 | + QList<int> labels = TemplateList::fromGallery(input).files().get<int>("ClusterID"); | ||
| 284 | 284 | ||
| 285 | QHash<int, int> labelToIndex; | 285 | QHash<int, int> labelToIndex; |
| 286 | int nClusters = 0; | 286 | int nClusters = 0; |
openbr/frvt2012.cpp
| @@ -132,7 +132,8 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &input_face, int32_t &age) | @@ -132,7 +132,8 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &input_face, int32_t &age) | ||
| 132 | TemplateList templates; | 132 | TemplateList templates; |
| 133 | templates.append(templateFromONEFACE(input_face)); | 133 | templates.append(templateFromONEFACE(input_face)); |
| 134 | templates >> *frvt2012_age_transform.data(); | 134 | templates >> *frvt2012_age_transform.data(); |
| 135 | - age = templates.first().file.get<float>("Subject"); | 135 | + // should maybe use "Age" directly -cao |
| 136 | + age = templates.first().file.get<float>("Regressand"); | ||
| 136 | return templates.first().file.failed() ? 4 : 0; | 137 | return templates.first().file.failed() ? 4 : 0; |
| 137 | } | 138 | } |
| 138 | 139 | ||
| @@ -141,6 +142,7 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &input_face, int8_t &gender, | @@ -141,6 +142,7 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &input_face, int8_t &gender, | ||
| 141 | TemplateList templates; | 142 | TemplateList templates; |
| 142 | templates.append(templateFromONEFACE(input_face)); | 143 | templates.append(templateFromONEFACE(input_face)); |
| 143 | templates >> *frvt2012_gender_transform.data(); | 144 | templates >> *frvt2012_gender_transform.data(); |
| 144 | - mf = gender = templates.first().file.get<QString>("Subject") == "Male" ? 0 : 1; | 145 | + // Should maybe use "Gender" directly -cao |
| 146 | + mf = gender = templates.first().file.get<QString>("Label") == "Male" ? 0 : 1; | ||
| 145 | return templates.first().file.failed() ? 4 : 0; | 147 | return templates.first().file.failed() ? 4 : 0; |
| 146 | } | 148 | } |
openbr/gui/classifier.cpp
| @@ -43,14 +43,16 @@ void Classifier::_classify(File file) | @@ -43,14 +43,16 @@ void Classifier::_classify(File file) | ||
| 43 | continue; | 43 | continue; |
| 44 | 44 | ||
| 45 | if (algorithm == "GenderClassification") { | 45 | if (algorithm == "GenderClassification") { |
| 46 | + // Should maybe use gender directly -cao | ||
| 46 | key = "Gender"; | 47 | key = "Gender"; |
| 47 | - value = (f.get<QString>("Subject")); | 48 | + value = f.get<QString>("Label"); |
| 48 | } else if (algorithm == "AgeRegression") { | 49 | } else if (algorithm == "AgeRegression") { |
| 49 | key = "Age"; | 50 | key = "Age"; |
| 50 | - value = QString::number(int(f.get<float>("Subject")+0.5)) + " Years"; | 51 | + // similarly, age -cao |
| 52 | + value = QString::number(int(f.get<float>("Regressand")+0.5)) + " Years"; | ||
| 51 | } else { | 53 | } else { |
| 52 | key = algorithm; | 54 | key = algorithm; |
| 53 | - value = f.get<QString>("Subject"); | 55 | + value = f.get<QString>("Label"); |
| 54 | } | 56 | } |
| 55 | break; | 57 | break; |
| 56 | } | 58 | } |
openbr/openbr.cpp
| @@ -77,9 +77,14 @@ float br_eval(const char *simmat, const char *mask, const char *csv) | @@ -77,9 +77,14 @@ float br_eval(const char *simmat, const char *mask, const char *csv) | ||
| 77 | return Evaluate(simmat, mask, csv); | 77 | return Evaluate(simmat, mask, csv); |
| 78 | } | 78 | } |
| 79 | 79 | ||
| 80 | -void br_eval_classification(const char *predicted_input, const char *truth_input) | ||
| 81 | -{ | ||
| 82 | - EvalClassification(predicted_input, truth_input); | 80 | +void br_eval_classification(const char *predicted_input, const char *truth_input, const char *predicted_property, const char * truth_property) |
| 81 | +{ | ||
| 82 | + if (predicted_property && truth_property) | ||
| 83 | + EvalClassification(predicted_input, truth_input, predicted_property, truth_property); | ||
| 84 | + else if (predicted_property) | ||
| 85 | + EvalClassification(predicted_input, truth_input, predicted_property); | ||
| 86 | + else | ||
| 87 | + EvalClassification(predicted_input, truth_input); | ||
| 83 | } | 88 | } |
| 84 | 89 | ||
| 85 | void br_eval_clustering(const char *csv, const char *input) | 90 | void br_eval_clustering(const char *csv, const char *input) |
openbr/openbr.h
| @@ -164,7 +164,7 @@ BR_EXPORT float br_eval(const char *simmat, const char *mask, const char *csv = | @@ -164,7 +164,7 @@ BR_EXPORT float br_eval(const char *simmat, const char *mask, const char *csv = | ||
| 164 | * \param truth_input The ground truth br::Input. | 164 | * \param truth_input The ground truth br::Input. |
| 165 | * \see br_enroll | 165 | * \see br_enroll |
| 166 | */ | 166 | */ |
| 167 | -BR_EXPORT void br_eval_classification(const char *predicted_input, const char *truth_input); | 167 | +BR_EXPORT void br_eval_classification(const char *predicted_input, const char *truth_input, const char * predicted_property, const char * truth_property); |
| 168 | 168 | ||
| 169 | /*! | 169 | /*! |
| 170 | * \brief Evaluates and prints clustering accuracy to the terminal. | 170 | * \brief Evaluates and prints clustering accuracy to the terminal. |
openbr/openbr_plugin.cpp
| @@ -412,6 +412,7 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | @@ -412,6 +412,7 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) | ||
| 412 | // of target images to every partition | 412 | // of target images to every partition |
| 413 | newTemplates[i].file.set("Partition", -1); | 413 | newTemplates[i].file.set("Partition", -1); |
| 414 | } else { | 414 | } else { |
| 415 | + // Direct use of "Subject" is not general -cao | ||
| 415 | const QByteArray md5 = QCryptographicHash::hash(newTemplates[i].file.get<QString>("Subject").toLatin1(), QCryptographicHash::Md5); | 416 | const QByteArray md5 = QCryptographicHash::hash(newTemplates[i].file.get<QString>("Subject").toLatin1(), QCryptographicHash::Md5); |
| 416 | // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow | 417 | // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow |
| 417 | newTemplates[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); | 418 | newTemplates[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); |
openbr/plugins/cluster.cpp
| @@ -89,10 +89,12 @@ class KNNTransform : public Transform | @@ -89,10 +89,12 @@ class KNNTransform : public Transform | ||
| 89 | Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance STORED false) | 89 | Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance STORED false) |
| 90 | Q_PROPERTY(bool weighted READ get_weighted WRITE set_weighted RESET reset_weighted STORED false) | 90 | Q_PROPERTY(bool weighted READ get_weighted WRITE set_weighted RESET reset_weighted STORED false) |
| 91 | Q_PROPERTY(int numSubjects READ get_numSubjects WRITE set_numSubjects RESET reset_numSubjects STORED false) | 91 | Q_PROPERTY(int numSubjects READ get_numSubjects WRITE set_numSubjects RESET reset_numSubjects STORED false) |
| 92 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | ||
| 92 | BR_PROPERTY(int, k, 1) | 93 | BR_PROPERTY(int, k, 1) |
| 93 | BR_PROPERTY(br::Distance*, distance, NULL) | 94 | BR_PROPERTY(br::Distance*, distance, NULL) |
| 94 | BR_PROPERTY(bool, weighted, false) | 95 | BR_PROPERTY(bool, weighted, false) |
| 95 | BR_PROPERTY(int, numSubjects, 1) | 96 | BR_PROPERTY(int, numSubjects, 1) |
| 97 | + BR_PROPERTY(QString, inputVariable, "Label") | ||
| 96 | 98 | ||
| 97 | TemplateList gallery; | 99 | TemplateList gallery; |
| 98 | 100 | ||
| @@ -111,13 +113,13 @@ class KNNTransform : public Transform | @@ -111,13 +113,13 @@ class KNNTransform : public Transform | ||
| 111 | QHash<QString, float> votes; | 113 | QHash<QString, float> votes; |
| 112 | const int max = (k < 1) ? sortedScores.size() : std::min(k, sortedScores.size()); | 114 | const int max = (k < 1) ? sortedScores.size() : std::min(k, sortedScores.size()); |
| 113 | for (int j=0; j<max; j++) | 115 | for (int j=0; j<max; j++) |
| 114 | - votes[gallery[sortedScores[j].second].file.get<QString>("Subject")] += (weighted ? sortedScores[j].first : 1); | 116 | + votes[gallery[sortedScores[j].second].file.get<QString>(inputVariable)] += (weighted ? sortedScores[j].first : 1); |
| 115 | subjects.append(votes.keys()[votes.values().indexOf(Common::Max(votes.values()))]); | 117 | subjects.append(votes.keys()[votes.values().indexOf(Common::Max(votes.values()))]); |
| 116 | 118 | ||
| 117 | // Remove subject from consideration | 119 | // Remove subject from consideration |
| 118 | if (subjects.size() < numSubjects) | 120 | if (subjects.size() < numSubjects) |
| 119 | for (int j=sortedScores.size()-1; j>=0; j--) | 121 | for (int j=sortedScores.size()-1; j>=0; j--) |
| 120 | - if (gallery[sortedScores[j].second].file.get<QString>("Subject") == subjects.last()) | 122 | + if (gallery[sortedScores[j].second].file.get<QString>(inputVariable) == subjects.last()) |
| 121 | sortedScores.removeAt(j); | 123 | sortedScores.removeAt(j); |
| 122 | } | 124 | } |
| 123 | 125 |
openbr/plugins/eigen3.cpp
| @@ -318,10 +318,12 @@ class LDATransform : public Transform | @@ -318,10 +318,12 @@ class LDATransform : public Transform | ||
| 318 | Q_PROPERTY(bool pcaWhiten READ get_pcaWhiten WRITE set_pcaWhiten RESET reset_pcaWhiten STORED false) | 318 | Q_PROPERTY(bool pcaWhiten READ get_pcaWhiten WRITE set_pcaWhiten RESET reset_pcaWhiten STORED false) |
| 319 | Q_PROPERTY(int directLDA READ get_directLDA WRITE set_directLDA RESET reset_directLDA STORED false) | 319 | Q_PROPERTY(int directLDA READ get_directLDA WRITE set_directLDA RESET reset_directLDA STORED false) |
| 320 | Q_PROPERTY(float directDrop READ get_directDrop WRITE set_directDrop RESET reset_directDrop STORED false) | 320 | Q_PROPERTY(float directDrop READ get_directDrop WRITE set_directDrop RESET reset_directDrop STORED false) |
| 321 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | ||
| 321 | BR_PROPERTY(float, pcaKeep, 0.98) | 322 | BR_PROPERTY(float, pcaKeep, 0.98) |
| 322 | BR_PROPERTY(bool, pcaWhiten, false) | 323 | BR_PROPERTY(bool, pcaWhiten, false) |
| 323 | BR_PROPERTY(int, directLDA, 0) | 324 | BR_PROPERTY(int, directLDA, 0) |
| 324 | BR_PROPERTY(float, directDrop, 0.1) | 325 | BR_PROPERTY(float, directDrop, 0.1) |
| 326 | + BR_PROPERTY(QString, inputVariable, "Label") | ||
| 325 | 327 | ||
| 326 | int dimsOut; | 328 | int dimsOut; |
| 327 | Eigen::VectorXf mean; | 329 | Eigen::VectorXf mean; |
| @@ -330,7 +332,7 @@ class LDATransform : public Transform | @@ -330,7 +332,7 @@ class LDATransform : public Transform | ||
| 330 | void train(const TemplateList &_trainingSet) | 332 | void train(const TemplateList &_trainingSet) |
| 331 | { | 333 | { |
| 332 | // creates "Label" | 334 | // creates "Label" |
| 333 | - TemplateList trainingSet = TemplateList::relabel(_trainingSet, "Subject"); | 335 | + TemplateList trainingSet = TemplateList::relabel(_trainingSet, inputVariable); |
| 334 | 336 | ||
| 335 | int instances = trainingSet.size(); | 337 | int instances = trainingSet.size(); |
| 336 | 338 |
openbr/plugins/gallery.cpp
| @@ -71,7 +71,7 @@ class arffGallery : public Gallery | @@ -71,7 +71,7 @@ class arffGallery : public Gallery | ||
| 71 | } | 71 | } |
| 72 | 72 | ||
| 73 | arffFile.write(qPrintable(OpenCVUtils::matrixToStringList(t).join(','))); | 73 | arffFile.write(qPrintable(OpenCVUtils::matrixToStringList(t).join(','))); |
| 74 | - arffFile.write(qPrintable(",'" + t.file.get<QString>("Subject") + "'\n")); | 74 | + arffFile.write(qPrintable(",'" + t.file.get<QString>("Label") + "'\n")); |
| 75 | } | 75 | } |
| 76 | }; | 76 | }; |
| 77 | 77 | ||
| @@ -874,7 +874,7 @@ class statGallery : public Gallery | @@ -874,7 +874,7 @@ class statGallery : public Gallery | ||
| 874 | 874 | ||
| 875 | void write(const Template &t) | 875 | void write(const Template &t) |
| 876 | { | 876 | { |
| 877 | - subjects.insert(t.file.get<QString>("Subject")); | 877 | + subjects.insert(t.file.get<QString>("Label")); |
| 878 | bytes.append(t.bytes()); | 878 | bytes.append(t.bytes()); |
| 879 | } | 879 | } |
| 880 | }; | 880 | }; |
openbr/plugins/independent.cpp
| @@ -20,11 +20,11 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t | @@ -20,11 +20,11 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t | ||
| 20 | const bool atLeast = transform->instances < 0; | 20 | const bool atLeast = transform->instances < 0; |
| 21 | const int instances = abs(transform->instances); | 21 | const int instances = abs(transform->instances); |
| 22 | 22 | ||
| 23 | - QList<QString> allLabels = templates.get<QString>("Subject"); | 23 | + QList<QString> allLabels = templates.get<QString>("Label"); |
| 24 | QList<QString> uniqueLabels = allLabels.toSet().toList(); | 24 | QList<QString> uniqueLabels = allLabels.toSet().toList(); |
| 25 | qSort(uniqueLabels); | 25 | qSort(uniqueLabels); |
| 26 | 26 | ||
| 27 | - QMap<QString,int> counts = templates.countValues<QString>("Subject", instances != std::numeric_limits<int>::max()); | 27 | + QMap<QString,int> counts = templates.countValues<QString>("Label", instances != std::numeric_limits<int>::max()); |
| 28 | 28 | ||
| 29 | if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max())) | 29 | if ((instances != std::numeric_limits<int>::max()) && (transform->classes != std::numeric_limits<int>::max())) |
| 30 | foreach (const QString & label, counts.keys()) | 30 | foreach (const QString & label, counts.keys()) |
openbr/plugins/normalize.cpp
| @@ -127,7 +127,7 @@ private: | @@ -127,7 +127,7 @@ private: | ||
| 127 | Mat m; | 127 | Mat m; |
| 128 | OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F); | 128 | OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F); |
| 129 | 129 | ||
| 130 | - const QList<int> labels = data.indexProperty("Subject"); | 130 | + const QList<int> labels = data.indexProperty("Label"); |
| 131 | const int dims = m.cols; | 131 | const int dims = m.cols; |
| 132 | 132 | ||
| 133 | vector<Mat> mv, av, bv; | 133 | vector<Mat> mv, av, bv; |
openbr/plugins/output.cpp
| @@ -146,8 +146,8 @@ class meltOutput : public MatrixOutput | @@ -146,8 +146,8 @@ class meltOutput : public MatrixOutput | ||
| 146 | QStringList lines; | 146 | QStringList lines; |
| 147 | if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys)); | 147 | if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys)); |
| 148 | 148 | ||
| 149 | - QList<QString> queryLabels = queryFiles.get<QString>("Subject"); | ||
| 150 | - QList<QString> targetLabels = targetFiles.get<QString>("Subject"); | 149 | + QList<QString> queryLabels = queryFiles.get<QString>("Label"); |
| 150 | + QList<QString> targetLabels = targetFiles.get<QString>("Label"); | ||
| 151 | 151 | ||
| 152 | for (int i=0; i<queryFiles.size(); i++) { | 152 | for (int i=0; i<queryFiles.size(); i++) { |
| 153 | for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) { | 153 | for (int j=(selfSimilar ? i+1 : 0); j<targetFiles.size(); j++) { |
| @@ -298,7 +298,7 @@ class txtOutput : public MatrixOutput | @@ -298,7 +298,7 @@ class txtOutput : public MatrixOutput | ||
| 298 | if (file.isNull() || targetFiles.isEmpty() || queryFiles.isEmpty()) return; | 298 | if (file.isNull() || targetFiles.isEmpty() || queryFiles.isEmpty()) return; |
| 299 | QStringList lines; | 299 | QStringList lines; |
| 300 | foreach (const File &file, queryFiles) | 300 | foreach (const File &file, queryFiles) |
| 301 | - lines.append(file.name + " " + file.get<QString>("Subject")); | 301 | + lines.append(file.name + " " + file.get<QString>("Label")); |
| 302 | QtUtils::writeFile(file, lines); | 302 | QtUtils::writeFile(file, lines); |
| 303 | } | 303 | } |
| 304 | }; | 304 | }; |
| @@ -426,7 +426,7 @@ class rankOutput : public MatrixOutput | @@ -426,7 +426,7 @@ class rankOutput : public MatrixOutput | ||
| 426 | typedef QPair<float,int> Pair; | 426 | typedef QPair<float,int> Pair; |
| 427 | int rank = 1; | 427 | int rank = 1; |
| 428 | foreach (const Pair &pair, Common::Sort(OpenCVUtils::matrixToVector<float>(data.row(i)), true)) { | 428 | foreach (const Pair &pair, Common::Sort(OpenCVUtils::matrixToVector<float>(data.row(i)), true)) { |
| 429 | - if (targetFiles[pair.second].get<QString>("Subject") == queryFiles[i].get<QString>("Subject")) { | 429 | + if (targetFiles[pair.second].get<QString>("Label") == queryFiles[i].get<QString>("Label")) { |
| 430 | ranks.append(rank); | 430 | ranks.append(rank); |
| 431 | positions.append(pair.second); | 431 | positions.append(pair.second); |
| 432 | scores.append(pair.first); | 432 | scores.append(pair.first); |
openbr/plugins/quality.cpp
| @@ -26,10 +26,10 @@ class ImpostorUniquenessMeasureTransform : public Transform | @@ -26,10 +26,10 @@ class ImpostorUniquenessMeasureTransform : public Transform | ||
| 26 | 26 | ||
| 27 | float calculateIUM(const Template &probe, const TemplateList &gallery) const | 27 | float calculateIUM(const Template &probe, const TemplateList &gallery) const |
| 28 | { | 28 | { |
| 29 | - const QString probeLabel = probe.file.get<QString>("Subject"); | 29 | + const QString probeLabel = probe.file.get<QString>("Label"); |
| 30 | TemplateList subset = gallery; | 30 | TemplateList subset = gallery; |
| 31 | for (int j=subset.size()-1; j>=0; j--) | 31 | for (int j=subset.size()-1; j>=0; j--) |
| 32 | - if (subset[j].file.get<QString>("Subject") == probeLabel) | 32 | + if (subset[j].file.get<QString>("Label") == probeLabel) |
| 33 | subset.removeAt(j); | 33 | subset.removeAt(j); |
| 34 | 34 | ||
| 35 | QList<float> scores = distance->compare(subset, probe); | 35 | QList<float> scores = distance->compare(subset, probe); |
| @@ -158,7 +158,7 @@ class MatchProbabilityDistance : public Distance | @@ -158,7 +158,7 @@ class MatchProbabilityDistance : public Distance | ||
| 158 | { | 158 | { |
| 159 | distance->train(src); | 159 | distance->train(src); |
| 160 | 160 | ||
| 161 | - const QList<int> labels = src.indexProperty("Subject"); | 161 | + const QList<int> labels = src.indexProperty("Label"); |
| 162 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); | 162 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); |
| 163 | distance->compare(src, src, matrixOutput.data()); | 163 | distance->compare(src, src, matrixOutput.data()); |
| 164 | 164 | ||
| @@ -228,7 +228,7 @@ class HeatMapDistance : public Distance | @@ -228,7 +228,7 @@ class HeatMapDistance : public Distance | ||
| 228 | { | 228 | { |
| 229 | distance->train(src); | 229 | distance->train(src); |
| 230 | 230 | ||
| 231 | - const QList<int> labels = src.indexProperty("Subject"); | 231 | + const QList<int> labels = src.indexProperty("Label"); |
| 232 | 232 | ||
| 233 | QList<TemplateList> patches; | 233 | QList<TemplateList> patches; |
| 234 | 234 | ||
| @@ -316,7 +316,7 @@ class UnitDistance : public Distance | @@ -316,7 +316,7 @@ class UnitDistance : public Distance | ||
| 316 | void train(const TemplateList &templates) | 316 | void train(const TemplateList &templates) |
| 317 | { | 317 | { |
| 318 | const TemplateList samples = templates.mid(0, 2000); | 318 | const TemplateList samples = templates.mid(0, 2000); |
| 319 | - const QList<int> sampleLabels = samples.indexProperty("Subject"); | 319 | + const QList<int> sampleLabels = samples.indexProperty("Label"); |
| 320 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size()))); | 320 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size()))); |
| 321 | Distance::compare(samples, samples, matrixOutput.data()); | 321 | Distance::compare(samples, samples, matrixOutput.data()); |
| 322 | 322 |
openbr/plugins/quantize.cpp
| @@ -150,7 +150,7 @@ class BayesianQuantizationDistance : public Distance | @@ -150,7 +150,7 @@ class BayesianQuantizationDistance : public Distance | ||
| 150 | qFatal("Expected sigle matrix templates of type CV_8UC1!"); | 150 | qFatal("Expected sigle matrix templates of type CV_8UC1!"); |
| 151 | 151 | ||
| 152 | const Mat data = OpenCVUtils::toMat(src.data()); | 152 | const Mat data = OpenCVUtils::toMat(src.data()); |
| 153 | - const QList<int> templateLabels = src.indexProperty("Subject"); | 153 | + const QList<int> templateLabels = src.indexProperty("Label"); |
| 154 | loglikelihoods = QVector<float>(data.cols*256, 0); | 154 | loglikelihoods = QVector<float>(data.cols*256, 0); |
| 155 | 155 | ||
| 156 | QFutureSynchronizer<void> futures; | 156 | QFutureSynchronizer<void> futures; |
| @@ -474,7 +474,7 @@ private: | @@ -474,7 +474,7 @@ private: | ||
| 474 | Mat data = OpenCVUtils::toMat(src.data()); | 474 | Mat data = OpenCVUtils::toMat(src.data()); |
| 475 | const int step = getStep(data.cols); | 475 | const int step = getStep(data.cols); |
| 476 | 476 | ||
| 477 | - const QList<int> labels = src.indexProperty("Subject"); | 477 | + const QList<int> labels = src.indexProperty("Label"); |
| 478 | 478 | ||
| 479 | Mat &lut = ProductQuantizationLUTs[index]; | 479 | Mat &lut = ProductQuantizationLUTs[index]; |
| 480 | lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1); | 480 | lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1); |
openbr/plugins/quantize2.cpp
| @@ -77,7 +77,7 @@ class BayesianQuantizationTransform : public Transform | @@ -77,7 +77,7 @@ class BayesianQuantizationTransform : public Transform | ||
| 77 | void train(const TemplateList &src) | 77 | void train(const TemplateList &src) |
| 78 | { | 78 | { |
| 79 | const Mat data = OpenCVUtils::toMat(src.data()); | 79 | const Mat data = OpenCVUtils::toMat(src.data()); |
| 80 | - const QList<int> labels = src.indexProperty("Subject"); | 80 | + const QList<int> labels = src.indexProperty("Label"); |
| 81 | 81 | ||
| 82 | thresholds = QVector<float>(256*data.cols); | 82 | thresholds = QVector<float>(256*data.cols); |
| 83 | 83 |
openbr/plugins/svm.cpp
| @@ -101,6 +101,8 @@ class SVMTransform : public Transform | @@ -101,6 +101,8 @@ class SVMTransform : public Transform | ||
| 101 | Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) | 101 | Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) |
| 102 | Q_PROPERTY(float C READ get_C WRITE set_C RESET reset_C STORED false) | 102 | Q_PROPERTY(float C READ get_C WRITE set_C RESET reset_C STORED false) |
| 103 | Q_PROPERTY(float gamma READ get_gamma WRITE set_gamma RESET reset_gamma STORED false) | 103 | Q_PROPERTY(float gamma READ get_gamma WRITE set_gamma RESET reset_gamma STORED false) |
| 104 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | ||
| 105 | + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) | ||
| 104 | 106 | ||
| 105 | public: | 107 | public: |
| 106 | enum Kernel { Linear = CvSVM::LINEAR, | 108 | enum Kernel { Linear = CvSVM::LINEAR, |
| @@ -119,6 +121,9 @@ private: | @@ -119,6 +121,9 @@ private: | ||
| 119 | BR_PROPERTY(Type, type, C_SVC) | 121 | BR_PROPERTY(Type, type, C_SVC) |
| 120 | BR_PROPERTY(float, C, -1) | 122 | BR_PROPERTY(float, C, -1) |
| 121 | BR_PROPERTY(float, gamma, -1) | 123 | BR_PROPERTY(float, gamma, -1) |
| 124 | + BR_PROPERTY(QString, inputVariable, "") | ||
| 125 | + BR_PROPERTY(QString, outputVariable, "") | ||
| 126 | + | ||
| 122 | 127 | ||
| 123 | SVM svm; | 128 | SVM svm; |
| 124 | QHash<QString, int> labelMap; | 129 | QHash<QString, int> labelMap; |
| @@ -128,14 +133,14 @@ private: | @@ -128,14 +133,14 @@ private: | ||
| 128 | { | 133 | { |
| 129 | Mat data = OpenCVUtils::toMat(_data.data()); | 134 | Mat data = OpenCVUtils::toMat(_data.data()); |
| 130 | Mat lab; | 135 | Mat lab; |
| 131 | - // If we are doing regression, assume subject has float values | 136 | + // If we are doing regression, the input variable should have float values |
| 132 | if (type == EPS_SVR || type == NU_SVR) { | 137 | if (type == EPS_SVR || type == NU_SVR) { |
| 133 | - lab = OpenCVUtils::toMat(_data.get<float>("Subject")); | 138 | + lab = OpenCVUtils::toMat(_data.get<float>(inputVariable)); |
| 134 | } | 139 | } |
| 135 | - // If we are doing classification, assume subject has discrete values, map them | 140 | + // If we are doing classification, we should be dealing with discrete values. Map them |
| 136 | // and store the mapping data | 141 | // and store the mapping data |
| 137 | else { | 142 | else { |
| 138 | - QList<int> dataLabels = _data.indexProperty("Subject", labelMap, reverseLookup); | 143 | + QList<int> dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); |
| 139 | lab = OpenCVUtils::toMat(dataLabels); | 144 | lab = OpenCVUtils::toMat(dataLabels); |
| 140 | } | 145 | } |
| 141 | trainSVM(svm, data, lab, kernel, type, C, gamma); | 146 | trainSVM(svm, data, lab, kernel, type, C, gamma); |
| @@ -146,9 +151,9 @@ private: | @@ -146,9 +151,9 @@ private: | ||
| 146 | dst = src; | 151 | dst = src; |
| 147 | float prediction = svm.predict(src.m().reshape(1, 1)); | 152 | float prediction = svm.predict(src.m().reshape(1, 1)); |
| 148 | if (type == EPS_SVR || type == NU_SVR) | 153 | if (type == EPS_SVR || type == NU_SVR) |
| 149 | - dst.file.set("Subject", prediction); | 154 | + dst.file.set(outputVariable, prediction); |
| 150 | else | 155 | else |
| 151 | - dst.file.set("Subject", reverseLookup[prediction]); | 156 | + dst.file.set(outputVariable, reverseLookup[prediction]); |
| 152 | } | 157 | } |
| 153 | 158 | ||
| 154 | void store(QDataStream &stream) const | 159 | void store(QDataStream &stream) const |
| @@ -162,6 +167,24 @@ private: | @@ -162,6 +167,24 @@ private: | ||
| 162 | loadSVM(svm, stream); | 167 | loadSVM(svm, stream); |
| 163 | stream >> labelMap >> reverseLookup; | 168 | stream >> labelMap >> reverseLookup; |
| 164 | } | 169 | } |
| 170 | + | ||
| 171 | + void init() | ||
| 172 | + { | ||
| 173 | + // Since SVM can do regression or classification, we have to check the problem type before | ||
| 174 | + // specifying target variable names | ||
| 175 | + if (inputVariable.isEmpty()) | ||
| 176 | + { | ||
| 177 | + if (type == EPS_SVR || type == NU_SVR) { | ||
| 178 | + inputVariable = "Regressor"; | ||
| 179 | + if (outputVariable.isEmpty()) | ||
| 180 | + outputVariable = "Regressand"; | ||
| 181 | + } | ||
| 182 | + else | ||
| 183 | + inputVariable = "Label"; | ||
| 184 | + } | ||
| 185 | + if (outputVariable.isEmpty()) | ||
| 186 | + outputVariable = inputVariable; | ||
| 187 | + } | ||
| 165 | }; | 188 | }; |
| 166 | 189 | ||
| 167 | BR_REGISTER(Transform, SVMTransform) | 190 | BR_REGISTER(Transform, SVMTransform) |
| @@ -178,6 +201,8 @@ class SVMDistance : public Distance | @@ -178,6 +201,8 @@ class SVMDistance : public Distance | ||
| 178 | Q_ENUMS(Type) | 201 | Q_ENUMS(Type) |
| 179 | Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) | 202 | Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) |
| 180 | Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) | 203 | Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) |
| 204 | + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | ||
| 205 | + | ||
| 181 | 206 | ||
| 182 | public: | 207 | public: |
| 183 | enum Kernel { Linear = CvSVM::LINEAR, | 208 | enum Kernel { Linear = CvSVM::LINEAR, |
| @@ -194,13 +219,14 @@ public: | @@ -194,13 +219,14 @@ public: | ||
| 194 | private: | 219 | private: |
| 195 | BR_PROPERTY(Kernel, kernel, Linear) | 220 | BR_PROPERTY(Kernel, kernel, Linear) |
| 196 | BR_PROPERTY(Type, type, EPS_SVR) | 221 | BR_PROPERTY(Type, type, EPS_SVR) |
| 222 | + BR_PROPERTY(QString, inputVariable, "Label") | ||
| 197 | 223 | ||
| 198 | SVM svm; | 224 | SVM svm; |
| 199 | 225 | ||
| 200 | void train(const TemplateList &src) | 226 | void train(const TemplateList &src) |
| 201 | { | 227 | { |
| 202 | const Mat data = OpenCVUtils::toMat(src.data()); | 228 | const Mat data = OpenCVUtils::toMat(src.data()); |
| 203 | - const QList<int> lab = src.indexProperty("Subject"); | 229 | + const QList<int> lab = src.indexProperty(inputVariable); |
| 204 | 230 | ||
| 205 | const int instances = data.rows * (data.rows+1) / 2; | 231 | const int instances = data.rows * (data.rows+1) / 2; |
| 206 | Mat deltaData(instances, data.cols, data.type()); | 232 | Mat deltaData(instances, data.cols, data.type()); |