Commit 1f041fdb4e41dc1cd94806634244184efe553f69

Authored by dgcrouse
1 parent f4ffecd3

Added cascade training support

Needs to be documented
Showing 1 changed file with 354 additions and 4 deletions
openbr/plugins/cascade.cpp
... ... @@ -19,11 +19,164 @@
19 19 #include "openbr_internal.h"
20 20 #include "openbr/core/opencvutils.h"
21 21 #include "openbr/core/resource.h"
  22 +#include <QtXml>
  23 +#include <stdlib.h>
  24 +
  25 +#ifndef _win32
  26 + const QString foldersep = "/";
  27 +#else
  28 + const QString foldersep = "\\";
  29 +#endif
22 30  
23 31 using namespace cv;
24 32  
25 33 namespace br
26 34 {
  35 +
  36 + struct trainParams{
  37 + public:
  38 + QString data; // REQUIRED: Filepath to store trained classifier
  39 + QString vec; // REQUIRED: Filepath to store vector of positive samples, default "vector"
  40 + QString img; // Filepath to source object image. Either this or info is REQUIRED
  41 + QString info; // Description file of source images. Either this or img is REQUIRED
  42 + QString bg; // REQUIRED: Filepath to background list file
  43 + int num; // Number of samples to generate
  44 + int bgcolor; // Background color supplied image (via img)
  45 + int bgthresh; // Threshold to determine bgcolor match
  46 + bool inv; // Invert colors
  47 + bool randinv; // Randomly invert colors
  48 + int maxidev; // Max intensity deviation of foreground pixels
  49 + double maxxangle; // Maximum rotation angle (X)
  50 + double maxyangle; // Maximum rotation angle (Y)
  51 + double maxzangle; // Maximum rotation angle (Z)
  52 + bool show; // Show generated samples
  53 + int w; // REQUIRED: Sample width
  54 + int h; // REQUIRED: Sample height
  55 + int numPos; // Number of positive samples
  56 + int numNeg; // Number of negative samples
  57 + int numStages; // Number of stages
  58 + int precalcValBufSize; // Precalculated val buffer size in Mb
  59 + int precalcIdxBufSize; // Precalculated index buffer size in Mb
  60 + bool baseFormatSave; // Save in old format
  61 + QString stageType; // Stage type (BOOST)
  62 + QString featureType; // Feature type (HAAR, LBP)
  63 + QString bt; // Boosted classifier type (DAB, RAB, LB, GAB)
  64 + double minHitRate; // Minimal hit rate per stage
  65 + double maxFalseAlarmRate; // Max false alarm rate per stage
  66 + double weightTrimRate; // Weight for trimming
  67 + int maxDepth; // Max weak tree depth
  68 + int maxWeakCount; // Max weak tree count per stage
  69 + QString mode; // Haar feature mode (BASIC, CORE, ALL)
  70 +
  71 + trainParams(){
  72 + num = -1;
  73 + maxidev = -1;
  74 + maxxangle = -1;
  75 + maxyangle = -1;
  76 + maxzangle = -1;
  77 + w = -1;
  78 + h = -1;
  79 + numPos = -1;
  80 + numNeg = -1;
  81 + numStages = -1;
  82 + precalcValBufSize = -1;
  83 + precalcIdxBufSize = -1;
  84 + minHitRate = -1;
  85 + maxFalseAlarmRate = -1;
  86 + weightTrimRate = -1;
  87 + maxDepth = -1;
  88 + maxWeakCount = -1;
  89 + inv = false;
  90 + randinv = false;
  91 + show = false;
  92 + baseFormatSave = false;
  93 + vec = "vector.vec";
  94 + bgcolor = -1;
  95 + bgthresh = -1;
  96 + }
  97 + };
  98 +
  99 + QString buildTrainingArgs(trainParams params){
  100 + QString args = "";
  101 + if (params.data != "") args += "-data " + params.data + " ";
  102 + else return "";
  103 + if (params.vec != "") args += "-vec " + params.vec + " ";
  104 + else return "";
  105 + if (params.bg != "") args += "-bg " + params.bg + " ";
  106 + else return "";
  107 + if (params.numPos >= 0) args += "-numPos " + QString::number(params.numPos) + " ";
  108 + if (params.numNeg >= 0) args += "-numNeg " + QString::number(params.numNeg) + " ";
  109 + if (params.numStages >= 0) args += "-numStages " + QString::number(params.numStages) + " ";
  110 + if (params.precalcValBufSize >= 0) args += "-precalcValBufSize " + QString::number(params.precalcValBufSize) + " ";
  111 + if (params.precalcIdxBufSize >= 0) args += "-precalcIdxBufSize " + QString::number(params.precalcIdxBufSize) + " ";
  112 + if (params.baseFormatSave) args += "-baseFormatSave ";
  113 + if (params.stageType != "") args += "-stageType " + params.stageType + " ";
  114 + if (params.featureType != "") args += "-featureType " + params.featureType + " ";
  115 + if (params.w >= 0) args += "-w " + QString::number(params.w) + " ";
  116 + else return "";
  117 + if (params.h >= 0) args += "-h " + QString::number(params.h) + " ";
  118 + else return "";
  119 + if (params.bt != "") args += "-bt " + params.bt + " ";
  120 + if (params.minHitRate >= 0) args += "-minHitRate " + QString::number(params.minHitRate) + " ";
  121 + if (params.maxFalseAlarmRate >= 0) args += "-maxFalseAlarmRate " + QString::number(params.maxFalseAlarmRate) + " ";
  122 + if (params.weightTrimRate >= 0) args += "-weightTrimRate " + QString::number(params.weightTrimRate) + " ";
  123 + if (params.maxDepth >= 0) args += "-maxDepth " + QString::number(params.maxDepth) + " ";
  124 + if (params.maxWeakCount >= 0) args += "-maxWeakCount " + QString::number(params.maxWeakCount) + " ";
  125 + if (params.mode != "") args += "-mode " + params.mode + " ";
  126 + return args;
  127 + }
  128 +
  129 +
  130 + QString buildSampleArgs(trainParams params){
  131 + QString args = "";
  132 + if (params.vec != "") args += "-vec "+params.vec+" ";
  133 + else return "";
  134 + if (params.img != "") args += "-img " + params.img + " ";
  135 + else if (params.info != "") args += "-info " + params.info + " ";
  136 + else return "";
  137 + if (params.bg != "") args += "-bg " + params.bg + " ";
  138 + if (params.num > 0) args += "-num " + QString::number(params.num) + " ";
  139 + if (params.bgcolor >=0 ) args += "-bgcolor " + QString::number(params.bgcolor) + " ";
  140 + if (params.bgthresh >= 0) args += "-bgthresh " + QString::number(params.bgthresh) + " ";
  141 + if (params.maxidev >= 0) args += "-maxidev " + QString::number(params.maxidev) + " ";
  142 + if (params.maxxangle >= 0) args += "-maxxangle " + QString::number(params.maxxangle) + " ";
  143 + if (params.maxyangle >= 0) args += "-maxyangle " + QString::number(params.maxyangle) + " ";
  144 + if (params.maxzangle >= 0) args += "-maxzangle " + QString::number(params.maxzangle) + " ";
  145 + if (params.w >= 0) args += "-w " + QString::number(params.w) + " ";
  146 + if (params.h >= 0) args += "-h " + QString::number(params.h) + " ";
  147 + if (params.show) args += "-show ";
  148 + if (params.inv) args += "-inv ";
  149 + if (params.randinv) args += "-randinv ";
  150 + return args;
  151 + }
  152 +
  153 + void execCommand(QString cmd, QString args){
  154 + #ifdef _WIN32
  155 + cmd += ".exe";
  156 + #endif
  157 + cmd += " " + args;
  158 + system(cmd.toLocal8Bit().data());
  159 + }
  160 +
  161 + void genSamples(trainParams params, QString argStr = ""){
  162 + QString cmdArgs = buildSampleArgs(params);
  163 + if (argStr != "") cmdArgs += " " + argStr;
  164 + execCommand("opencv_createsamples",cmdArgs);
  165 + }
  166 +
  167 +
  168 + void trainCascade(trainParams params,QString argStr = ""){
  169 + QString cmdArgs = buildTrainingArgs(params);
  170 + if (argStr != "") cmdArgs += " " + argStr;
  171 +
  172 + execCommand("opencv_traincascade", cmdArgs);
  173 + }
  174 +
  175 +QString rectToString(QRectF r){
  176 + QString out = " " + QString::number(r.x()) + " " + QString::number(r.y()) + " " + QString::number(r.width()) + " "+ QString::number(r.height());
  177 + return out;
  178 +}
  179 +
27 180  
28 181 class CascadeResourceMaker : public ResourceMaker<CascadeClassifier>
29 182 {
... ... @@ -32,13 +185,26 @@ class CascadeResourceMaker : public ResourceMaker&lt;CascadeClassifier&gt;
32 185 public:
33 186 CascadeResourceMaker(const QString &model)
34 187 {
35   - file = Globals->sdkPath + "/share/openbr/models/";
  188 + file = Globals->sdkPath + "/share/openbr/models/";
36 189 if (model == "Ear") file += "haarcascades/haarcascade_ear.xml";
37 190 else if (model == "Eye") file += "haarcascades/haarcascade_eye_tree_eyeglasses.xml";
38 191 else if (model == "FrontalFace") file += "haarcascades/haarcascade_frontalface_alt2.xml";
39 192 else if (model == "ProfileFace") file += "haarcascades/haarcascade_profileface.xml";
40   - else qFatal("Invalid model.");
41   - }
  193 + else{
  194 + // Create temp folder if does not exist
  195 + file = model+foldersep+"cascade.xml";
  196 + QDir dir(model);
  197 + if (!dir.exists())
  198 + if (!QDir::current().mkdir(model)) qFatal("Cannot create model.");
  199 +
  200 + // Make sure file can be created
  201 + QFile pathTest(file);
  202 + if (pathTest.exists()) pathTest.remove();
  203 +
  204 + if (!pathTest.open(QIODevice::WriteOnly | QIODevice::Text)) qFatal("Cannot create model.");
  205 + pathTest.remove();
  206 + }
  207 + }
42 208  
43 209 private:
44 210 CascadeClassifier *make() const
... ... @@ -50,20 +216,66 @@ private:
50 216 }
51 217 };
52 218  
  219 +
53 220 /*!
54 221 * \ingroup transforms
55 222 * \brief Wraps OpenCV cascade classifier
56 223 * \author Josh Klontz \cite jklontz
57 224 */
58   -class CascadeTransform : public UntrainableMetaTransform
  225 +class CascadeTransform : public MetaTransform
59 226 {
60 227 Q_OBJECT
61 228 Q_PROPERTY(QString model READ get_model WRITE set_model RESET reset_model STORED false)
62 229 Q_PROPERTY(int minSize READ get_minSize WRITE set_minSize RESET reset_minSize STORED false)
63 230 Q_PROPERTY(bool ROCMode READ get_ROCMode WRITE set_ROCMode RESET reset_ROCMode STORED false)
  231 +
  232 + // Training parameters
  233 + Q_PROPERTY(int numStages READ get_numStages WRITE set_numStages RESET reset_numStages STORED false)
  234 + Q_PROPERTY(int w READ get_w WRITE set_w RESET reset_w STORED false)
  235 + Q_PROPERTY(int h READ get_h WRITE set_h RESET reset_h STORED false)
  236 + Q_PROPERTY(int numPos READ get_numPos WRITE set_numPos RESET reset_numPos STORED false)
  237 + Q_PROPERTY(int numNeg READ get_numNeg WRITE set_numNeg RESET reset_numNeg STORED false)
  238 + Q_PROPERTY(int precalcValBufSize READ get_precalcValBufSize WRITE set_precalcValBufSize RESET reset_precalcValBufSize STORED false)
  239 + Q_PROPERTY(int precalcIdxBufSize READ get_precalcIdxBufSize WRITE set_precalcIdxBufSize RESET reset_precalcIdxBufSize STORED false)
  240 + Q_PROPERTY(double minHitRate READ get_minHitRate WRITE set_minHitRate RESET reset_minHitRate STORED false)
  241 + Q_PROPERTY(double maxFalseAlarmRate READ get_maxFalseAlarmRate WRITE set_maxFalseAlarmRate RESET reset_maxFalseAlarmRate STORED false)
  242 + Q_PROPERTY(double weightTrimRate READ get_weightTrimRate WRITE set_weightTrimRate RESET reset_weightTrimRate STORED false)
  243 + Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false)
  244 + Q_PROPERTY(int maxWeakCount READ get_maxWeakCount WRITE set_maxWeakCount RESET reset_maxWeakCount STORED false)
  245 + Q_PROPERTY(QString stageType READ get_stageType WRITE set_stageType RESET reset_stageType STORED false)
  246 + Q_PROPERTY(QString featureType READ get_featureType WRITE set_featureType RESET reset_featureType STORED false)
  247 + Q_PROPERTY(QString bt READ get_bt WRITE set_bt RESET reset_bt STORED false)
  248 + Q_PROPERTY(QString mode READ get_mode WRITE set_mode RESET reset_mode STORED false)
  249 + Q_PROPERTY(bool show READ get_show WRITE set_show RESET reset_show STORED false)
  250 + Q_PROPERTY(bool baseFormatSave READ get_baseFormatSave WRITE set_baseFormatSave RESET reset_baseFormatSave STORED false)
  251 + Q_PROPERTY(bool overwrite READ get_overwrite WRITE set_overwrite RESET reset_overwrite STORED false)
  252 +
  253 +
64 254 BR_PROPERTY(QString, model, "FrontalFace")
65 255 BR_PROPERTY(int, minSize, 64)
66 256 BR_PROPERTY(bool, ROCMode, false)
  257 +
  258 + // Training parameters - Default values provided trigger OpenCV defaults
  259 + BR_PROPERTY(int, numStages, -1)
  260 + BR_PROPERTY(int,w,-1)
  261 + BR_PROPERTY(int,h,-1)
  262 + BR_PROPERTY(int,numPos,-1)
  263 + BR_PROPERTY(int,numNeg,-1)
  264 + BR_PROPERTY(int,precalcValBufSize,-1)
  265 + BR_PROPERTY(int,precalcIdxBufSize,-1)
  266 + BR_PROPERTY(double,minHitRate,-1)
  267 + BR_PROPERTY(double,maxFalseAlarmRate,-1)
  268 + BR_PROPERTY(double,weightTrimRate,-1)
  269 + BR_PROPERTY(int,maxDepth,-1)
  270 + BR_PROPERTY(int,maxWeakCount,-1)
  271 + BR_PROPERTY(QString,stageType,"")
  272 + BR_PROPERTY(QString,featureType,"")
  273 + BR_PROPERTY(QString,bt,"")
  274 + BR_PROPERTY(QString,mode,"")
  275 + BR_PROPERTY(bool,show,false)
  276 + BR_PROPERTY(bool,baseFormatSave,false)
  277 + BR_PROPERTY(bool,overwrite,false)
  278 +
67 279  
68 280 Resource<CascadeClassifier> cascadeResource;
69 281  
... ... @@ -71,6 +283,144 @@ class CascadeTransform : public UntrainableMetaTransform
71 283 {
72 284 cascadeResource.setResourceMaker(new CascadeResourceMaker(model));
73 285 }
  286 +
  287 + // Train transform
  288 + void train(const TemplateList& data){
  289 + if (overwrite){
  290 + QDir dataDir(model);
  291 + if (dataDir.exists()){
  292 + dataDir.removeRecursively();
  293 + QDir::current().mkdir(model);
  294 + }
  295 + }
  296 +
  297 + FileList files = data.files();
  298 +
  299 +
  300 + // Open positive and negative list files
  301 + QString posFName = "pos.txt";
  302 + QString negFName = "neg.txt";
  303 + QFile posFile(posFName);
  304 + QFile negFile(negFName);
  305 + posFile.open(QIODevice::WriteOnly | QIODevice::Text);
  306 + negFile.open(QIODevice::WriteOnly | QIODevice::Text);
  307 + QTextStream posStream(&posFile);
  308 + QTextStream negStream(&negFile);
  309 +
  310 +
  311 + const QString endln = "\r\n";
  312 +
  313 + int posCount = 0;
  314 + int negCount = 0;
  315 +
  316 + bool buildPos = false; // If true, build positive vector from single image
  317 +
  318 + QList<QString> vectors;
  319 +
  320 + trainParams params;
  321 +
  322 + // Fill in from params (param defaults are same as struct defaults, so no checks are needed)
  323 + params.numStages = numStages;
  324 + params.w = w;
  325 + params.h = h;
  326 + params.numPos = numPos;
  327 + params.numNeg = numNeg;
  328 + params.precalcValBufSize = precalcValBufSize;
  329 + params.precalcIdxBufSize = precalcIdxBufSize;
  330 + params.minHitRate = minHitRate;
  331 + params.maxFalseAlarmRate = maxFalseAlarmRate;
  332 + params.weightTrimRate = weightTrimRate;
  333 + params.maxDepth = maxDepth;
  334 + params.maxWeakCount = maxWeakCount;
  335 + params.stageType = stageType;
  336 + params.featureType = featureType;
  337 + params.bt = bt;
  338 + params.mode = mode;
  339 + params.show = show;
  340 + params.baseFormatSave = baseFormatSave;
  341 + if (params.w < 0) params.w = minSize;
  342 + if (params.h < 0) params.h = minSize;
  343 +
  344 + for (int i = 0; i < files.length(); i++){
  345 + File f = files[i];
  346 + if (f.localKeys().contains("training-set")){
  347 + QString tset = f.localMetadata()["training-set"].toString().toLower();
  348 +
  349 + // Negative samples
  350 + if (tset == "neg"){
  351 + if (negCount > 0) negStream<<endln;
  352 + negStream << f.path() << foldersep << f.fileName();
  353 + negCount++;
  354 +
  355 + // Positive samples for crop/rescale
  356 + }else if (tset == "pos"){
  357 +
  358 + if (posCount > 0) posStream<<endln;
  359 + QString rects = "";
  360 +
  361 + // Extract rectangles
  362 + for (int j = 0; j < f.rects().length(); j++){
  363 + rects += rectToString(f.rects()[j]);
  364 + posCount++;
  365 + }
  366 + if (f.rects().length() > 0)
  367 + posStream << f.path() << foldersep << f.fileName() << " " << f.rects().length() << " " << rects;
  368 +
  369 + // Single positive sample for background removal and overlay on negatives
  370 + }else if (tset == "pos-base"){
  371 +
  372 + buildPos = true;
  373 + params.img = f.path() + foldersep + f.fileName();
  374 +
  375 + // Parse settings (unique to this one tag)
  376 + if (f.localKeys().contains("num")) params.num = f.localMetadata()["num"].toInt();
  377 + if (f.localKeys().contains("bgcolor")) params.bgcolor = f.localMetadata()["bgcolor"].toInt();
  378 + if (f.localKeys().contains("bgthresh")) params.bgthresh = f.localMetadata()["bgthresh"].toInt();
  379 + if (f.localKeys().contains("inv")) params.inv = f.localMetadata()["inv"].toBool();
  380 + if (f.localKeys().contains("randinv")) params.randinv = f.localMetadata()["randinv"].toBool();
  381 + if (f.localKeys().contains("maxidev")) params.maxidev = f.localMetadata()["maxidev"].toInt();
  382 + if (f.localKeys().contains("maxxangle")) params.maxxangle = f.localMetadata()["maxxangle"].toDouble();
  383 + if (f.localKeys().contains("maxyangle")) params.maxyangle = f.localMetadata()["maxyangle"].toDouble();
  384 + if (f.localKeys().contains("maxzangle")) params.maxzangle = f.localMetadata()["maxzangle"].toDouble();
  385 +
  386 + }else if (tset == "vector"){
  387 + vectors.append(f.path() + foldersep + f.fileName());
  388 + }
  389 + }
  390 + }
  391 +
  392 +
  393 + // Fill in remaining params conditionally
  394 + posFile.close();
  395 + negFile.close();
  396 + if (buildPos){
  397 + if (params.numPos < 0){
  398 + if (params.num > 0) params.numPos = (int)(params.num*.95);
  399 + else params.numPos = 950;
  400 + posFile.remove();
  401 + }
  402 + }else{
  403 + params.info = posFName;
  404 + if (params.numPos < 0){
  405 + params.numPos = (int)(posCount*.95);
  406 + }
  407 + }
  408 + params.bg = negFName;
  409 + params.data = model;
  410 + if (params.num < 0){
  411 + params.num = posCount;
  412 + }
  413 + if (params.numNeg < 0){
  414 + params.numNeg = negCount*10;
  415 + }
  416 +
  417 +
  418 + genSamples(params);
  419 + trainCascade(params);
  420 + if (posFile.exists()) posFile.remove();
  421 + negFile.remove();
  422 + }
  423 +
74 424  
75 425 void project(const Template &src, Template &dst) const
76 426 {
... ...