Commit 1246894606c8960d15496aa0dffb99afe5747d46

Authored by Brendan Klare
1 parent 29c8c3f5

LDA support for two class hyperplane classificiation

Showing 1 changed file with 34 additions and 1 deletions
openbr/plugins/eigen3.cpp
@@ -302,11 +302,13 @@ class LDATransform : public Transform @@ -302,11 +302,13 @@ class LDATransform : public Transform
302 Q_PROPERTY(int directLDA READ get_directLDA WRITE set_directLDA RESET reset_directLDA STORED false) 302 Q_PROPERTY(int directLDA READ get_directLDA WRITE set_directLDA RESET reset_directLDA STORED false)
303 Q_PROPERTY(float directDrop READ get_directDrop WRITE set_directDrop RESET reset_directDrop STORED false) 303 Q_PROPERTY(float directDrop READ get_directDrop WRITE set_directDrop RESET reset_directDrop STORED false)
304 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false) 304 Q_PROPERTY(QString inputVariable READ get_inputVariable WRITE set_inputVariable RESET reset_inputVariable STORED false)
  305 + Q_PROPERTY(bool isBinary READ get_isBinary WRITE set_isBinary RESET reset_isBinary STORED false)
305 BR_PROPERTY(float, pcaKeep, 0.98) 306 BR_PROPERTY(float, pcaKeep, 0.98)
306 BR_PROPERTY(bool, pcaWhiten, false) 307 BR_PROPERTY(bool, pcaWhiten, false)
307 BR_PROPERTY(int, directLDA, 0) 308 BR_PROPERTY(int, directLDA, 0)
308 BR_PROPERTY(float, directDrop, 0.1) 309 BR_PROPERTY(float, directDrop, 0.1)
309 BR_PROPERTY(QString, inputVariable, "Label") 310 BR_PROPERTY(QString, inputVariable, "Label")
  311 + BR_PROPERTY(bool, isBinary, false)
310 312
311 int dimsOut; 313 int dimsOut;
312 Eigen::VectorXf mean; 314 Eigen::VectorXf mean;
@@ -316,7 +318,6 @@ class LDATransform : public Transform @@ -316,7 +318,6 @@ class LDATransform : public Transform
316 { 318 {
317 // creates "Label" 319 // creates "Label"
318 TemplateList trainingSet = TemplateList::relabel(_trainingSet, inputVariable); 320 TemplateList trainingSet = TemplateList::relabel(_trainingSet, inputVariable);
319 -  
320 int instances = trainingSet.size(); 321 int instances = trainingSet.size();
321 322
322 // Perform PCA dimensionality reduction 323 // Perform PCA dimensionality reduction
@@ -450,6 +451,34 @@ class LDATransform : public Transform @@ -450,6 +451,34 @@ class LDATransform : public Transform
450 // Compute final projection matrix 451 // Compute final projection matrix
451 projection = ((space2.eVecs.transpose() * space1.eVecs.transpose()) * pca.eVecs.transpose()).transpose(); 452 projection = ((space2.eVecs.transpose() * space1.eVecs.transpose()) * pca.eVecs.transpose()).transpose();
452 dimsOut = dim2; 453 dimsOut = dim2;
  454 +
  455 + if (isBinary) {
  456 + assert(dimsOut == 1);
  457 + TemplateList projected;
  458 + float posVal = 0;
  459 + float negVal = 0;
  460 + for (int i = 0; i < trainingSet.size(); i++) {
  461 + Template t;
  462 + project(trainingSet[i],t);
  463 + //Note: the positive class is assumed to be 0 b/c it will
  464 + // typically be the first gallery template in the TemplateList structure
  465 + if (classes[i] == 0)
  466 + posVal += t.m().at<float>(0,0);
  467 + else if (classes[i] == 1)
  468 + negVal += t.m().at<float>(0,0);
  469 + else
  470 + qFatal("Binary mode only supports two class problems.");
  471 + }
  472 + posVal /= classCounts[0];
  473 + negVal /= classCounts[1];
  474 +
  475 + if (posVal < negVal) {
  476 + //Ensure positive value is supposed to be > 0 after projection
  477 + Eigen::MatrixXf invert = Eigen::MatrixXf::Ones(dimsIn,1);
  478 + invert *= -1;
  479 + projection = invert.transpose() * projection;
  480 + }
  481 + }
453 } 482 }
454 483
455 void project(const Template &src, Template &dst) const 484 void project(const Template &src, Template &dst) const
@@ -462,6 +491,10 @@ class LDATransform : public Transform @@ -462,6 +491,10 @@ class LDATransform : public Transform
462 491
463 // Do projection 492 // Do projection
464 outMap = projection.transpose() * (inMap - mean); 493 outMap = projection.transpose() * (inMap - mean);
  494 +
  495 + if (isBinary) {
  496 + dst.file.set("conf",dst.m().at<float>(0,0));
  497 + }
465 } 498 }
466 499
467 void store(QDataStream &stream) const 500 void store(QDataStream &stream) const