Commit 1246894606c8960d15496aa0dffb99afe5747d46
1 parent
29c8c3f5
LDA support for two class hyperplane classificiation
Showing
1 changed file
with
34 additions
and
1 deletions
openbr/plugins/eigen3.cpp
| @@ -302,11 +302,13 @@ class LDATransform : public Transform | @@ -302,11 +302,13 @@ class LDATransform : public Transform | ||
| 302 | Q_PROPERTY(int directLDA READ get_directLDA WRITE set_directLDA RESET reset_directLDA STORED false) | 302 | Q_PROPERTY(int directLDA READ get_directLDA WRITE set_directLDA RESET reset_directLDA STORED false) |
| 303 | Q_PROPERTY(float directDrop READ get_directDrop WRITE set_directDrop RESET reset_directDrop STORED false) | 303 | Q_PROPERTY(float directDrop READ get_directDrop WRITE set_directDrop RESET reset_directDrop STORED false) |
| 304 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) | 304 | Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) |
| 305 | + Q_PROPERTY(bool isBinary READ get_isBinary WRITE set_isBinary RESET reset_isBinary STORED false) | ||
| 305 | BR_PROPERTY(float, pcaKeep, 0.98) | 306 | BR_PROPERTY(float, pcaKeep, 0.98) |
| 306 | BR_PROPERTY(bool, pcaWhiten, false) | 307 | BR_PROPERTY(bool, pcaWhiten, false) |
| 307 | BR_PROPERTY(int, directLDA, 0) | 308 | BR_PROPERTY(int, directLDA, 0) |
| 308 | BR_PROPERTY(float, directDrop, 0.1) | 309 | BR_PROPERTY(float, directDrop, 0.1) |
| 309 | BR_PROPERTY(QString, inputVariable, "Label") | 310 | BR_PROPERTY(QString, inputVariable, "Label") |
| 311 | + BR_PROPERTY(bool, isBinary, false) | ||
| 310 | 312 | ||
| 311 | int dimsOut; | 313 | int dimsOut; |
| 312 | Eigen::VectorXf mean; | 314 | Eigen::VectorXf mean; |
| @@ -316,7 +318,6 @@ class LDATransform : public Transform | @@ -316,7 +318,6 @@ class LDATransform : public Transform | ||
| 316 | { | 318 | { |
| 317 | // creates "Label" | 319 | // creates "Label" |
| 318 | TemplateList trainingSet = TemplateList::relabel(_trainingSet, inputVariable); | 320 | TemplateList trainingSet = TemplateList::relabel(_trainingSet, inputVariable); |
| 319 | - | ||
| 320 | int instances = trainingSet.size(); | 321 | int instances = trainingSet.size(); |
| 321 | 322 | ||
| 322 | // Perform PCA dimensionality reduction | 323 | // Perform PCA dimensionality reduction |
| @@ -450,6 +451,34 @@ class LDATransform : public Transform | @@ -450,6 +451,34 @@ class LDATransform : public Transform | ||
| 450 | // Compute final projection matrix | 451 | // Compute final projection matrix |
| 451 | projection = ((space2.eVecs.transpose() * space1.eVecs.transpose()) * pca.eVecs.transpose()).transpose(); | 452 | projection = ((space2.eVecs.transpose() * space1.eVecs.transpose()) * pca.eVecs.transpose()).transpose(); |
| 452 | dimsOut = dim2; | 453 | dimsOut = dim2; |
| 454 | + | ||
| 455 | + if (isBinary) { | ||
| 456 | + assert(dimsOut == 1); | ||
| 457 | + TemplateList projected; | ||
| 458 | + float posVal = 0; | ||
| 459 | + float negVal = 0; | ||
| 460 | + for (int i = 0; i < trainingSet.size(); i++) { | ||
| 461 | + Template t; | ||
| 462 | + project(trainingSet[i],t); | ||
| 463 | + //Note: the positive class is assumed to be 0 b/c it will | ||
| 464 | + // typically be the first gallery template in the TemplateList structure | ||
| 465 | + if (classes[i] == 0) | ||
| 466 | + posVal += t.m().at<float>(0,0); | ||
| 467 | + else if (classes[i] == 1) | ||
| 468 | + negVal += t.m().at<float>(0,0); | ||
| 469 | + else | ||
| 470 | + qFatal("Binary mode only supports two class problems."); | ||
| 471 | + } | ||
| 472 | + posVal /= classCounts[0]; | ||
| 473 | + negVal /= classCounts[1]; | ||
| 474 | + | ||
| 475 | + if (posVal < negVal) { | ||
| 476 | + //Ensure positive value is supposed to be > 0 after projection | ||
| 477 | + Eigen::MatrixXf invert = Eigen::MatrixXf::Ones(dimsIn,1); | ||
| 478 | + invert *= -1; | ||
| 479 | + projection = invert.transpose() * projection; | ||
| 480 | + } | ||
| 481 | + } | ||
| 453 | } | 482 | } |
| 454 | 483 | ||
| 455 | void project(const Template &src, Template &dst) const | 484 | void project(const Template &src, Template &dst) const |
| @@ -462,6 +491,10 @@ class LDATransform : public Transform | @@ -462,6 +491,10 @@ class LDATransform : public Transform | ||
| 462 | 491 | ||
| 463 | // Do projection | 492 | // Do projection |
| 464 | outMap = projection.transpose() * (inMap - mean); | 493 | outMap = projection.transpose() * (inMap - mean); |
| 494 | + | ||
| 495 | + if (isBinary) { | ||
| 496 | + dst.file.set("conf",dst.m().at<float>(0,0)); | ||
| 497 | + } | ||
| 465 | } | 498 | } |
| 466 | 499 | ||
| 467 | void store(QDataStream &stream) const | 500 | void store(QDataStream &stream) const |