Commit ce6fe757524874f2b87f4a8c770fbe9a0f4b2503

Authored by Josh Klontz
2 parents f239ab13 f65ab061

Merge pull request #179 from dgcrouse/master

Added training for cascade classifiers
openbr/plugins/cascade.cpp
... ... @@ -19,12 +19,144 @@
19 19 #include "openbr_internal.h"
20 20 #include "openbr/core/opencvutils.h"
21 21 #include "openbr/core/resource.h"
  22 +#include <QProcess>
22 23  
23 24 using namespace cv;
  25 +
  26 +struct TrainParams
  27 +{
  28 + QString data; // REQUIRED: Filepath to store trained classifier
  29 + QString vec; // REQUIRED: Filepath to store vector of positive samples, default "vector"
  30 + QString img; // Filepath to source object image. Either this or info is REQUIRED
  31 + QString info; // Description file of source images. Either this or img is REQUIRED
  32 + QString bg; // REQUIRED: Filepath to background list file
  33 + int num; // Number of samples to generate
  34 + int bgcolor; // Background color supplied image (via img)
  35 + int bgthresh; // Threshold to determine bgcolor match
  36 + bool inv; // Invert colors
  37 + bool randinv; // Randomly invert colors
  38 + int maxidev; // Max intensity deviation of foreground pixels
  39 + double maxxangle; // Maximum rotation angle (X)
  40 + double maxyangle; // Maximum rotation angle (Y)
  41 + double maxzangle; // Maximum rotation angle (Z)
  42 + bool show; // Show generated samples
  43 + int w; // REQUIRED: Sample width
  44 + int h; // REQUIRED: Sample height
  45 + int numPos; // Number of positive samples
  46 + int numNeg; // Number of negative samples
  47 + int numStages; // Number of stages
  48 + int precalcValBufSize; // Precalculated val buffer size in Mb
  49 + int precalcIdxBufSize; // Precalculated index buffer size in Mb
  50 + bool baseFormatSave; // Save in old format
  51 + QString stageType; // Stage type (BOOST)
  52 + QString featureType; // Feature type (HAAR, LBP)
  53 + QString bt; // Boosted classifier type (DAB, RAB, LB, GAB)
  54 + double minHitRate; // Minimal hit rate per stage
  55 + double maxFalseAlarmRate; // Max false alarm rate per stage
  56 + double weightTrimRate; // Weight for trimming
  57 + int maxDepth; // Max weak tree depth
  58 + int maxWeakCount; // Max weak tree count per stage
  59 + QString mode; // Haar feature mode (BASIC, CORE, ALL)
24 60  
25   -namespace br
  61 + TrainParams()
  62 + {
  63 + num = -1;
  64 + maxidev = -1;
  65 + maxxangle = -1;
  66 + maxyangle = -1;
  67 + maxzangle = -1;
  68 + w = -1;
  69 + h = -1;
  70 + numPos = -1;
  71 + numNeg = -1;
  72 + numStages = -1;
  73 + precalcValBufSize = -1;
  74 + precalcIdxBufSize = -1;
  75 + minHitRate = -1;
  76 + maxFalseAlarmRate = -1;
  77 + weightTrimRate = -1;
  78 + maxDepth = -1;
  79 + maxWeakCount = -1;
  80 + inv = false;
  81 + randinv = false;
  82 + show = false;
  83 + baseFormatSave = false;
  84 + vec = "vector.vec";
  85 + bgcolor = -1;
  86 + bgthresh = -1;
  87 + }
  88 +};
  89 +
  90 +static QStringList buildTrainingArgs(const TrainParams &params)
  91 +{
  92 + QStringList args;
  93 + if (params.data != "") args << "-data" << params.data;
  94 + else qFatal("Must specify storage location for cascade");
  95 + if (params.vec != "") args << "-vec" << params.vec;
  96 + else qFatal("Must specify location of positive vector");
  97 + if (params.bg != "") args << "-bg" << params.bg;
  98 + else qFatal("Must specify negative images");
  99 + if (params.numPos >= 0) args << "-numPos" << QString::number(params.numPos);
  100 + if (params.numNeg >= 0) args << "-numNeg" << QString::number(params.numNeg);
  101 + if (params.numStages >= 0) args << "-numStages" << QString::number(params.numStages);
  102 + if (params.precalcValBufSize >= 0) args << "-precalcValBufSize" << QString::number(params.precalcValBufSize);
  103 + if (params.precalcIdxBufSize >= 0) args << "-precalcIdxBufSize" << QString::number(params.precalcIdxBufSize);
  104 + if (params.baseFormatSave) args << "-baseFormatSave";
  105 + if (params.stageType != "") args << "-stageType" << params.stageType;
  106 + if (params.featureType != "") args << "-featureType" << params.featureType;
  107 + if (params.w >= 0) args << "-w" << QString::number(params.w);
  108 + else qFatal("Must specify width");
  109 + if (params.h >= 0) args << "-h" << QString::number(params.h);
  110 + else qFatal("Must specify height");
  111 + if (params.bt != "") args << "-bt" << params.bt;
  112 + if (params.minHitRate >= 0) args << "-minHitRate" << QString::number(params.minHitRate);
  113 + if (params.maxFalseAlarmRate >= 0) args << "-maxFalseAlarmRate" << QString::number(params.maxFalseAlarmRate);
  114 + if (params.weightTrimRate >= 0) args << "-weightTrimRate" << QString::number(params.weightTrimRate);
  115 + if (params.maxDepth >= 0) args << "-maxDepth" << QString::number(params.maxDepth);
  116 + if (params.maxWeakCount >= 0) args << "-maxWeakCount" << QString::number(params.maxWeakCount);
  117 + if (params.mode != "") args << "-mode" << params.mode;
  118 + return args;
  119 +}
  120 +
  121 +static QStringList buildSampleArgs(const TrainParams &params)
  122 +{
  123 + QStringList args;
  124 + if (params.vec != "") args << "-vec" << params.vec;
  125 + else qFatal("Must specify location of positive vector");
  126 + if (params.img != "") args << "-img" << params.img;
  127 + else if (params.info != "") args << "-info" << params.info;
  128 + else qFatal("Must specify positive images");
  129 + if (params.bg != "") args << "-bg" << params.bg;
  130 + if (params.num > 0) args << "-num" << QString::number(params.num);
  131 + if (params.bgcolor >=0 ) args << "-bgcolor" << QString::number(params.bgcolor);
  132 + if (params.bgthresh >= 0) args << "-bgthresh" << QString::number(params.bgthresh);
  133 + if (params.maxidev >= 0) args << "-maxidev" << QString::number(params.maxidev);
  134 + if (params.maxxangle >= 0) args << "-maxxangle" << QString::number(params.maxxangle);
  135 + if (params.maxyangle >= 0) args << "-maxyangle" << QString::number(params.maxyangle);
  136 + if (params.maxzangle >= 0) args << "-maxzangle" << QString::number(params.maxzangle);
  137 + if (params.w >= 0) args << "-w" << QString::number(params.w);
  138 + if (params.h >= 0) args << "-h" << QString::number(params.h);
  139 + if (params.show) args << "-show";
  140 + if (params.inv) args << "-inv";
  141 + if (params.randinv) args << "-randinv";
  142 + return args;
  143 +}
  144 +
  145 +static void genSamples(const TrainParams &params)
26 146 {
  147 + const QStringList cmdArgs = buildSampleArgs(params);
  148 + QProcess::execute("opencv_createsamples",cmdArgs);
  149 +}
27 150  
  151 +static void trainCascade(const TrainParams &params)
  152 +{
  153 + const QStringList cmdArgs = buildTrainingArgs(params);
  154 + QProcess::execute("opencv_traincascade", cmdArgs);
  155 +}
  156 +
  157 +namespace br
  158 +{
  159 +
28 160 class CascadeResourceMaker : public ResourceMaker<CascadeClassifier>
29 161 {
30 162 QString file;
... ... @@ -37,7 +169,20 @@ public:
37 169 else if (model == "Eye") file += "haarcascades/haarcascade_eye_tree_eyeglasses.xml";
38 170 else if (model == "FrontalFace") file += "haarcascades/haarcascade_frontalface_alt2.xml";
39 171 else if (model == "ProfileFace") file += "haarcascades/haarcascade_profileface.xml";
40   - else qFatal("Invalid model.");
  172 + else{
  173 + // Create temp folder if does not exist
  174 + file = model+QDir::separator()+"cascade.xml";
  175 + QDir dir(model);
  176 + if (!dir.exists())
  177 + if (!QDir::current().mkdir(model)) qFatal("Cannot create model.");
  178 +
  179 + // Make sure file can be created
  180 + QFile pathTest(file);
  181 + if (pathTest.exists()) pathTest.remove();
  182 +
  183 + if (!pathTest.open(QIODevice::WriteOnly | QIODevice::Text)) qFatal("Cannot create model.");
  184 + pathTest.remove();
  185 + }
41 186 }
42 187  
43 188 private:
... ... @@ -54,16 +199,60 @@ private:
54 199 * \ingroup transforms
55 200 * \brief Wraps OpenCV cascade classifier
56 201 * \author Josh Klontz \cite jklontz
  202 + * \author David Crouse \cite dgcrouse
57 203 */
58   -class CascadeTransform : public UntrainableMetaTransform
  204 +class CascadeTransform : public MetaTransform
59 205 {
60 206 Q_OBJECT
61 207 Q_PROPERTY(QString model READ get_model WRITE set_model RESET reset_model STORED false)
62 208 Q_PROPERTY(int minSize READ get_minSize WRITE set_minSize RESET reset_minSize STORED false)
63 209 Q_PROPERTY(bool ROCMode READ get_ROCMode WRITE set_ROCMode RESET reset_ROCMode STORED false)
  210 +
  211 + // Training parameters
  212 + Q_PROPERTY(int numStages READ get_numStages WRITE set_numStages RESET reset_numStages STORED false)
  213 + Q_PROPERTY(int w READ get_w WRITE set_w RESET reset_w STORED false)
  214 + Q_PROPERTY(int h READ get_h WRITE set_h RESET reset_h STORED false)
  215 + Q_PROPERTY(int numPos READ get_numPos WRITE set_numPos RESET reset_numPos STORED false)
  216 + Q_PROPERTY(int numNeg READ get_numNeg WRITE set_numNeg RESET reset_numNeg STORED false)
  217 + Q_PROPERTY(int precalcValBufSize READ get_precalcValBufSize WRITE set_precalcValBufSize RESET reset_precalcValBufSize STORED false)
  218 + Q_PROPERTY(int precalcIdxBufSize READ get_precalcIdxBufSize WRITE set_precalcIdxBufSize RESET reset_precalcIdxBufSize STORED false)
  219 + Q_PROPERTY(double minHitRate READ get_minHitRate WRITE set_minHitRate RESET reset_minHitRate STORED false)
  220 + Q_PROPERTY(double maxFalseAlarmRate READ get_maxFalseAlarmRate WRITE set_maxFalseAlarmRate RESET reset_maxFalseAlarmRate STORED false)
  221 + Q_PROPERTY(double weightTrimRate READ get_weightTrimRate WRITE set_weightTrimRate RESET reset_weightTrimRate STORED false)
  222 + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false)
  223 + Q_PROPERTY(int maxWeakCount READ get_maxWeakCount WRITE set_maxWeakCount RESET reset_maxWeakCount STORED false)
  224 + Q_PROPERTY(QString stageType READ get_stageType WRITE set_stageType RESET reset_stageType STORED false)
  225 + Q_PROPERTY(QString featureType READ get_featureType WRITE set_featureType RESET reset_featureType STORED false)
  226 + Q_PROPERTY(QString bt READ get_bt WRITE set_bt RESET reset_bt STORED false)
  227 + Q_PROPERTY(QString mode READ get_mode WRITE set_mode RESET reset_mode STORED false)
  228 + Q_PROPERTY(bool show READ get_show WRITE set_show RESET reset_show STORED false)
  229 + Q_PROPERTY(bool baseFormatSave READ get_baseFormatSave WRITE set_baseFormatSave RESET reset_baseFormatSave STORED false)
  230 + Q_PROPERTY(bool overwrite READ get_overwrite WRITE set_overwrite RESET reset_overwrite STORED false)
  231 +
64 232 BR_PROPERTY(QString, model, "FrontalFace")
65 233 BR_PROPERTY(int, minSize, 64)
66 234 BR_PROPERTY(bool, ROCMode, false)
  235 +
  236 + // Training parameters - Default values provided trigger OpenCV defaults
  237 + BR_PROPERTY(int, numStages, -1)
  238 + BR_PROPERTY(int, w, -1)
  239 + BR_PROPERTY(int, h, -1)
  240 + BR_PROPERTY(int, numPos, -1)
  241 + BR_PROPERTY(int, numNeg, -1)
  242 + BR_PROPERTY(int, precalcValBufSize, -1)
  243 + BR_PROPERTY(int, precalcIdxBufSize, -1)
  244 + BR_PROPERTY(double, minHitRate, -1)
  245 + BR_PROPERTY(double, maxFalseAlarmRate, -1)
  246 + BR_PROPERTY(double, weightTrimRate, -1)
  247 + BR_PROPERTY(int, maxDepth, -1)
  248 + BR_PROPERTY(int, maxWeakCount, -1)
  249 + BR_PROPERTY(QString, stageType, "")
  250 + BR_PROPERTY(QString, featureType, "")
  251 + BR_PROPERTY(QString, bt, "")
  252 + BR_PROPERTY(QString, mode, "")
  253 + BR_PROPERTY(bool, show, false)
  254 + BR_PROPERTY(bool, baseFormatSave, false)
  255 + BR_PROPERTY(bool, overwrite, false)
67 256  
68 257 Resource<CascadeClassifier> cascadeResource;
69 258  
... ... @@ -71,6 +260,136 @@ class CascadeTransform : public UntrainableMetaTransform
71 260 {
72 261 cascadeResource.setResourceMaker(new CascadeResourceMaker(model));
73 262 }
  263 +
  264 + // Train transform
  265 + void train(const TemplateList& data)
  266 + {
  267 + if (overwrite) {
  268 + QDir dataDir(model);
  269 + if (dataDir.exists()) {
  270 + dataDir.removeRecursively();
  271 + QDir::current().mkdir(model);
  272 + }
  273 + }
  274 +
  275 + const FileList files = data.files();
  276 +
  277 + // Open positive and negative list files
  278 + const QString posFName = "pos.txt";
  279 + const QString negFName = "neg.txt";
  280 + QFile posFile(posFName);
  281 + QFile negFile(negFName);
  282 + posFile.open(QIODevice::WriteOnly | QIODevice::Text);
  283 + negFile.open(QIODevice::WriteOnly | QIODevice::Text);
  284 + QTextStream posStream(&posFile);
  285 + QTextStream negStream(&negFile);
  286 +
  287 + const QString endln = "\r\n";
  288 +
  289 + int posCount = 0;
  290 + int negCount = 0;
  291 +
  292 + bool buildPos = false; // If true, build positive vector from single image
  293 +
  294 + TrainParams params;
  295 +
  296 + // Fill in from params (param defaults are same as struct defaults, so no checks are needed)
  297 + params.numStages = numStages;
  298 + params.w = w;
  299 + params.h = h;
  300 + params.numPos = numPos;
  301 + params.numNeg = numNeg;
  302 + params.precalcValBufSize = precalcValBufSize;
  303 + params.precalcIdxBufSize = precalcIdxBufSize;
  304 + params.minHitRate = minHitRate;
  305 + params.maxFalseAlarmRate = maxFalseAlarmRate;
  306 + params.weightTrimRate = weightTrimRate;
  307 + params.maxDepth = maxDepth;
  308 + params.maxWeakCount = maxWeakCount;
  309 + params.stageType = stageType;
  310 + params.featureType = featureType;
  311 + params.bt = bt;
  312 + params.mode = mode;
  313 + params.show = show;
  314 + params.baseFormatSave = baseFormatSave;
  315 + if (params.w < 0) params.w = minSize;
  316 + if (params.h < 0) params.h = minSize;
  317 +
  318 + for (int i = 0; i < files.length(); i++) {
  319 + File f = files[i];
  320 + if (f.contains("training-set")) {
  321 + QString tset = f.get<QString>("training-set",QString()).toLower();
  322 +
  323 + // Negative samples
  324 + if (tset == "neg") {
  325 + if (negCount > 0) negStream<<endln;
  326 + negStream << f.path() << QDir::separator() << f.fileName();
  327 + negCount++;
  328 +
  329 + // Positive samples for crop/rescale
  330 + }else if (tset == "pos") {
  331 +
  332 + if (posCount > 0) posStream<<endln;
  333 + QString rects = "";
  334 +
  335 + // Extract rectangles
  336 + for (int j = 0; j < f.rects().length(); j++) {
  337 + QRectF r = f.rects()[j];
  338 + rects += " " + QString::number(r.x()) + " " + QString::number(r.y()) + " " + QString::number(r.width()) + " "+ QString::number(r.height());
  339 + posCount++;
  340 + }
  341 + if (f.rects().length() > 0)
  342 + posStream << f.path() << QDir::separator() << f.fileName() << " " << f.rects().length() << " " << rects;
  343 +
  344 + // Single positive sample for background removal and overlay on negatives
  345 + }else if (tset == "pos-base") {
  346 +
  347 + buildPos = true;
  348 + params.img = f.path() + QDir::separator() + f.fileName();
  349 +
  350 + // Parse settings (unique to this one tag)
  351 + if (f.contains("num")) params.num = f.get<int>("num",0);
  352 + if (f.contains("bgcolor")) params.bgcolor = f.get<int>("bgcolor",0);
  353 + if (f.contains("bgthresh")) params.bgthresh =f.get<int>("bgthresh",0);
  354 + if (f.contains("inv")) params.inv = f.get<bool>("inv",false);
  355 + if (f.contains("randinv")) params.randinv = f.get<bool>("randinv",false);
  356 + if (f.contains("maxidev")) params.maxidev = f.get<int>("maxidev",0);
  357 + if (f.contains("maxxangle")) params.maxxangle = f.get<double>("maxxangle",0);
  358 + if (f.contains("maxyangle")) params.maxyangle = f.get<double>("maxyangle",0);
  359 + if (f.contains("maxzangle")) params.maxzangle = f.get<double>("maxzangle",0);
  360 + }
  361 + }
  362 + }
  363 +
  364 + // Fill in remaining params conditionally
  365 + posFile.close();
  366 + negFile.close();
  367 + if (buildPos) {
  368 + if (params.numPos < 0) {
  369 + if (params.num > 0) params.numPos = (int)(params.num*.95);
  370 + else params.numPos = 950;
  371 + posFile.remove();
  372 + }
  373 + }else{
  374 + params.info = posFName;
  375 + if (params.numPos < 0) {
  376 + params.numPos = (int)(posCount*.95);
  377 + }
  378 + }
  379 + params.bg = negFName;
  380 + params.data = model;
  381 + if (params.num < 0) {
  382 + params.num = posCount;
  383 + }
  384 + if (params.numNeg < 0) {
  385 + params.numNeg = negCount*10;
  386 + }
  387 +
  388 + genSamples(params);
  389 + trainCascade(params);
  390 + if (posFile.exists()) posFile.remove();
  391 + negFile.remove();
  392 + }
74 393  
75 394 void project(const Template &src, Template &dst) const
76 395 {
... ...
share/openbr/openbr.bib
... ... @@ -38,6 +38,11 @@
38 38 Author = {Austin Van Blanton},
39 39 Howpublished = {https://github.com/imaus10},
40 40 Title = {imaus10 at gmail.com}}
  41 +
  42 + @misc{dgcrouse,
  43 + Author = {David G. Crouse},
  44 + Howpublished = {https://github.com/dgcrouse},
  45 + Title = {dgcrouse at gmail.com}}
41 46  
42 47 % Software
43 48 @misc{libface,
... ...