Commit b6ff0861579e9e4dd3e68261d2f1bb63263675b9
1 parent
49b51c19
progress on cross validation framework
Showing
7 changed files
with
44 additions
and
32 deletions
sdk/core/plot.cpp
| ... | ... | @@ -124,21 +124,23 @@ float Evaluate(const QString &simmat, const QString &mask, const QString &csv) |
| 124 | 124 | |
| 125 | 125 | // Make comparisons |
| 126 | 126 | QList<Comparison> comparisons; comparisons.reserve(scores.rows*scores.cols); |
| 127 | - int genuineCount = 0, impostorCount = 0; | |
| 127 | + int genuineCount = 0, impostorCount = 0, numNaNs = 0; | |
| 128 | 128 | for (int i=0; i<scores.rows; i++) { |
| 129 | 129 | for (int j=0; j<scores.cols; j++) { |
| 130 | 130 | const BEE::Mask_t mask_val = masks.at<BEE::Mask_t>(i,j); |
| 131 | 131 | const BEE::Simmat_t simmat_val = scores.at<BEE::Simmat_t>(i,j); |
| 132 | 132 | if ((mask_val == BEE::DontCare) || |
| 133 | 133 | (simmat_val == -std::numeric_limits<float>::max())) continue; |
| 134 | + if (simmat_val != simmat_val) { numNaNs++; continue; } | |
| 134 | 135 | comparisons.append(Comparison(simmat_val, j, i, mask_val == BEE::Match)); |
| 135 | 136 | if (comparisons.last().genuine) genuineCount++; |
| 136 | 137 | else impostorCount++; |
| 137 | 138 | } |
| 138 | 139 | } |
| 139 | 140 | |
| 140 | - if (genuineCount == 0) qFatal("No genuine scores."); | |
| 141 | - if (impostorCount == 0) qFatal("No impostor scores."); | |
| 141 | + if (numNaNs > 0) qWarning("Encountered %d NaN scores!", numNaNs); | |
| 142 | + if (genuineCount == 0) qFatal("No genuine scores!"); | |
| 143 | + if (impostorCount == 0) qFatal("No impostor scores!"); | |
| 142 | 144 | |
| 143 | 145 | std::sort(comparisons.begin(), comparisons.end()); |
| 144 | 146 | |
| ... | ... | @@ -324,7 +326,8 @@ struct RPlot |
| 324 | 326 | pivotItems = QVector< QSet<QString> >(pivotHeaders.size()); |
| 325 | 327 | foreach (const QString &fileName, files) { |
| 326 | 328 | QStringList pivots = getPivots(fileName, false); |
| 327 | - if (pivots.size() != pivotHeaders.size()) qFatal("plot.cpp::initializeR pivot size mismatch."); | |
| 329 | + if (pivots.size() != pivotHeaders.size()) | |
| 330 | + qFatal("Pivot size mismatch: [%s] [%s]", qPrintable(pivotHeaders.join(",")), qPrintable(pivots.join(","))); | |
| 328 | 331 | file.write(qPrintable(QString("tmp <- read.csv(\"%1\")\n").arg(fileName).replace("\\", "\\\\"))); |
| 329 | 332 | for (int i=0; i<pivots.size(); i++) { |
| 330 | 333 | pivotItems[i].insert(pivots[i]); | ... | ... |
sdk/openbr_plugin.cpp
| ... | ... | @@ -451,6 +451,19 @@ TemplateList TemplateList::fromInput(const br::File &input) |
| 451 | 451 | return templates; |
| 452 | 452 | } |
| 453 | 453 | |
| 454 | +TemplateList TemplateList::relabel(const TemplateList &tl) | |
| 455 | +{ | |
| 456 | + QHash<int,int> labels; | |
| 457 | + foreach (int label, tl.labels<int>()) | |
| 458 | + if (!labels.contains(label)) | |
| 459 | + labels.insert(label, labels.size()); | |
| 460 | + | |
| 461 | + TemplateList result = tl; | |
| 462 | + for (int i=0; i<result.size(); i++) | |
| 463 | + result[i].file.setLabel(labels[result[i].file.label()]); | |
| 464 | + return result; | |
| 465 | +} | |
| 466 | + | |
| 454 | 467 | /* Object - public methods */ |
| 455 | 468 | QStringList Object::parameters() const |
| 456 | 469 | { |
| ... | ... | @@ -1053,10 +1066,8 @@ static TemplateList Downsample(const TemplateList &templates, const Transform *t |
| 1053 | 1066 | |
| 1054 | 1067 | std::random_shuffle(indices.begin(), indices.end()); |
| 1055 | 1068 | const int max = atLeast ? indices.size() : std::min(indices.size(), instances); |
| 1056 | - for (int j=0; j<max; j++) { | |
| 1069 | + for (int j=0; j<max; j++) | |
| 1057 | 1070 | downsample.append(templates.value(indices[j])); |
| 1058 | - if (transform->relabel) downsample.last().file.insert("Label", i); | |
| 1059 | - } | |
| 1060 | 1071 | } |
| 1061 | 1072 | |
| 1062 | 1073 | if (transform->fraction < 1) { |
| ... | ... | @@ -1168,7 +1179,6 @@ Transform::Transform(bool _independent, bool _trainable) |
| 1168 | 1179 | { |
| 1169 | 1180 | independent = _independent; |
| 1170 | 1181 | trainable = _trainable; |
| 1171 | - relabel = false; | |
| 1172 | 1182 | classes = std::numeric_limits<int>::max(); |
| 1173 | 1183 | instances = std::numeric_limits<int>::max(); |
| 1174 | 1184 | fraction = 1; |
| ... | ... | @@ -1222,7 +1232,6 @@ Transform *Transform::make(QString str, QObject *parent) |
| 1222 | 1232 | Transform *Transform::clone() const |
| 1223 | 1233 | { |
| 1224 | 1234 | Transform *clone = Factory<Transform>::make(file.flat()); |
| 1225 | - clone->relabel = relabel; | |
| 1226 | 1235 | clone->classes = classes; |
| 1227 | 1236 | clone->instances = instances; |
| 1228 | 1237 | clone->fraction = fraction; | ... | ... |
sdk/openbr_plugin.h
| ... | ... | @@ -323,7 +323,7 @@ struct TemplateList : public QList<Template> |
| 323 | 323 | TemplateList() : uniform(false) {} |
| 324 | 324 | TemplateList(const QList<Template> &templates) : uniform(false) { append(templates); } /*!< \brief Initialize the template list from another template list. */ |
| 325 | 325 | BR_EXPORT static TemplateList fromInput(const File &input); /*!< \brief Create a template list from a br::Input. */ |
| 326 | - | |
| 326 | + BR_EXPORT static TemplateList relabel(const TemplateList &tl); /*!< \brief Ensure labels are in the range [0,numClasses-1]. */ | |
| 327 | 327 | /*! |
| 328 | 328 | * \brief Returns the total number of bytes in all the templates. |
| 329 | 329 | */ |
| ... | ... | @@ -932,11 +932,9 @@ class BR_EXPORT Transform : public Object |
| 932 | 932 | Q_OBJECT |
| 933 | 933 | |
| 934 | 934 | public: |
| 935 | - Q_PROPERTY(bool relabel READ get_relabel WRITE set_relabel RESET reset_relabel STORED false) | |
| 936 | 935 | Q_PROPERTY(int classes READ get_classes WRITE set_classes RESET reset_classes STORED false) |
| 937 | 936 | Q_PROPERTY(int instances READ get_instances WRITE set_instances RESET reset_instances STORED false) |
| 938 | 937 | Q_PROPERTY(float fraction READ get_fraction WRITE set_fraction RESET reset_fraction STORED false) |
| 939 | - BR_PROPERTY(bool, relabel, false) | |
| 940 | 938 | BR_PROPERTY(int, classes, std::numeric_limits<int>::max()) |
| 941 | 939 | BR_PROPERTY(int, instances, std::numeric_limits<int>::max()) |
| 942 | 940 | BR_PROPERTY(float, fraction, 1) | ... | ... |
sdk/plugins/algorithms.cpp
| ... | ... | @@ -64,7 +64,7 @@ class AlgorithmsInitializer : public Initializer |
| 64 | 64 | Globals->abbreviations.insert("DenseSIFT", "(Grid(10,10)+SIFTDescriptor(12)+ByRow)"); |
| 65 | 65 | Globals->abbreviations.insert("FaceRecognitionRegistration", "(ASEFEyes+Affine(88,88,0.25,0.35)+FTE(DFFS,instances=1))"); |
| 66 | 66 | Globals->abbreviations.insert("FaceRecognitionExtraction", "(Mask+DenseSIFT/DenseLBP+PCA(0.95,instances=1)+Normalize(L2)+Cat)"); |
| 67 | - Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+LDA(0.98,instances=-2,relabel=true)+Cat+PCA(768,instances=1))"); | |
| 67 | + Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+LDA(0.98,instances=-2)+Cat+PCA(768,instances=1))"); | |
| 68 | 68 | Globals->abbreviations.insert("FaceRecognitionQuantization", "(Normalize(L1)+Quantize)"); |
| 69 | 69 | Globals->abbreviations.insert("FaceClassificationRegistration", "(ASEFEyes+Affine(56,72,0.33,0.45)+FTE(DFFS))"); |
| 70 | 70 | Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+PCA(0.95,instances=-1)+Cat)"); | ... | ... |
sdk/plugins/eigen3.cpp
| ... | ... | @@ -264,8 +264,11 @@ class LDATransform : public Transform |
| 264 | 264 | Eigen::VectorXf mean; |
| 265 | 265 | Eigen::MatrixXf projection; |
| 266 | 266 | |
| 267 | - void train(const TemplateList &trainingSet) | |
| 267 | + void train(const TemplateList &_trainingSet) | |
| 268 | 268 | { |
| 269 | + TemplateList trainingSet = _trainingSet; | |
| 270 | + trainingSet = TemplateList::relabel(trainingSet); | |
| 271 | + | |
| 269 | 272 | int instances = trainingSet.size(); |
| 270 | 273 | |
| 271 | 274 | // Perform PCA dimensionality reduction |
| ... | ... | @@ -276,6 +279,7 @@ class LDATransform : public Transform |
| 276 | 279 | |
| 277 | 280 | TemplateList ldaTrainingSet; |
| 278 | 281 | static_cast<Transform*>(&pca)->project(trainingSet, ldaTrainingSet); |
| 282 | + ldaTrainingSet = TemplateList::relabel(ldaTrainingSet); | |
| 279 | 283 | |
| 280 | 284 | int dimsIn = ldaTrainingSet.first().m().rows * ldaTrainingSet.first().m().cols; |
| 281 | 285 | ... | ... |
sdk/plugins/quality.cpp
| ... | ... | @@ -91,7 +91,7 @@ struct KDE |
| 91 | 91 | bins.append(Common::KernelDensityEstimation(scores, min + (max-min)*i/(size-1), h)); |
| 92 | 92 | } |
| 93 | 93 | |
| 94 | - float operator()(float score, bool gaussian = false) const | |
| 94 | + float operator()(float score, bool gaussian = true) const | |
| 95 | 95 | { |
| 96 | 96 | if (gaussian) return 1/(stddev*sqrt(2*CV_PI))*exp(-0.5*pow((score-mean)/stddev, 2)); |
| 97 | 97 | if (score <= min) return bins.first(); |
| ... | ... | @@ -120,12 +120,12 @@ struct MP |
| 120 | 120 | MP() {} |
| 121 | 121 | MP(const QList<float> &genuineScores, const QList<float> &impostorScores) |
| 122 | 122 | : genuine(genuineScores), impostor(impostorScores) {} |
| 123 | - float operator()(float score, bool gaussian = false, bool log = false) const | |
| 123 | + float operator()(float score, bool gaussian = true) const | |
| 124 | 124 | { |
| 125 | 125 | const float g = genuine(score, gaussian); |
| 126 | 126 | const float s = g / (impostor(score, gaussian) + g); |
| 127 | - if (log) return (std::max(std::log10(s), -10.f) + 10)/10; | |
| 128 | - else return s; | |
| 127 | + if (s != s) qDebug() << "!!" << g << impostor(score, gaussian) << score << genuine.mean << genuine.stddev << impostor.mean << impostor.stddev; | |
| 128 | + return s; | |
| 129 | 129 | } |
| 130 | 130 | }; |
| 131 | 131 | |
| ... | ... | @@ -150,11 +150,9 @@ class MPDistance : public Distance |
| 150 | 150 | Q_PROPERTY(br::Distance* distance READ get_distance WRITE set_distance RESET reset_distance STORED false) |
| 151 | 151 | Q_PROPERTY(QString binKey READ get_binKey WRITE set_binKey RESET reset_binKey STORED false) |
| 152 | 152 | Q_PROPERTY(bool gaussian READ get_gaussian WRITE set_gaussian RESET reset_gaussian STORED false) |
| 153 | - Q_PROPERTY(bool log READ get_log WRITE set_log RESET reset_log STORED false) | |
| 154 | 153 | BR_PROPERTY(br::Distance*, distance, make("Dist(L2)")) |
| 155 | 154 | BR_PROPERTY(QString, binKey, "") |
| 156 | - BR_PROPERTY(bool, gaussian, false) | |
| 157 | - BR_PROPERTY(bool, log, false) | |
| 155 | + BR_PROPERTY(bool, gaussian, true) | |
| 158 | 156 | |
| 159 | 157 | QHash<QString, MP> mps; |
| 160 | 158 | |
| ... | ... | @@ -170,6 +168,7 @@ class MPDistance : public Distance |
| 170 | 168 | for (int i=0; i<src.size(); i++) |
| 171 | 169 | for (int j=0; j<i; j++) { |
| 172 | 170 | const float score = memoryOutput.data()->data.at<float>(i, j); |
| 171 | + if (score == -std::numeric_limits<float>::max()) continue; | |
| 173 | 172 | const QString bin = src[i].file.getString(binKey, ""); |
| 174 | 173 | if (labels[i] == labels[j]) genuineScores[bin].append(score); |
| 175 | 174 | else impostorScores[bin].append(score); |
| ... | ... | @@ -181,7 +180,9 @@ class MPDistance : public Distance |
| 181 | 180 | |
| 182 | 181 | float compare(const Template &target, const Template &query) const |
| 183 | 182 | { |
| 184 | - return mps[query.file.getString(binKey, "")](distance->compare(target, query), gaussian, log); | |
| 183 | + float rawScore = distance->compare(target, query); | |
| 184 | + if (rawScore == -std::numeric_limits<float>::max()) return rawScore; | |
| 185 | + return mps[query.file.getString(binKey, "")](rawScore, gaussian); | |
| 185 | 186 | } |
| 186 | 187 | |
| 187 | 188 | void store(QDataStream &stream) const | ... | ... |
sdk/plugins/validate.cpp
| ... | ... | @@ -9,37 +9,33 @@ namespace br |
| 9 | 9 | * \brief Cross validate a trainable transform. |
| 10 | 10 | * \author Josh Klontz \cite jklontz |
| 11 | 11 | */ |
| 12 | -class CrossValidateTransform : public Transform | |
| 12 | +class CrossValidateTransform : public MetaTransform | |
| 13 | 13 | { |
| 14 | 14 | Q_OBJECT |
| 15 | 15 | Q_PROPERTY(QString description READ get_description WRITE set_description RESET reset_description STORED false) |
| 16 | 16 | Q_PROPERTY(QList<br::Transform*> transforms READ get_transforms WRITE set_transforms RESET reset_transforms) |
| 17 | 17 | BR_PROPERTY(QString, description, "Identity") |
| 18 | - BR_PROPERTY(QList<br::Transform*>, transforms, QList<br::Transform*>() << make(description)) | |
| 18 | + BR_PROPERTY(QList<br::Transform*>, transforms, QList<br::Transform*>()) | |
| 19 | 19 | |
| 20 | 20 | void train(const TemplateList &data) |
| 21 | 21 | { |
| 22 | - if (!transforms.first()->trainable) | |
| 23 | - return; | |
| 24 | - | |
| 25 | 22 | int numPartitions = 0; |
| 26 | 23 | QList<int> partitions; partitions.reserve(data.size()); |
| 27 | 24 | foreach (const File &file, data.files()) { |
| 28 | 25 | partitions.append(file.getInt("Cross_Validation_Partition", 0)); |
| 29 | - numPartitions = std::max(numPartitions, partitions.last()); | |
| 26 | + numPartitions = std::max(numPartitions, partitions.last()+1); | |
| 30 | 27 | } |
| 31 | 28 | |
| 29 | + while (transforms.size() < numPartitions) | |
| 30 | + transforms.append(make(description)); | |
| 31 | + | |
| 32 | 32 | if (numPartitions < 2) { |
| 33 | 33 | transforms.first()->train(data); |
| 34 | 34 | return; |
| 35 | 35 | } |
| 36 | 36 | |
| 37 | - while (transforms.size() < numPartitions) | |
| 38 | - transforms.append(make(description)); | |
| 39 | - | |
| 40 | 37 | QList< QFuture<void> > futures; |
| 41 | 38 | for (int i=0; i<numPartitions; i++) { |
| 42 | - qDebug() << "!!" << transforms[i]->description(); | |
| 43 | 39 | TemplateList partitionedData = data; |
| 44 | 40 | for (int j=partitionedData.size()-1; j>=0; j--) |
| 45 | 41 | if (partitions[j] == i) |
| ... | ... | @@ -47,6 +43,7 @@ class CrossValidateTransform : public Transform |
| 47 | 43 | if (Globals->parallelism) futures.append(QtConcurrent::run(transforms[i], &Transform::train, partitionedData)); |
| 48 | 44 | else transforms[i]->train(partitionedData); |
| 49 | 45 | } |
| 46 | + Globals->trackFutures(futures); | |
| 50 | 47 | } |
| 51 | 48 | |
| 52 | 49 | void project(const Template &src, Template &dst) const | ... | ... |