diff --git a/app/br/br.cpp b/app/br/br.cpp index 8615274..475e81f 100644 --- a/app/br/br.cpp +++ b/app/br/br.cpp @@ -129,8 +129,8 @@ public: check(parc == 3, "Incorrect parameter count for 'convert'."); br_convert(parv[0], parv[1], parv[2]); } else if (!strcmp(fun, "evalClassification")) { - check(parc == 2, "Incorrect parameter count for 'evalClassification'."); - br_eval_classification(parv[0], parv[1]); + check(parc >= 2 && parc <= 4, "Incorrect parameter count for 'evalClassification'."); + br_eval_classification(parv[0], parv[1], parc >= 3 ? parv[2] : "", parc >= 4 ? parv[3] : ""); } else if (!strcmp(fun, "evalClustering")) { check(parc == 2, "Incorrect parameter count for 'evalClustering'."); br_eval_clustering(parv[0], parv[1]); @@ -138,8 +138,8 @@ public: check((parc >= 2) && (parc <= 3), "Incorrect parameter count for 'evalDetection'."); br_eval_detection(parv[0], parv[1], parc == 3 ? parv[2] : ""); } else if (!strcmp(fun, "evalRegression")) { - check(parc == 2, "Incorrect parameter count for 'evalRegression'."); - br_eval_regression(parv[0], parv[1]); + check(parc >= 2 && parc <= 4, "Incorrect parameter count for 'evalRegression'."); + br_eval_regression(parv[0], parv[1], parc >= 3 ? parv[2] : "", parc >= 4 ? parv[3] : ""); } else if (!strcmp(fun, "plotDetection")) { check(parc >= 2, "Incorrect parameter count for 'plotDetection'."); br_plot_detection(parc-1, parv, parv[parc-1], true); @@ -215,10 +215,10 @@ private: "-combineMasks ... {mask} (And|Or)\n" "-cat ... {gallery}\n" "-convert (Format|Gallery|Output) {output_file}\n" - "-evalClassification \n" + "-evalClassification \n" "-evalClustering \n" "-evalDetection [{csv}]\n" - "-evalRegression \n" + "-evalRegression \n" "-plotDetection ... {destination}\n" "-plotMetadata ... \n" "-getHeader \n" diff --git a/app/examples/age_estimation.cpp b/app/examples/age_estimation.cpp index c2a90e5..2958b71 100644 --- a/app/examples/age_estimation.cpp +++ b/app/examples/age_estimation.cpp @@ -29,7 +29,7 @@ static void printTemplate(const br::Template &t) { - printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get("Subject"))); + printf("%s age: %d\n", qPrintable(t.file.fileName()), int(t.file.get("Age"))); } int main(int argc, char *argv[]) diff --git a/app/examples/gender_estimation.cpp b/app/examples/gender_estimation.cpp index dd4ebd2..d2a824d 100644 --- a/app/examples/gender_estimation.cpp +++ b/app/examples/gender_estimation.cpp @@ -29,7 +29,7 @@ static void printTemplate(const br::Template &t) { - printf("%s gender: %s\n", qPrintable(t.file.fileName()), qPrintable(t.file.get("Subject"))); + printf("%s gender: %s\n", qPrintable(t.file.fileName()), qPrintable(t.file.get("Gender"))); } int main(int argc, char *argv[]) diff --git a/openbr/core/bee.cpp b/openbr/core/bee.cpp index c6cb254..9639657 100644 --- a/openbr/core/bee.cpp +++ b/openbr/core/bee.cpp @@ -99,10 +99,10 @@ void BEE::writeSigset(const QString &sigset, const br::FileList &files, bool ign QStringList metadata; if (!ignoreMetadata) foreach (const QString &key, file.localKeys()) { - if ((key == "Index") || (key == "Subject")) continue; + if ((key == "Index") || (key == "Label")) continue; metadata.append(key+"=\""+QtUtils::toString(file.value(key))+"\""); } - lines.append("\t("Subject",file.fileName()) +"\">"); + lines.append("\t("Label",file.fileName()) +"\">"); lines.append("\t\t"); lines.append("\t"); } @@ -266,10 +266,11 @@ void BEE::makeMask(const QString &targetInput, const QString &queryInput, const cv::Mat BEE::makeMask(const br::FileList &targets, const br::FileList &queries, int partition) { - // Would like to use indexProperty for this, but didn't make a version of that for Filelist yet - // -cao - QList targetLabels = File::get(targets, "Subject", "-1"); - QList queryLabels = File::get(queries, "Subject", "-1"); + // Direct use of "Label" isn't general, also would prefer to use indexProperty, rather than + // doing string comparisons (but that isn't implemented yet for FileList) -cao + QList targetLabels = File::get(targets, "Label", "-1"); + QList queryLabels = File::get(queries, "Label", "-1"); + QList targetPartitions = targets.crossValidationPartitions(); QList queryPartitions = queries.crossValidationPartitions(); diff --git a/openbr/core/cluster.cpp b/openbr/core/cluster.cpp index 6a6b954..8c880d9 100644 --- a/openbr/core/cluster.cpp +++ b/openbr/core/cluster.cpp @@ -279,8 +279,8 @@ void br::EvalClustering(const QString &csv, const QString &input) qDebug("Evaluating %s against %s", qPrintable(csv), qPrintable(input)); // We assume clustering algorithms store assigned cluster labels as integers (since the clusters are - // not named). - QList labels = File::get(TemplateList::fromGallery(input), "Subject"); + // not named). Direct use of ClusterID is not general -cao + QList labels = File::get(TemplateList::fromGallery(input), "ClusterID"); QHash labelToIndex; int nClusters = 0; diff --git a/openbr/core/common.cpp b/openbr/core/common.cpp index 17584ec..55cbd3e 100644 --- a/openbr/core/common.cpp +++ b/openbr/core/common.cpp @@ -15,11 +15,15 @@ * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * */ #include "common.h" +#include using namespace std; /**** GLOBAL ****/ void Common::seedRNG() { + static QMutex seedControl; + QMutexLocker lock(&seedControl); + static bool seeded = false; if (!seeded) { srand(0); // We seed with 0 instead of time(NULL) to have reproducible randomness @@ -29,8 +33,6 @@ void Common::seedRNG() { QList Common::RandSample(int n, int max, int min, bool unique) { - seedRNG(); - QList samples; samples.reserve(n); int range = max-min; if (range <= 0) qFatal("Non-positive range."); @@ -50,8 +52,6 @@ QList Common::RandSample(int n, int max, int min, bool unique) QList Common::RandSample(int n, const QSet &values, bool unique) { - seedRNG(); - QList valueList = values.toList(); if (unique && (values.size() <= n)) return valueList; diff --git a/openbr/core/eval.cpp b/openbr/core/eval.cpp index 9e813d8..a87807e 100644 --- a/openbr/core/eval.cpp +++ b/openbr/core/eval.cpp @@ -255,9 +255,20 @@ struct Counter } }; -void EvalClassification(const QString &predictedInput, const QString &truthInput) +void EvalClassification(const QString &predictedInput, const QString &truthInput, QString predictedProperty, QString truthProperty) { qDebug("Evaluating classification of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); + + if (predictedProperty.isEmpty()) + predictedProperty = "Label"; + // If predictedProperty is specified, but truthProperty isn't, copy over the value from + // predicted property + else if (truthProperty.isEmpty()) + truthProperty = predictedProperty; + + if (truthProperty.isEmpty()) + truthProperty = "Label"; + TemplateList predicted(TemplateList::fromGallery(predictedInput)); TemplateList truth(TemplateList::fromGallery(truthInput)); if (predicted.size() != truth.size()) qFatal("Input size mismatch."); @@ -267,9 +278,8 @@ void EvalClassification(const QString &predictedInput, const QString &truthInput if (predicted[i].file.name != truth[i].file.name) qFatal("Input order mismatch."); - // Typically these lists will be of length one, but this generalization allows measuring multi-class labeling accuracy. - QString predictedSubject = predicted[i].file.get("Subject"); - QString trueSubject = truth[i].file.get("Subject"); + QString predictedSubject = predicted[i].file.get(predictedProperty); + QString trueSubject = truth[i].file.get(truthProperty); QStringList predictedSubjects(predictedSubject); QStringList trueSubjects(trueSubject); @@ -466,21 +476,37 @@ float EvalDetection(const QString &predictedInput, const QString &truthInput, co return averageOverlap; } -void EvalRegression(const QString &predictedInput, const QString &truthInput) +void EvalRegression(const QString &predictedInput, const QString &truthInput, QString predictedProperty, QString truthProperty) { qDebug("Evaluating regression of %s against %s", qPrintable(predictedInput), qPrintable(truthInput)); + + if (predictedProperty.isEmpty()) + predictedProperty = "Regressor"; + // If predictedProperty is specified, but truthProperty isn't, copy the value over + // rather than using the default for truthProperty + else if (truthProperty.isEmpty()) + truthProperty = predictedProperty; + + if (truthProperty.isEmpty()) + predictedProperty = "Regressand"; + const TemplateList predicted(TemplateList::fromGallery(predictedInput)); const TemplateList truth(TemplateList::fromGallery(truthInput)); if (predicted.size() != truth.size()) qFatal("Input size mismatch."); float rmsError = 0; + float maeError = 0; QStringList truthValues, predictedValues; for (int i=0; i("Subject")-truth[i].file.get("Subject"), 2.f); - truthValues.append(QString::number(truth[i].file.get("Subject"))); - predictedValues.append(QString::number(predicted[i].file.get("Subject"))); + + float difference = predicted[i].file.get(predictedProperty) - truth[i].file.get(truthProperty); + + rmsError += pow(difference, 2.f); + maeError += fabsf(difference); + truthValues.append(QString::number(truth[i].file.get(truthProperty))); + predictedValues.append(QString::number(predicted[i].file.get(predictedProperty))); } QStringList rSource; @@ -500,6 +526,7 @@ void EvalRegression(const QString &predictedInput, const QString &truthInput) if (success) QtUtils::showFile("EvalRegression.pdf"); qDebug("RMS Error = %f", sqrt(rmsError/predicted.size())); + qDebug("MAE = %f", maeError/predicted.size()); } } // namespace br diff --git a/openbr/core/eval.h b/openbr/core/eval.h index d90d1a5..b4f199d 100644 --- a/openbr/core/eval.h +++ b/openbr/core/eval.h @@ -26,9 +26,9 @@ namespace br float Evaluate(const QString &simmat, const QString &mask = "", const QString &csv = ""); // Returns TAR @ FAR = 0.001 float Evaluate(const cv::Mat &scores, const FileList &target, const FileList &query, const QString &csv = "", int parition = 0); float Evaluate(const cv::Mat &scores, const cv::Mat &masks, const QString &csv = ""); - void EvalClassification(const QString &predictedInput, const QString &truthInput); + void EvalClassification(const QString &predictedInput, const QString &truthInput, QString predictedProperty="", QString truthProperty=""); float EvalDetection(const QString &predictedInput, const QString &truthInput, const QString &csv = ""); // Return average overlap - void EvalRegression(const QString &predictedInput, const QString &truthInput); + void EvalRegression(const QString &predictedInput, const QString &truthInput, QString predictedProperty="", QString truthProperty=""); } #endif // __EVAL_H diff --git a/openbr/frvt2012.cpp b/openbr/frvt2012.cpp index ac47ac8..b0c2d29 100644 --- a/openbr/frvt2012.cpp +++ b/openbr/frvt2012.cpp @@ -132,7 +132,7 @@ int32_t SdkEstimator::estimate_age(const ONEFACE &input_face, int32_t &age) TemplateList templates; templates.append(templateFromONEFACE(input_face)); templates >> *frvt2012_age_transform.data(); - age = templates.first().file.get("Subject"); + age = templates.first().file.get("Age"); return templates.first().file.failed() ? 4 : 0; } @@ -141,6 +141,6 @@ int32_t SdkEstimator::estimate_gender(const ONEFACE &input_face, int8_t &gender, TemplateList templates; templates.append(templateFromONEFACE(input_face)); templates >> *frvt2012_gender_transform.data(); - mf = gender = templates.first().file.get("Subject") == "Male" ? 0 : 1; + mf = gender = templates.first().file.get("Gender") == "Male" ? 0 : 1; return templates.first().file.failed() ? 4 : 0; } diff --git a/openbr/gui/classifier.cpp b/openbr/gui/classifier.cpp index 2df9f5e..71a3102 100644 --- a/openbr/gui/classifier.cpp +++ b/openbr/gui/classifier.cpp @@ -39,19 +39,23 @@ void Classifier::_classify(File file) { QString key, value; foreach (const File &f, Enroll(file.flat(), File("[algorithm=" + algorithm + "]"))) { - if (!f.contains("Label")) - continue; if (algorithm == "GenderClassification") { key = "Gender"; - value = (f.get("Subject")); } else if (algorithm == "AgeRegression") { key = "Age"; - value = QString::number(int(f.get("Subject")+0.5)) + " Years"; } else { key = algorithm; - value = f.get("Subject"); } + + if (!f.contains(key)) + continue; + + if (algorithm == "AgeRegression") + value = QString::number(int(f.get(key)+0.5)) + " Years"; + else + value = f.get(key); + break; } diff --git a/openbr/openbr.cpp b/openbr/openbr.cpp index c269366..09ca043 100644 --- a/openbr/openbr.cpp +++ b/openbr/openbr.cpp @@ -72,9 +72,9 @@ float br_eval(const char *simmat, const char *mask, const char *csv) return Evaluate(simmat, mask, csv); } -void br_eval_classification(const char *predicted_gallery, const char *truth_gallery) +void br_eval_classification(const char *predicted_gallery, const char *truth_gallery, const char *predicted_property, const char * truth_property) { - EvalClassification(predicted_gallery, truth_gallery); + EvalClassification(predicted_gallery, truth_gallery, predicted_property, truth_property); } void br_eval_clustering(const char *csv, const char *gallery) @@ -87,9 +87,9 @@ float br_eval_detection(const char *predicted_gallery, const char *truth_gallery return EvalDetection(predicted_gallery, truth_gallery, csv); } -void br_eval_regression(const char *predicted_gallery, const char *truth_gallery) +void br_eval_regression(const char *predicted_gallery, const char *truth_gallery, const char * predicted_property, const char * truth_property) { - EvalRegression(predicted_gallery, truth_gallery); + EvalRegression(predicted_gallery, truth_gallery, predicted_property, truth_property); } void br_finalize() diff --git a/openbr/openbr.h b/openbr/openbr.h index 6d96a28..6a26a77 100644 --- a/openbr/openbr.h +++ b/openbr/openbr.h @@ -149,7 +149,7 @@ BR_EXPORT float br_eval(const char *simmat, const char *mask, const char *csv = * \param predicted_gallery The predicted br::Gallery. * \param truth_gallery The ground truth br::Gallery. */ -BR_EXPORT void br_eval_classification(const char *predicted_gallery, const char *truth_gallery); +BR_EXPORT void br_eval_classification(const char *predicted_gallery, const char *truth_gallery, const char * predicted_property="", const char * truth_property=""); /*! * \brief Evaluates and prints clustering accuracy to the terminal. @@ -173,7 +173,7 @@ BR_EXPORT float br_eval_detection(const char *predicted_gallery, const char *tru * \param predicted_gallery The predicted br::Gallery. * \param truth_gallery The ground truth br::Gallery. */ -BR_EXPORT void br_eval_regression(const char *predicted_gallery, const char *truth_gallery); +BR_EXPORT void br_eval_regression(const char *predicted_gallery, const char *truth_gallery, const char * predicted_property="", const char * truth_property=""); /*! * \brief Wraps br::Context::finalize() diff --git a/openbr/openbr_plugin.cpp b/openbr/openbr_plugin.cpp index 1becbbf..38aeb80 100644 --- a/openbr/openbr_plugin.cpp +++ b/openbr/openbr_plugin.cpp @@ -412,7 +412,8 @@ TemplateList TemplateList::fromGallery(const br::File &gallery) // of target images to every partition newTemplates[i].file.set("Partition", -1); } else { - const QByteArray md5 = QCryptographicHash::hash(newTemplates[i].file.get("Subject").toLatin1(), QCryptographicHash::Md5); + // Direct use of "Label" is not general -cao + const QByteArray md5 = QCryptographicHash::hash(newTemplates[i].file.get("Label").toLatin1(), QCryptographicHash::Md5); // Select the right 8 hex characters so that it can be represented as a 64 bit integer without overflow newTemplates[i].file.set("Partition", md5.toHex().right(8).toULongLong(0, 16) % crossValidate); } @@ -890,6 +891,8 @@ void br::Context::initialize(int &argc, char *argv[], QString sdkPath, bool use_ qInstallMessageHandler(messageHandler); + Common::seedRNG(); + // Search for SDK if (sdkPath.isEmpty()) { QStringList checkPaths; checkPaths << QDir::currentPath() << QCoreApplication::applicationDirPath(); @@ -1082,9 +1085,6 @@ Transform::Transform(bool _independent, bool _trainable) { independent = _independent; trainable = _trainable; - classes = std::numeric_limits::max(); - instances = std::numeric_limits::max(); - fraction = 1; } Transform *Transform::make(QString str, QObject *parent) @@ -1140,9 +1140,6 @@ Transform *Transform::make(QString str, QObject *parent) Transform *Transform::clone() const { Transform *clone = Factory::make(file.flat()); - clone->classes = classes; - clone->instances = instances; - clone->fraction = fraction; return clone; } diff --git a/openbr/openbr_plugin.h b/openbr/openbr_plugin.h index a5171b2..0f0d800 100644 --- a/openbr/openbr_plugin.h +++ b/openbr/openbr_plugin.h @@ -130,13 +130,6 @@ void reset_##NAME() { NAME = DEFAULT; } * -# If the value is convertable to a floating point number then it is represented with \c float. * -# Otherwise, it is represented with \c QString. * - * The metadata keys \c Subject and \c Label have special significance in the system. - * \c Subject is a string specifying a unique identifier used to determine ground truth match/non-match. - * \c Label is a floating point value used for supervised learning. - * When the system needs labels for training, but only subjects are provided in the file metadata, the rule for generating labels is as follows. - * If the subject value can be converted to a float then do so and consider that the label. - * Otherwise, generate a unique integer ID for the string starting from zero and incrementing by one everytime another ID is needed. - * * Metadata keys fall into one of two categories: * - \c camelCaseKeys are inputs that specify how to process the file. * - \c Capitalized_Underscored_Keys are outputs computed from processing the file. @@ -147,8 +140,6 @@ void reset_##NAME() { NAME = DEFAULT; } * --- | ---- | ----------- * separator | QString | Seperate #name into multiple files * Index | int | Index of a template in a template list - * Subject | QString | Class name - * Label | float | Class value * Confidence | float | Classification/Regression quality * FTE | bool | Failure to enroll * FTO | bool | Failure to open @@ -157,13 +148,15 @@ void reset_##NAME() { NAME = DEFAULT; } * *_Width | float | Size * *_Height | float | Size * *_Radius | float | Size + * Label | QString | Class label * Theta | float | Pose * Roll | float | Pose * Pitch | float | Pose * Yaw | float | Pose * Points | QList | List of unnamed points * Rects | QList | List of unnamed rects - * Age | QString | Age used for demographic filtering + * Age | float | Age used for demographic filtering + * Gender | QString | Subject gender * _* | * | Reserved for internal use */ struct BR_EXPORT File @@ -172,7 +165,7 @@ struct BR_EXPORT File File() {} File(const QString &file) { init(file); } /*!< \brief Construct a file from a string. */ - File(const QString &file, const QVariant &subject) { init(file); set("Subject", subject); } /*!< \brief Construct a file from a string and assign a label. */ + File(const QString &file, const QVariant &label) { init(file); set("Label", label); } /*!< \brief Construct a file from a string and assign a label. */ File(const char *file) { init(file); } /*!< \brief Construct a file from a c-style string. */ inline operator QString() const { return name; } /*!< \brief Returns #name. */ QString flat() const; /*!< \brief A stringified version of the file with metadata. */ @@ -1058,12 +1051,6 @@ class BR_EXPORT Transform : public Object Q_OBJECT public: - Q_PROPERTY(int classes READ get_classes WRITE set_classes RESET reset_classes STORED false) - Q_PROPERTY(int instances READ get_instances WRITE set_instances RESET reset_instances STORED false) - Q_PROPERTY(float fraction READ get_fraction WRITE set_fraction RESET reset_fraction STORED false) - BR_PROPERTY(int, classes, std::numeric_limits::max()) - BR_PROPERTY(int, instances, std::numeric_limits::max()) - BR_PROPERTY(float, fraction, 1) bool independent, trainable; virtual ~Transform() {} diff --git a/openbr/plugins/algorithms.cpp b/openbr/plugins/algorithms.cpp index 809a112..48154af 100644 --- a/openbr/plugins/algorithms.cpp +++ b/openbr/plugins/algorithms.cpp @@ -48,7 +48,7 @@ class AlgorithmsInitializer : public Initializer // Video Globals->abbreviations.insert("DisplayVideo", "Stream([FPSLimit(30)+Show(false,[FrameNumber])+Discard])"); Globals->abbreviations.insert("PerFrameDetection", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+ASEFEyes+RestoreMat(original)+Draw(inPlace=true),Show(false,[FrameNumber])+Discard])"); - Globals->abbreviations.insert("AgeGenderDemo", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+Expand+++(+Rename(Subject,Age)+Discard)/(+Rename(Subject,Gender)+Discard)+RestoreMat(original)+Draw(inPlace=true)+DrawPropertiesPoint([Age,Gender],Affine_0,inPlace=true)+SaveMat(original)+Discard+Contract,RestoreMat(original)+FPSCalc+Show(false,[AvgFPS,Age,Gender])+Discard])"); + Globals->abbreviations.insert("AgeGenderDemo", "Stream([SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+Expand+++/+Discard+RestoreMat(original)+Draw(inPlace=true)+DrawPropertiesPoint([Age,Gender],Affine_0,inPlace=true)+SaveMat(original)+Discard+Contract,RestoreMat(original)+FPSCalc+Show(false,[AvgFPS,Age,Gender])+Discard])"); Globals->abbreviations.insert("BoVW", "Flatten+CatRows+KMeans(500)+Hist(500)"); Globals->abbreviations.insert("HOF", "Stream([KeyPointDetector(SIFT),AggregateFrames(2)+OpticalFlow,ROI,HoGDescriptor])+BoVW"); Globals->abbreviations.insert("HoG", "Stream([KeyPointDetector(SIFT),ROI,HoGDescriptor])+BoVW"); @@ -78,14 +78,14 @@ class AlgorithmsInitializer : public Initializer Globals->abbreviations.insert("FaceDetection", "(Open+Cvt(Gray)+Cascade(FrontalFace))"); Globals->abbreviations.insert("DenseLBP", "(Blur(1.1)+Gamma(0.2)+DoG(1,2)+ContrastEq(0.1,10)+LBP(1,2)+RectRegions(8,8,6,6)+Hist(59))"); Globals->abbreviations.insert("DenseSIFT", "(Grid(10,10)+SIFTDescriptor(12)+ByRow)"); - Globals->abbreviations.insert("FaceRecognitionRegistration", "(ASEFEyes+Affine(88,88,0.25,0.35)+FTE(DFFS,instances=1))"); - Globals->abbreviations.insert("FaceRecognitionExtraction", "(Mask+DenseSIFT/DenseLBP+PCA(0.95,instances=1)+Normalize(L2)+Cat)"); - Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+LDA(0.98,instances=-2)+Cat+PCA(768,instances=1))"); + Globals->abbreviations.insert("FaceRecognitionRegistration", "(ASEFEyes+Affine(88,88,0.25,0.35)+DownsampleTraining(FTE(DFFS),instances=1))"); + Globals->abbreviations.insert("FaceRecognitionExtraction", "(Mask+DenseSIFT/DenseLBP+DownsampleTraining(PCA(0.95),instances=1)+Normalize(L2)+Cat)"); + Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+DownsampleTraining(LDA(0.98),instances=-2)+Cat+DownsampleTraining(PCA(768),instances=1))"); Globals->abbreviations.insert("FaceRecognitionQuantization", "(Normalize(L1)+Quantize)"); Globals->abbreviations.insert("FaceClassificationRegistration", "(ASEFEyes+Affine(56,72,0.33,0.45)+FTE(DFFS))"); - Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+PCA(0.95,instances=-1)+Cat)"); - Globals->abbreviations.insert("AgeRegressor", "Center(Range,instances=-1)+SVM(RBF,EPS_SVR,instances=100)"); - Globals->abbreviations.insert("GenderClassifier", "Center(Range,instances=-1)+SVM(RBF,C_SVC,instances=4000)"); + Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+DownsampleTraining(PCA(0.95),instances=-1, inputVariable=Gender)+Cat)"); + Globals->abbreviations.insert("AgeRegressor", "DownsampleTraining(Center(Range),instances=-1, inputVariable=Age)+DownsampleTraining(SVM(RBF,EPS_SVR,inputVariable=Age),instances=100, inputVariable=Age)"); + Globals->abbreviations.insert("GenderClassifier", "DownsampleTraining(Center(Range),instances=-1, inputVariable=Gender)+DownsampleTraining(SVM(RBF,C_SVC,inputVariable=Gender),instances=4000, inputVariable=Gender)"); Globals->abbreviations.insert("UCharL1", "Unit(ByteL1)"); } }; diff --git a/openbr/plugins/cluster.cpp b/openbr/plugins/cluster.cpp index e6ff87c..3a9b577 100644 --- a/openbr/plugins/cluster.cpp +++ b/openbr/plugins/cluster.cpp @@ -89,10 +89,14 @@ class KNNTransform : public Transform Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance STORED false) Q_PROPERTY(bool weighted READ get_weighted WRITE set_weighted RESET reset_weighted STORED false) Q_PROPERTY(int numSubjects READ get_numSubjects WRITE set_numSubjects RESET reset_numSubjects STORED false) + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) BR_PROPERTY(int, k, 1) BR_PROPERTY(br::Distance*, distance, NULL) BR_PROPERTY(bool, weighted, false) BR_PROPERTY(int, numSubjects, 1) + BR_PROPERTY(QString, inputVariable, "Label") + BR_PROPERTY(QString, outputVariable, "KNN") TemplateList gallery; @@ -111,17 +115,17 @@ class KNNTransform : public Transform QHash votes; const int max = (k < 1) ? sortedScores.size() : std::min(k, sortedScores.size()); for (int j=0; j("Subject")] += (weighted ? sortedScores[j].first : 1); + votes[gallery[sortedScores[j].second].file.get(inputVariable)] += (weighted ? sortedScores[j].first : 1); subjects.append(votes.keys()[votes.values().indexOf(Common::Max(votes.values()))]); // Remove subject from consideration if (subjects.size() < numSubjects) for (int j=sortedScores.size()-1; j>=0; j--) - if (gallery[sortedScores[j].second].file.get("Subject") == subjects.last()) + if (gallery[sortedScores[j].second].file.get(inputVariable) == subjects.last()) sortedScores.removeAt(j); } - dst.file.set("KNN", subjects.size() > 1 ? "[" + subjects.join(",") + "]" : subjects.first()); + dst.file.set(outputVariable, subjects.size() > 1 ? "[" + subjects.join(",") + "]" : subjects.first()); } void store(QDataStream &stream) const diff --git a/openbr/plugins/eigen3.cpp b/openbr/plugins/eigen3.cpp index 89765db..4a904c2 100644 --- a/openbr/plugins/eigen3.cpp +++ b/openbr/plugins/eigen3.cpp @@ -303,10 +303,12 @@ class LDATransform : public Transform Q_PROPERTY(bool pcaWhiten READ get_pcaWhiten WRITE set_pcaWhiten RESET reset_pcaWhiten STORED false) Q_PROPERTY(int directLDA READ get_directLDA WRITE set_directLDA RESET reset_directLDA STORED false) Q_PROPERTY(float directDrop READ get_directDrop WRITE set_directDrop RESET reset_directDrop STORED false) + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) BR_PROPERTY(float, pcaKeep, 0.98) BR_PROPERTY(bool, pcaWhiten, false) BR_PROPERTY(int, directLDA, 0) BR_PROPERTY(float, directDrop, 0.1) + BR_PROPERTY(QString, inputVariable, "Label") int dimsOut; Eigen::VectorXf mean; @@ -315,7 +317,7 @@ class LDATransform : public Transform void train(const TemplateList &_trainingSet) { // creates "Label" - TemplateList trainingSet = TemplateList::relabel(_trainingSet, "Subject"); + TemplateList trainingSet = TemplateList::relabel(_trainingSet, inputVariable); int instances = trainingSet.size(); diff --git a/openbr/plugins/gallery.cpp b/openbr/plugins/gallery.cpp index 18f04e0..85a307b 100644 --- a/openbr/plugins/gallery.cpp +++ b/openbr/plugins/gallery.cpp @@ -71,7 +71,7 @@ class arffGallery : public Gallery } arffFile.write(qPrintable(OpenCVUtils::matrixToStringList(t).join(','))); - arffFile.write(qPrintable(",'" + t.file.get("Subject") + "'\n")); + arffFile.write(qPrintable(",'" + t.file.get("Label") + "'\n")); } }; @@ -643,11 +643,16 @@ class dbGallery : public Gallery query = query.mid(1, query.size()-2); if (!q.exec(query)) qFatal("%s.", qPrintable(q.lastError().text())); + if ((q.record().count() == 0) || (q.record().count() > 3)) qFatal("Query record expected one to three fields, got %d.", q.record().count()); const bool hasMetadata = (q.record().count() >= 2); const bool hasFilter = (q.record().count() >= 3); + QString labelName = "Label"; + if (q.record().count() >= 2) + labelName = q.record().fieldName(1); + // subset = seed:subjectMaxSize:numSubjects:subjectMinSize or // subset = seed:{Metadata,...,Metadata}:numSubjects int seed = 0, subjectMaxSize = std::numeric_limits::max(), numSubjects = std::numeric_limits::max(), subjectMinSize = 0; @@ -673,6 +678,7 @@ class dbGallery : public Gallery QHash > entries; // QHash > while (q.next()) { if (hasFilter && (seed >= 0) && (qHash(q.value(2).toString()) % 2 != (uint)seed % 2)) continue; // Ensures training and testing filters don't overlap + if (metadataFields.isEmpty()) entries[hasMetadata ? q.value(1).toString() : ""].append(QPair(q.value(0).toString(), hasFilter ? q.value(2).toString() : "")); else @@ -707,8 +713,10 @@ class dbGallery : public Gallery if (entryList.size() > subjectMaxSize) std::random_shuffle(entryList.begin(), entryList.end()); - foreach (const Entry &entry, entryList.mid(0, subjectMaxSize)) - templates.append(File(entry.first, label)); + foreach (const Entry &entry, entryList.mid(0, subjectMaxSize)) { + templates.append(File(entry.first)); + templates.last().file.set(labelName, label); + } numSubjects--; } } @@ -816,7 +824,7 @@ class statGallery : public Gallery void write(const Template &t) { - subjects.insert(t.file.get("Subject")); + subjects.insert(t.file.get("Label")); bytes.append(t.bytes()); } }; diff --git a/openbr/plugins/independent.cpp b/openbr/plugins/independent.cpp index 123c979..4a1ec03 100644 --- a/openbr/plugins/independent.cpp +++ b/openbr/plugins/independent.cpp @@ -9,37 +9,36 @@ using namespace cv; namespace br { -static TemplateList Downsample(const TemplateList &templates, const Transform *transform) +static TemplateList Downsample(const TemplateList &templates, int classes, int instances, float fraction, const QString & inputVariable) { // Return early when no downsampling is required - if ((transform->classes == std::numeric_limits::max()) && - (transform->instances == std::numeric_limits::max()) && - (transform->fraction >= 1)) + if ((classes == std::numeric_limits::max()) && + (instances == std::numeric_limits::max()) && + (fraction >= 1)) return templates; - const bool atLeast = transform->instances < 0; - const int instances = abs(transform->instances); + const bool atLeast = instances < 0; + instances = abs(instances); - QList allLabels = File::get(templates, "Subject"); + QList allLabels = File::get(templates, inputVariable); QList uniqueLabels = allLabels.toSet().toList(); qSort(uniqueLabels); - QMap counts = templates.countValues("Subject", instances != std::numeric_limits::max()); + QMap counts = templates.countValues(inputVariable, instances != std::numeric_limits::max()); - if ((instances != std::numeric_limits::max()) && (transform->classes != std::numeric_limits::max())) + if ((instances != std::numeric_limits::max()) && (classes != std::numeric_limits::max())) foreach (const QString & label, counts.keys()) if (counts[label] < instances) counts.remove(label); uniqueLabels = counts.keys(); - if ((transform->classes != std::numeric_limits::max()) && (uniqueLabels.size() < transform->classes)) - qWarning("Downsample requested %d classes but only %d are available.", transform->classes, uniqueLabels.size()); + if ((classes != std::numeric_limits::max()) && (uniqueLabels.size() < classes)) + qWarning("Downsample requested %d classes but only %d are available.", classes, uniqueLabels.size()); - Common::seedRNG(); QList selectedLabels = uniqueLabels; - if (transform->classes < uniqueLabels.size()) { + if (classes < uniqueLabels.size()) { std::random_shuffle(selectedLabels.begin(), selectedLabels.end()); - selectedLabels = selectedLabels.mid(0, transform->classes); + selectedLabels = selectedLabels.mid(0, classes); } TemplateList downsample; @@ -56,14 +55,45 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t downsample.append(templates.value(indices[j])); } - if (transform->fraction < 1) { + if (fraction < 1) { std::random_shuffle(downsample.begin(), downsample.end()); - downsample = downsample.mid(0, downsample.size()*transform->fraction); + downsample = downsample.mid(0, downsample.size()*fraction); } return downsample; } +class DownsampleTrainingTransform : public Transform +{ + Q_OBJECT + Q_PROPERTY(br::Transform* transform READ get_transform WRITE set_transform RESET reset_transform STORED true) + Q_PROPERTY(int classes READ get_classes WRITE set_classes RESET reset_classes STORED false) + Q_PROPERTY(int instances READ get_instances WRITE set_instances RESET reset_instances STORED false) + Q_PROPERTY(float fraction READ get_fraction WRITE set_fraction RESET reset_fraction STORED false) + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) + BR_PROPERTY(br::Transform*, transform, NULL) + BR_PROPERTY(int, classes, std::numeric_limits::max()) + BR_PROPERTY(int, instances, std::numeric_limits::max()) + BR_PROPERTY(float, fraction, 1) + BR_PROPERTY(QString, inputVariable, "Label") + + void project(const Template & src, Template & dst) const + { + transform->project(src,dst); + } + + + void train(const TemplateList &data) + { + if (!transform || !transform->trainable) + return; + + TemplateList downsampled = Downsample(data, classes, instances, fraction, inputVariable); + transform->train(downsampled); + } +}; +BR_REGISTER(Transform, DownsampleTrainingTransform) + /*! * \ingroup transforms * \brief Clones the transform so that it can be applied independently. @@ -124,13 +154,10 @@ class IndependentTransform : public MetaTransform while (transforms.size() < templatesList.size()) transforms.append(transform->clone()); - for (int i=0; i futures; for (int i=0; i > contours; @@ -171,7 +174,7 @@ class LargestConvexAreaTransform : public UntrainableTransform if (area / hullArea > 0.98) maxArea = std::max(maxArea, area); } - dst.file.set("Label", maxArea); + dst.file.set(outputVariable, maxArea); } }; diff --git a/openbr/plugins/normalize.cpp b/openbr/plugins/normalize.cpp index 5c85415..842e952 100644 --- a/openbr/plugins/normalize.cpp +++ b/openbr/plugins/normalize.cpp @@ -97,6 +97,7 @@ class CenterTransform : public Transform Q_OBJECT Q_ENUMS(Method) Q_PROPERTY(Method method READ get_method WRITE set_method RESET reset_method STORED false) + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) public: /*!< */ @@ -107,6 +108,7 @@ public: private: BR_PROPERTY(Method, method, Mean) + BR_PROPERTY(QString, inputVariable, "Label") Mat a, b; // dst = (src - b) / a @@ -127,7 +129,7 @@ private: Mat m; OpenCVUtils::toMat(data.data()).convertTo(m, CV_64F); - const QList labels = data.indexProperty("Subject"); + const QList labels = data.indexProperty(inputVariable); const int dims = m.cols; vector mv, av, bv; diff --git a/openbr/plugins/openbr_internal.h b/openbr/plugins/openbr_internal.h index 2f42e50..27136da 100644 --- a/openbr/plugins/openbr_internal.h +++ b/openbr/plugins/openbr_internal.h @@ -219,10 +219,6 @@ public: } output->file = this->file; - output->classes = classes; - output->instances = instances; - output->fraction = fraction; - output->init(); return output; diff --git a/openbr/plugins/output.cpp b/openbr/plugins/output.cpp index 2b75582..e71adc1 100644 --- a/openbr/plugins/output.cpp +++ b/openbr/plugins/output.cpp @@ -146,8 +146,8 @@ class meltOutput : public MatrixOutput QStringList lines; if (file.baseName() != "terminal") lines.append(QString("Query,Target,Mask,Similarity%1").arg(keys)); - QList queryLabels = File::get(queryFiles, "Subject"); - QList targetLabels = File::get(targetFiles, "Subject"); + QList queryLabels = File::get(queryFiles, "Label"); + QList targetLabels = File::get(targetFiles, "Label"); for (int i=0; i("Subject")); + lines.append(file.name + " " + file.get("Label")); QtUtils::writeFile(file, lines); } }; @@ -428,7 +428,7 @@ class rankOutput : public MatrixOutput foreach (const Pair &pair, Common::Sort(OpenCVUtils::matrixToVector(data.row(i)), true)) { if (Globals->crossValidate > 0 ? (targetFiles[pair.second].get("Partition",-1) == queryFiles[i].get("Partition",-1)) : true) { if (QString(targetFiles[pair.second]) != QString(queryFiles[i])) { - if (targetFiles[pair.second].get("Subject") == queryFiles[i].get("Subject")) { + if (targetFiles[pair.second].get("Label") == queryFiles[i].get("Label")) { ranks.append(rank); positions.append(pair.second); scores.append(pair.first); diff --git a/openbr/plugins/quality.cpp b/openbr/plugins/quality.cpp index fd3209c..e372460 100644 --- a/openbr/plugins/quality.cpp +++ b/openbr/plugins/quality.cpp @@ -19,17 +19,20 @@ class ImpostorUniquenessMeasureTransform : public Transform Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) Q_PROPERTY(double mean READ get_mean WRITE set_mean RESET reset_mean) Q_PROPERTY(double stddev READ get_stddev WRITE set_stddev RESET reset_stddev) + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) BR_PROPERTY(br::Distance*, distance, Distance::make("Dist(L2)", this)) BR_PROPERTY(double, mean, 0) BR_PROPERTY(double, stddev, 1) + BR_PROPERTY(QString, inputVariable, "Label") + TemplateList impostors; float calculateIUM(const Template &probe, const TemplateList &gallery) const { - const QString probeLabel = probe.file.get("Subject"); + const QString probeLabel = probe.file.get(inputVariable); TemplateList subset = gallery; for (int j=subset.size()-1; j>=0; j--) - if (subset[j].file.get("Subject") == probeLabel) + if (subset[j].file.get(inputVariable) == probeLabel) subset.removeAt(j); QList scores = distance->compare(subset, probe); @@ -151,6 +154,7 @@ class MatchProbabilityDistance : public Distance Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) Q_PROPERTY(bool gaussian READ get_gaussian WRITE set_gaussian RESET reset_gaussian STORED false) Q_PROPERTY(bool crossModality READ get_crossModality WRITE set_crossModality RESET reset_crossModality STORED false) + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) MP mp; @@ -158,7 +162,7 @@ class MatchProbabilityDistance : public Distance { distance->train(src); - const QList labels = src.indexProperty("Subject"); + const QList labels = src.indexProperty(inputVariable); QScopedPointer matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); distance->compare(src, src, matrixOutput.data()); @@ -201,6 +205,7 @@ protected: BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) BR_PROPERTY(bool, gaussian, true) BR_PROPERTY(bool, crossModality, false) + BR_PROPERTY(QString, inputVariable, "Label") }; BR_REGISTER(Distance, MatchProbabilityDistance) @@ -217,10 +222,12 @@ class HeatMapDistance : public Distance Q_PROPERTY(bool gaussian READ get_gaussian WRITE set_gaussian RESET reset_gaussian STORED false) Q_PROPERTY(bool crossModality READ get_crossModality WRITE set_crossModality RESET reset_crossModality STORED false) Q_PROPERTY(int step READ get_step WRITE set_step RESET reset_step STORED false) + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) BR_PROPERTY(bool, gaussian, true) BR_PROPERTY(bool, crossModality, false) BR_PROPERTY(int, step, 1) + BR_PROPERTY(QString, inputVariable, "Label") QList mp; @@ -228,7 +235,7 @@ class HeatMapDistance : public Distance { distance->train(src); - const QList labels = src.indexProperty("Subject"); + const QList labels = src.indexProperty(inputVariable); QList patches; @@ -307,14 +314,16 @@ class UnitDistance : public Distance Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance) Q_PROPERTY(float a READ get_a WRITE set_a RESET reset_a) Q_PROPERTY(float b READ get_b WRITE set_b RESET reset_b) + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) BR_PROPERTY(float, a, 1) BR_PROPERTY(float, b, 0) + BR_PROPERTY(QString, inputVariable, "Label") void train(const TemplateList &templates) { const TemplateList samples = templates.mid(0, 2000); - const QList sampleLabels = samples.indexProperty("Subject"); + const QList sampleLabels = samples.indexProperty(inputVariable); QScopedPointer matrixOutput(MatrixOutput::make(FileList(samples.size()), FileList(samples.size()))); Distance::compare(samples, samples, matrixOutput.data()); diff --git a/openbr/plugins/quantize.cpp b/openbr/plugins/quantize.cpp index 94840f6..e74513f 100644 --- a/openbr/plugins/quantize.cpp +++ b/openbr/plugins/quantize.cpp @@ -120,6 +120,10 @@ BR_REGISTER(Transform, HistEqQuantizationTransform) class BayesianQuantizationDistance : public Distance { Q_OBJECT + + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) + BR_PROPERTY(QString, inputVariable, "Label") + QVector loglikelihoods; static void computeLogLikelihood(const Mat &data, const QList &labels, float *loglikelihood) @@ -150,7 +154,7 @@ class BayesianQuantizationDistance : public Distance qFatal("Expected sigle matrix templates of type CV_8UC1!"); const Mat data = OpenCVUtils::toMat(src.data()); - const QList templateLabels = src.indexProperty("Subject"); + const QList templateLabels = src.indexProperty(inputVariable); loglikelihoods = QVector(data.cols*256, 0); QFutureSynchronizer futures; @@ -343,9 +347,11 @@ class ProductQuantizationTransform : public Transform Q_PROPERTY(int n READ get_n WRITE set_n RESET reset_n STORED false) Q_PROPERTY(br::Distance *distance READ get_distance WRITE set_distance RESET reset_distance STORED false) Q_PROPERTY(bool bayesian READ get_bayesian WRITE set_bayesian RESET reset_bayesian STORED false) + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) BR_PROPERTY(int, n, 2) BR_PROPERTY(br::Distance*, distance, Distance::make("L2", this)) BR_PROPERTY(bool, bayesian, false) + BR_PROPERTY(QString, inputVariable, "Label") quint16 index; QList centers; @@ -474,7 +480,7 @@ private: Mat data = OpenCVUtils::toMat(src.data()); const int step = getStep(data.cols); - const QList labels = src.indexProperty("Subject"); + const QList labels = src.indexProperty(inputVariable); Mat &lut = ProductQuantizationLUTs[index]; lut = Mat(getDims(data.cols), 256*(256+1)/2, CV_32FC1); diff --git a/openbr/plugins/quantize2.cpp b/openbr/plugins/quantize2.cpp index 63ea769..8d05523 100644 --- a/openbr/plugins/quantize2.cpp +++ b/openbr/plugins/quantize2.cpp @@ -19,6 +19,10 @@ namespace br class BayesianQuantizationTransform : public Transform { Q_OBJECT + + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) + BR_PROPERTY(QString, inputVariable, "Label") + QVector thresholds; static void computeThresholdsRecursive(const QVector &cumulativeGenuines, const QVector &cumulativeImpostors, @@ -77,7 +81,7 @@ class BayesianQuantizationTransform : public Transform void train(const TemplateList &src) { const Mat data = OpenCVUtils::toMat(src.data()); - const QList labels = src.indexProperty("Subject"); + const QList labels = src.indexProperty(inputVariable); thresholds = QVector(256*data.cols); diff --git a/openbr/plugins/svm.cpp b/openbr/plugins/svm.cpp index 39efa4d..d4e2aad 100644 --- a/openbr/plugins/svm.cpp +++ b/openbr/plugins/svm.cpp @@ -101,6 +101,8 @@ class SVMTransform : public Transform Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) Q_PROPERTY(float C READ get_C WRITE set_C RESET reset_C STORED false) Q_PROPERTY(float gamma READ get_gamma WRITE set_gamma RESET reset_gamma STORED false) + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) + Q_PROPERTY(QString outputVariable READ get_outputVariable WRITE set_outputVariable RESET reset_outputVariable STORED false) public: enum Kernel { Linear = CvSVM::LINEAR, @@ -119,6 +121,9 @@ private: BR_PROPERTY(Type, type, C_SVC) BR_PROPERTY(float, C, -1) BR_PROPERTY(float, gamma, -1) + BR_PROPERTY(QString, inputVariable, "") + BR_PROPERTY(QString, outputVariable, "") + SVM svm; QHash labelMap; @@ -128,14 +133,15 @@ private: { Mat data = OpenCVUtils::toMat(_data.data()); Mat lab; - // If we are doing regression, assume subject has float values + // If we are doing regression, the input variable should have float + // values if (type == EPS_SVR || type == NU_SVR) { - lab = OpenCVUtils::toMat(File::get(_data, "Subject")); + lab = OpenCVUtils::toMat(File::get(_data, inputVariable)); } - // If we are doing classification, assume subject has discrete values, map them - // and store the mapping data + // If we are doing classification, we should be dealing with discrete + // values. Map them and store the mapping data else { - QList dataLabels = _data.indexProperty("Subject", labelMap, reverseLookup); + QList dataLabels = _data.indexProperty(inputVariable, labelMap, reverseLookup); lab = OpenCVUtils::toMat(dataLabels); } trainSVM(svm, data, lab, kernel, type, C, gamma); @@ -146,9 +152,9 @@ private: dst = src; float prediction = svm.predict(src.m().reshape(1, 1)); if (type == EPS_SVR || type == NU_SVR) - dst.file.set("Subject", prediction); + dst.file.set(outputVariable, prediction); else - dst.file.set("Subject", reverseLookup[prediction]); + dst.file.set(outputVariable, reverseLookup[prediction]); } void store(QDataStream &stream) const @@ -162,6 +168,24 @@ private: loadSVM(svm, stream); stream >> labelMap >> reverseLookup; } + + void init() + { + // Since SVM can do regression or classification, we have to check the problem type before + // specifying target variable names + if (inputVariable.isEmpty()) + { + if (type == EPS_SVR || type == NU_SVR) { + inputVariable = "Regressor"; + if (outputVariable.isEmpty()) + outputVariable = "Regressand"; + } + else + inputVariable = "Label"; + } + if (outputVariable.isEmpty()) + outputVariable = inputVariable; + } }; BR_REGISTER(Transform, SVMTransform) @@ -178,6 +202,8 @@ class SVMDistance : public Distance Q_ENUMS(Type) Q_PROPERTY(Kernel kernel READ get_kernel WRITE set_kernel RESET reset_kernel STORED false) Q_PROPERTY(Type type READ get_type WRITE set_type RESET reset_type STORED false) + Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) + public: enum Kernel { Linear = CvSVM::LINEAR, @@ -194,13 +220,14 @@ public: private: BR_PROPERTY(Kernel, kernel, Linear) BR_PROPERTY(Type, type, EPS_SVR) + BR_PROPERTY(QString, inputVariable, "Label") SVM svm; void train(const TemplateList &src) { const Mat data = OpenCVUtils::toMat(src.data()); - const QList lab = src.indexProperty("Subject"); + const QList lab = src.indexProperty(inputVariable); const int instances = data.rows * (data.rows+1) / 2; Mat deltaData(instances, data.cols, data.type()); diff --git a/scripts/evalAgeRegression-PCSO.sh b/scripts/evalAgeRegression-PCSO.sh index 8d0f286..9e294fe 100755 --- a/scripts/evalAgeRegression-PCSO.sh +++ b/scripts/evalAgeRegression-PCSO.sh @@ -4,8 +4,12 @@ if [ ! -f evalAgeRegression-PCSO.sh ]; then exit fi +export BR="../build/app/br/br -useGui 0" +export PCSO_DIR=/user/pripshare/Databases/FaceDatabases/PCSO/PCSO/ +export ageAlg=AgeRegression + # Create a file list by querying the database -br -quiet -algorithm Identity -enroll "../data/PCSO/PCSO.db[query='SELECT File,Age,PersonID FROM PCSO WHERE Age >= 15 AND AGE <= 75', subset=1:200]" terminal.txt > Input.txt +$BR -quiet -algorithm Identity -enroll "$PCSO_DIR/PCSO.db[query='SELECT File,Age,PersonID FROM PCSO WHERE Age >= 17 AND AGE <= 68', subset=1:200]" terminal.txt > Input.txt # Enroll the file list and evaluate performance -br -algorithm AgeRegression -path ../data/PCSO/img -enroll Input.txt Output.txt -evalRegression Output.txt Input.txt +$BR -algorithm $ageAlg -path $PCSO_DIR/Images -enroll Input.txt Output.txt -evalRegression Output.txt Input.txt Age diff --git a/scripts/evalFaceRecognition-MEDS.sh b/scripts/evalFaceRecognition-MEDS.sh index 231370f..673f674 100755 --- a/scripts/evalFaceRecognition-MEDS.sh +++ b/scripts/evalFaceRecognition-MEDS.sh @@ -20,11 +20,11 @@ if [ ! -e Algorithm_Dataset ]; then fi if [ ! -e MEDS.mask ]; then - br -makeMask ../data/MEDS/sigset/MEDS_frontal_target.xml ../data/MEDS/sigset/MEDS_frontal_query.xml MEDS.mask + br -useGui 0 -makeMask ../data/MEDS/sigset/MEDS_frontal_target.xml ../data/MEDS/sigset/MEDS_frontal_query.xml MEDS.mask fi # Run Algorithm on MEDS -br -algorithm ${ALGORITHM} -path ../data/MEDS/img -compare ../data/MEDS/sigset/MEDS_frontal_target.xml ../data/MEDS/sigset/MEDS_frontal_query.xml ${ALGORITHM}_MEDS.mtx -eval ${ALGORITHM}_MEDS.mtx MEDS.mask Algorithm_Dataset/${ALGORITHM}_MEDS.csv +br -useGui 0 -algorithm ${ALGORITHM} -path ../data/MEDS/img -compare ../data/MEDS/sigset/MEDS_frontal_target.xml ../data/MEDS/sigset/MEDS_frontal_query.xml ${ALGORITHM}_MEDS.mtx -eval ${ALGORITHM}_MEDS.mtx MEDS.mask Algorithm_Dataset/${ALGORITHM}_MEDS.csv # Plot results -br -plot Algorithm_Dataset/*_MEDS.csv MEDS +br -useGui 0 -plot Algorithm_Dataset/*_MEDS.csv MEDS diff --git a/scripts/evalGenderClassification-PCSO.sh b/scripts/evalGenderClassification-PCSO.sh index 881214e..9373645 100755 --- a/scripts/evalGenderClassification-PCSO.sh +++ b/scripts/evalGenderClassification-PCSO.sh @@ -4,8 +4,13 @@ if [ ! -f evalGenderClassification-PCSO.sh ]; then exit fi +export BR=../build/app/br/br +export genderAlg=GenderClassification + +export PCSO_DIR=/user/pripshare/Databases/FaceDatabases/PCSO/PCSO/ + # Create a file list by querying the database -br -quiet -algorithm Identity -enroll "../data/PCSO/PCSO.db[query='SELECT File,Gender,PersonID FROM PCSO', subset=1:8000]" terminal.txt > Input.txt +$BR -useGui 0 -quiet -algorithm Identity -enroll "$PCSO_DIR/PCSO.db[query='SELECT File,Gender,PersonID FROM PCSO', subset=1:8000]" terminal.txt > Input.txt # Enroll the file list and evaluate performance -br -algorithm GenderClassification -path ../data/PCSO/img -enroll Input.txt Output.txt -evalClassification Output.txt Input.txt +$BR -useGui 0 -algorithm $genderAlg -path $PCSO_DIR/Images -enroll Input.txt Output.txt -evalClassification Output.txt Input.txt Gender \ No newline at end of file diff --git a/scripts/trainAgeRegression-PCSO.sh b/scripts/trainAgeRegression-PCSO.sh index c063055..c1f3ef1 100755 --- a/scripts/trainAgeRegression-PCSO.sh +++ b/scripts/trainAgeRegression-PCSO.sh @@ -6,6 +6,11 @@ fi #rm -f ../share/openbr/models/features/FaceClassificationRegistration #rm -f ../share/openbr/models/features/FaceClassificationExtraction -rm -f ../share/openbr/models/algorithms/AgeRegression +#rm -f ../share/openbr/models/algorithms/AgeRegression -br -algorithm AgeRegression -path ../data/PCSO/Images -train "../data/PCSO/PCSO.db[query='SELECT File,Age,PersonID FROM PCSO WHERE Age >= 15 AND AGE <= 75', subset=0:200]" ../share/openbr/models/algorithms/AgeRegression +export BR=../build/app/br/br +export ageAlg=AgeRegression + +export PCSO_DIR=/user/pripshare/Databases/FaceDatabases/PCSO/PCSO/ + +$BR -useGui 0 -algorithm $ageAlg -path $PCSO_DIR/Images -train "$PCSO_DIR/PCSO.db[query='SELECT File,Age,PersonID FROM PCSO WHERE Age >= 17 AND AGE <= 68', subset=0:200]" ../share/openbr/models/algorithms/AgeRegression diff --git a/scripts/trainFaceRecognition-PCSO.sh b/scripts/trainFaceRecognition-PCSO.sh index e67b3aa..ebeb056 100755 --- a/scripts/trainFaceRecognition-PCSO.sh +++ b/scripts/trainFaceRecognition-PCSO.sh @@ -8,6 +8,13 @@ fi #rm -f ../share/openbr/models/features/FaceRecognitionExtraction #rm -f ../share/openbr/models/features/FaceRecognitionEmbedding #rm -f ../share/openbr/models/features/FaceRecognitionQuantization -rm -f ../share/openbr/models/algorithms/FaceRecognition +#rm -f ../share/openbr/models/algorithms/FaceRecognition + +export BR=../build/app/br/br + +export PCSO_DIR=/user/pripshare/Databases/FaceDatabases/PCSO/PCSO/ + + + +$BR -useGui 0 -algorithm FaceRecognition -path "$PCSO_DIR/Images/" -train "$PCSO_DIR/PCSO.db[query='SELECT File,PersonID as Label,PersonID FROM PCSO', subset=0:5:6000]" ../share/openbr/models/algorithms/FaceRecognition -br -algorithm FaceRecognition -path ../data/PCSO/img -train "../data/PCSO/PCSO.db[query='SELECT File,'S'||PersonID,PersonID FROM PCSO', subset=0:5:6000]" ../share/openbr/models/algorithms/FaceRecognition diff --git a/scripts/trainGenderClassification-PCSO.sh b/scripts/trainGenderClassification-PCSO.sh index 91fae04..acb8146 100755 --- a/scripts/trainGenderClassification-PCSO.sh +++ b/scripts/trainGenderClassification-PCSO.sh @@ -6,6 +6,11 @@ fi #rm -f ../share/openbr/models/features/FaceClassificationRegistration #rm -f ../share/openbr/models/features/FaceClassificationExtraction -rm -f ../share/openbr/models/algorithms/GenderClassification +#rm -f ../share/openbr/models/algorithms/GenderClassification -br -algorithm GenderClassification -path ../data/PCSO/Images -train "../data/PCSO/PCSO.db[query='SELECT File,Gender,PersonID FROM PCSO', subset=0:8000]" ../share/openbr/models/algorithms/GenderClassification +export BR=../build/app/br/br +export genderAlg=GenderClassification + +export PCSO_DIR=/user/pripshare/Databases/FaceDatabases/PCSO/PCSO/ + +$BR -useGui 0 -algorithm $genderAlg -path $PCSO_DIR/Images -train "$PCSO_DIR/PCSO.db[query='SELECT File,Gender,PersonID FROM PCSO', subset=0:8000]" ../share/openbr/models/algorithms/GenderClassification diff --git a/share/openbr/models b/share/openbr/models index dccddf4..a73d510 160000 --- a/share/openbr/models +++ b/share/openbr/models @@ -1 +1 @@ -Subproject commit dccddf4dd3a5239911807beeec39308f8890b1e4 +Subproject commit a73d51013ea05f263e88a28539393159fff2183e