Commit 9567b9759c1683a5854f6af9fc56dd424d97993a
1 parent
ff22117c
Step 1 of Cascade integration
Showing
7 changed files
with
2927 additions
and
277 deletions
openbr/core/boost.cpp
0 → 100644
| 1 | +#include "boost.h" | |
| 2 | +#include <queue> | |
| 3 | +#include "cxmisc.h" | |
| 4 | + | |
| 5 | +using namespace std; | |
| 6 | +using namespace br; | |
| 7 | +using namespace cv; | |
| 8 | + | |
| 9 | +static inline double | |
| 10 | +logRatio( double val ) | |
| 11 | +{ | |
| 12 | + const double eps = 1e-5; | |
| 13 | + | |
| 14 | + val = max( val, eps ); | |
| 15 | + val = min( val, 1. - eps ); | |
| 16 | + return log( val/(1. - val) ); | |
| 17 | +} | |
| 18 | + | |
| 19 | +#define CV_CMP_FLT(i,j) (i < j) | |
| 20 | +static CV_IMPLEMENT_QSORT_EX( icvSortFlt, float, CV_CMP_FLT, const float* ) | |
| 21 | + | |
| 22 | +#define CV_CMP_NUM_IDX(i,j) (aux[i] < aux[j]) | |
| 23 | +static CV_IMPLEMENT_QSORT_EX( icvSortIntAux, int, CV_CMP_NUM_IDX, const float* ) | |
| 24 | +static CV_IMPLEMENT_QSORT_EX( icvSortUShAux, unsigned short, CV_CMP_NUM_IDX, const float* ) | |
| 25 | + | |
| 26 | +#define CV_THRESHOLD_EPS (0.00001F) | |
| 27 | + | |
| 28 | +static const int MinBlockSize = 1 << 16; | |
| 29 | +static const int BlockSizeDelta = 1 << 10; | |
| 30 | + | |
| 31 | +// TODO remove this code duplication with ml/precomp.hpp | |
| 32 | + | |
| 33 | +static int CV_CDECL icvCmpIntegers( const void* a, const void* b ) | |
| 34 | +{ | |
| 35 | + return *(const int*)a - *(const int*)b; | |
| 36 | +} | |
| 37 | + | |
| 38 | +static CvMat* cvPreprocessIndexArray( const CvMat* idx_arr, int data_arr_size, bool check_for_duplicates=false ) | |
| 39 | +{ | |
| 40 | + CvMat* idx = 0; | |
| 41 | + | |
| 42 | + CV_FUNCNAME( "cvPreprocessIndexArray" ); | |
| 43 | + | |
| 44 | + __BEGIN__; | |
| 45 | + | |
| 46 | + int i, idx_total, idx_selected = 0, step, type, prev = INT_MIN, is_sorted = 1; | |
| 47 | + uchar* srcb = 0; | |
| 48 | + int* srci = 0; | |
| 49 | + int* dsti; | |
| 50 | + | |
| 51 | + if( !CV_IS_MAT(idx_arr) ) | |
| 52 | + CV_ERROR( CV_StsBadArg, "Invalid index array" ); | |
| 53 | + | |
| 54 | + if( idx_arr->rows != 1 && idx_arr->cols != 1 ) | |
| 55 | + CV_ERROR( CV_StsBadSize, "the index array must be 1-dimensional" ); | |
| 56 | + | |
| 57 | + idx_total = idx_arr->rows + idx_arr->cols - 1; | |
| 58 | + srcb = idx_arr->data.ptr; | |
| 59 | + srci = idx_arr->data.i; | |
| 60 | + | |
| 61 | + type = CV_MAT_TYPE(idx_arr->type); | |
| 62 | + step = CV_IS_MAT_CONT(idx_arr->type) ? 1 : idx_arr->step/CV_ELEM_SIZE(type); | |
| 63 | + | |
| 64 | + switch( type ) | |
| 65 | + { | |
| 66 | + case CV_8UC1: | |
| 67 | + case CV_8SC1: | |
| 68 | + // idx_arr is array of 1's and 0's - | |
| 69 | + // i.e. it is a mask of the selected components | |
| 70 | + if( idx_total != data_arr_size ) | |
| 71 | + CV_ERROR( CV_StsUnmatchedSizes, | |
| 72 | + "Component mask should contain as many elements as the total number of input variables" ); | |
| 73 | + | |
| 74 | + for( i = 0; i < idx_total; i++ ) | |
| 75 | + idx_selected += srcb[i*step] != 0; | |
| 76 | + | |
| 77 | + if( idx_selected == 0 ) | |
| 78 | + CV_ERROR( CV_StsOutOfRange, "No components/input_variables is selected!" ); | |
| 79 | + | |
| 80 | + break; | |
| 81 | + case CV_32SC1: | |
| 82 | + // idx_arr is array of integer indices of selected components | |
| 83 | + if( idx_total > data_arr_size ) | |
| 84 | + CV_ERROR( CV_StsOutOfRange, | |
| 85 | + "index array may not contain more elements than the total number of input variables" ); | |
| 86 | + idx_selected = idx_total; | |
| 87 | + // check if sorted already | |
| 88 | + for( i = 0; i < idx_total; i++ ) | |
| 89 | + { | |
| 90 | + int val = srci[i*step]; | |
| 91 | + if( val >= prev ) | |
| 92 | + { | |
| 93 | + is_sorted = 0; | |
| 94 | + break; | |
| 95 | + } | |
| 96 | + prev = val; | |
| 97 | + } | |
| 98 | + break; | |
| 99 | + default: | |
| 100 | + CV_ERROR( CV_StsUnsupportedFormat, "Unsupported index array data type " | |
| 101 | + "(it should be 8uC1, 8sC1 or 32sC1)" ); | |
| 102 | + } | |
| 103 | + | |
| 104 | + CV_CALL( idx = cvCreateMat( 1, idx_selected, CV_32SC1 )); | |
| 105 | + dsti = idx->data.i; | |
| 106 | + | |
| 107 | + if( type < CV_32SC1 ) | |
| 108 | + { | |
| 109 | + for( i = 0; i < idx_total; i++ ) | |
| 110 | + if( srcb[i*step] ) | |
| 111 | + *dsti++ = i; | |
| 112 | + } | |
| 113 | + else | |
| 114 | + { | |
| 115 | + for( i = 0; i < idx_total; i++ ) | |
| 116 | + dsti[i] = srci[i*step]; | |
| 117 | + | |
| 118 | + if( !is_sorted ) | |
| 119 | + qsort( dsti, idx_total, sizeof(dsti[0]), icvCmpIntegers ); | |
| 120 | + | |
| 121 | + if( dsti[0] < 0 || dsti[idx_total-1] >= data_arr_size ) | |
| 122 | + CV_ERROR( CV_StsOutOfRange, "the index array elements are out of range" ); | |
| 123 | + | |
| 124 | + if( check_for_duplicates ) | |
| 125 | + { | |
| 126 | + for( i = 1; i < idx_total; i++ ) | |
| 127 | + if( dsti[i] <= dsti[i-1] ) | |
| 128 | + CV_ERROR( CV_StsBadArg, "There are duplicated index array elements" ); | |
| 129 | + } | |
| 130 | + } | |
| 131 | + | |
| 132 | + __END__; | |
| 133 | + | |
| 134 | + if( cvGetErrStatus() < 0 ) | |
| 135 | + cvReleaseMat( &idx ); | |
| 136 | + | |
| 137 | + return idx; | |
| 138 | +} | |
| 139 | + | |
| 140 | +//----------------------------- CascadeBoostParams ------------------------------------------------- | |
| 141 | + | |
| 142 | +CascadeBoostParams::CascadeBoostParams() : minHitRate( 0.995F), maxFalseAlarm( 0.5F ) | |
| 143 | +{ | |
| 144 | + boost_type = CvBoost::GENTLE; | |
| 145 | + use_surrogates = use_1se_rule = truncate_pruned_tree = false; | |
| 146 | +} | |
| 147 | + | |
| 148 | +CascadeBoostParams::CascadeBoostParams( int _boostType, | |
| 149 | + float _minHitRate, float _maxFalseAlarm, | |
| 150 | + double _weightTrimRate, int _maxDepth, int _maxWeakCount ) : | |
| 151 | + CvBoostParams( _boostType, _maxWeakCount, _weightTrimRate, _maxDepth, false, 0 ) | |
| 152 | +{ | |
| 153 | + boost_type = CvBoost::GENTLE; | |
| 154 | + minHitRate = _minHitRate; | |
| 155 | + maxFalseAlarm = _maxFalseAlarm; | |
| 156 | + use_surrogates = use_1se_rule = truncate_pruned_tree = false; | |
| 157 | +} | |
| 158 | + | |
| 159 | +void CascadeBoostParams::write( FileStorage &fs ) const | |
| 160 | +{ | |
| 161 | + string boostTypeStr = boost_type == CvBoost::DISCRETE ? CC_DISCRETE_BOOST : | |
| 162 | + boost_type == CvBoost::REAL ? CC_REAL_BOOST : | |
| 163 | + boost_type == CvBoost::LOGIT ? CC_LOGIT_BOOST : | |
| 164 | + boost_type == CvBoost::GENTLE ? CC_GENTLE_BOOST : string(); | |
| 165 | + CV_Assert( !boostTypeStr.empty() ); | |
| 166 | + fs << CC_BOOST_TYPE << boostTypeStr; | |
| 167 | + fs << CC_MINHITRATE << minHitRate; | |
| 168 | + fs << CC_MAXFALSEALARM << maxFalseAlarm; | |
| 169 | + fs << CC_TRIM_RATE << weight_trim_rate; | |
| 170 | + fs << CC_MAX_DEPTH << max_depth; | |
| 171 | + fs << CC_WEAK_COUNT << weak_count; | |
| 172 | +} | |
| 173 | + | |
| 174 | +bool CascadeBoostParams::read( const FileNode &node ) | |
| 175 | +{ | |
| 176 | + string boostTypeStr; | |
| 177 | + FileNode rnode = node[CC_BOOST_TYPE]; | |
| 178 | + rnode >> boostTypeStr; | |
| 179 | + boost_type = !boostTypeStr.compare( CC_DISCRETE_BOOST ) ? CvBoost::DISCRETE : | |
| 180 | + !boostTypeStr.compare( CC_REAL_BOOST ) ? CvBoost::REAL : | |
| 181 | + !boostTypeStr.compare( CC_LOGIT_BOOST ) ? CvBoost::LOGIT : | |
| 182 | + !boostTypeStr.compare( CC_GENTLE_BOOST ) ? CvBoost::GENTLE : -1; | |
| 183 | + if (boost_type == -1) | |
| 184 | + CV_Error( CV_StsBadArg, "unsupported Boost type" ); | |
| 185 | + node[CC_MINHITRATE] >> minHitRate; | |
| 186 | + node[CC_MAXFALSEALARM] >> maxFalseAlarm; | |
| 187 | + node[CC_TRIM_RATE] >> weight_trim_rate ; | |
| 188 | + node[CC_MAX_DEPTH] >> max_depth ; | |
| 189 | + node[CC_WEAK_COUNT] >> weak_count ; | |
| 190 | + if ( minHitRate <= 0 || minHitRate > 1 || | |
| 191 | + maxFalseAlarm <= 0 || maxFalseAlarm > 1 || | |
| 192 | + weight_trim_rate <= 0 || weight_trim_rate > 1 || | |
| 193 | + max_depth <= 0 || weak_count <= 0 ) | |
| 194 | + CV_Error( CV_StsBadArg, "bad parameters range"); | |
| 195 | + return true; | |
| 196 | +} | |
| 197 | + | |
| 198 | +void CascadeBoostParams::printDefaults() const | |
| 199 | +{ | |
| 200 | + cout << "--boostParams--" << endl; | |
| 201 | + cout << " [-bt <{" << CC_DISCRETE_BOOST << ", " | |
| 202 | + << CC_REAL_BOOST << ", " | |
| 203 | + << CC_LOGIT_BOOST ", " | |
| 204 | + << CC_GENTLE_BOOST << "(default)}>]" << endl; | |
| 205 | + cout << " [-minHitRate <min_hit_rate> = " << minHitRate << ">]" << endl; | |
| 206 | + cout << " [-maxFalseAlarmRate <max_false_alarm_rate = " << maxFalseAlarm << ">]" << endl; | |
| 207 | + cout << " [-weightTrimRate <weight_trim_rate = " << weight_trim_rate << ">]" << endl; | |
| 208 | + cout << " [-maxDepth <max_depth_of_weak_tree = " << max_depth << ">]" << endl; | |
| 209 | + cout << " [-maxWeakCount <max_weak_tree_count = " << weak_count << ">]" << endl; | |
| 210 | +} | |
| 211 | + | |
| 212 | +void CascadeBoostParams::printAttrs() const | |
| 213 | +{ | |
| 214 | + string boostTypeStr = boost_type == CvBoost::DISCRETE ? CC_DISCRETE_BOOST : | |
| 215 | + boost_type == CvBoost::REAL ? CC_REAL_BOOST : | |
| 216 | + boost_type == CvBoost::LOGIT ? CC_LOGIT_BOOST : | |
| 217 | + boost_type == CvBoost::GENTLE ? CC_GENTLE_BOOST : string(); | |
| 218 | + CV_Assert( !boostTypeStr.empty() ); | |
| 219 | + cout << "boostType: " << boostTypeStr << endl; | |
| 220 | + cout << "minHitRate: " << minHitRate << endl; | |
| 221 | + cout << "maxFalseAlarmRate: " << maxFalseAlarm << endl; | |
| 222 | + cout << "weightTrimRate: " << weight_trim_rate << endl; | |
| 223 | + cout << "maxDepth: " << max_depth << endl; | |
| 224 | + cout << "maxWeakCount: " << weak_count << endl; | |
| 225 | +} | |
| 226 | + | |
| 227 | +bool CascadeBoostParams::scanAttr( const string prmName, const string val) | |
| 228 | +{ | |
| 229 | + bool res = true; | |
| 230 | + | |
| 231 | + if( !prmName.compare( "-bt" ) ) | |
| 232 | + { | |
| 233 | + boost_type = !val.compare( CC_DISCRETE_BOOST ) ? CvBoost::DISCRETE : | |
| 234 | + !val.compare( CC_REAL_BOOST ) ? CvBoost::REAL : | |
| 235 | + !val.compare( CC_LOGIT_BOOST ) ? CvBoost::LOGIT : | |
| 236 | + !val.compare( CC_GENTLE_BOOST ) ? CvBoost::GENTLE : -1; | |
| 237 | + if (boost_type == -1) | |
| 238 | + res = false; | |
| 239 | + } | |
| 240 | + else if( !prmName.compare( "-minHitRate" ) ) | |
| 241 | + { | |
| 242 | + minHitRate = (float) atof( val.c_str() ); | |
| 243 | + } | |
| 244 | + else if( !prmName.compare( "-maxFalseAlarmRate" ) ) | |
| 245 | + { | |
| 246 | + maxFalseAlarm = (float) atof( val.c_str() ); | |
| 247 | + } | |
| 248 | + else if( !prmName.compare( "-weightTrimRate" ) ) | |
| 249 | + { | |
| 250 | + weight_trim_rate = (float) atof( val.c_str() ); | |
| 251 | + } | |
| 252 | + else if( !prmName.compare( "-maxDepth" ) ) | |
| 253 | + { | |
| 254 | + max_depth = atoi( val.c_str() ); | |
| 255 | + } | |
| 256 | + else if( !prmName.compare( "-maxWeakCount" ) ) | |
| 257 | + { | |
| 258 | + weak_count = atoi( val.c_str() ); | |
| 259 | + } | |
| 260 | + else | |
| 261 | + res = false; | |
| 262 | + | |
| 263 | + return res; | |
| 264 | +} | |
| 265 | + | |
| 266 | +CvDTreeNode* CascadeBoostTrainData::subsample_data( const CvMat* _subsample_idx ) | |
| 267 | +{ | |
| 268 | + CvDTreeNode* root = 0; | |
| 269 | + CvMat* isubsample_idx = 0; | |
| 270 | + CvMat* subsample_co = 0; | |
| 271 | + | |
| 272 | + bool isMakeRootCopy = true; | |
| 273 | + | |
| 274 | + if( !data_root ) | |
| 275 | + CV_Error( CV_StsError, "No training data has been set" ); | |
| 276 | + | |
| 277 | + if( _subsample_idx ) | |
| 278 | + { | |
| 279 | + CV_Assert( (isubsample_idx = cvPreprocessIndexArray( _subsample_idx, sample_count )) != 0 ); | |
| 280 | + | |
| 281 | + if( isubsample_idx->cols + isubsample_idx->rows - 1 == sample_count ) | |
| 282 | + { | |
| 283 | + const int* sidx = isubsample_idx->data.i; | |
| 284 | + for( int i = 0; i < sample_count; i++ ) | |
| 285 | + { | |
| 286 | + if( sidx[i] != i ) | |
| 287 | + { | |
| 288 | + isMakeRootCopy = false; | |
| 289 | + break; | |
| 290 | + } | |
| 291 | + } | |
| 292 | + } | |
| 293 | + else | |
| 294 | + isMakeRootCopy = false; | |
| 295 | + } | |
| 296 | + | |
| 297 | + if( isMakeRootCopy ) | |
| 298 | + { | |
| 299 | + // make a copy of the root node | |
| 300 | + CvDTreeNode temp; | |
| 301 | + int i; | |
| 302 | + root = new_node( 0, 1, 0, 0 ); | |
| 303 | + temp = *root; | |
| 304 | + *root = *data_root; | |
| 305 | + root->num_valid = temp.num_valid; | |
| 306 | + if( root->num_valid ) | |
| 307 | + { | |
| 308 | + for( i = 0; i < var_count; i++ ) | |
| 309 | + root->num_valid[i] = data_root->num_valid[i]; | |
| 310 | + } | |
| 311 | + root->cv_Tn = temp.cv_Tn; | |
| 312 | + root->cv_node_risk = temp.cv_node_risk; | |
| 313 | + root->cv_node_error = temp.cv_node_error; | |
| 314 | + } | |
| 315 | + else | |
| 316 | + { | |
| 317 | + int* sidx = isubsample_idx->data.i; | |
| 318 | + // co - array of count/offset pairs (to handle duplicated values in _subsample_idx) | |
| 319 | + int* co, cur_ofs = 0; | |
| 320 | + int workVarCount = get_work_var_count(); | |
| 321 | + int count = isubsample_idx->rows + isubsample_idx->cols - 1; | |
| 322 | + | |
| 323 | + root = new_node( 0, count, 1, 0 ); | |
| 324 | + | |
| 325 | + CV_Assert( (subsample_co = cvCreateMat( 1, sample_count*2, CV_32SC1 )) != 0); | |
| 326 | + cvZero( subsample_co ); | |
| 327 | + co = subsample_co->data.i; | |
| 328 | + for( int i = 0; i < count; i++ ) | |
| 329 | + co[sidx[i]*2]++; | |
| 330 | + for( int i = 0; i < sample_count; i++ ) | |
| 331 | + { | |
| 332 | + if( co[i*2] ) | |
| 333 | + { | |
| 334 | + co[i*2+1] = cur_ofs; | |
| 335 | + cur_ofs += co[i*2]; | |
| 336 | + } | |
| 337 | + else | |
| 338 | + co[i*2+1] = -1; | |
| 339 | + } | |
| 340 | + | |
| 341 | + cv::AutoBuffer<uchar> inn_buf(sample_count*(2*sizeof(int) + sizeof(float))); | |
| 342 | + // subsample ordered variables | |
| 343 | + for( int vi = 0; vi < numPrecalcIdx; vi++ ) | |
| 344 | + { | |
| 345 | + int ci = get_var_type(vi); | |
| 346 | + CV_Assert( ci < 0 ); | |
| 347 | + | |
| 348 | + int *src_idx_buf = (int*)(uchar*)inn_buf; | |
| 349 | + float *src_val_buf = (float*)(src_idx_buf + sample_count); | |
| 350 | + int* sample_indices_buf = (int*)(src_val_buf + sample_count); | |
| 351 | + const int* src_idx = 0; | |
| 352 | + const float* src_val = 0; | |
| 353 | + get_ord_var_data( data_root, vi, src_val_buf, src_idx_buf, &src_val, &src_idx, sample_indices_buf ); | |
| 354 | + | |
| 355 | + int j = 0, idx, count_i; | |
| 356 | + int num_valid = data_root->get_num_valid(vi); | |
| 357 | + CV_Assert( num_valid == sample_count ); | |
| 358 | + | |
| 359 | + if (is_buf_16u) | |
| 360 | + { | |
| 361 | + unsigned short* udst_idx = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() + | |
| 362 | + vi*sample_count + data_root->offset); | |
| 363 | + for( int i = 0; i < num_valid; i++ ) | |
| 364 | + { | |
| 365 | + idx = src_idx[i]; | |
| 366 | + count_i = co[idx*2]; | |
| 367 | + if( count_i ) | |
| 368 | + for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ ) | |
| 369 | + udst_idx[j] = (unsigned short)cur_ofs; | |
| 370 | + } | |
| 371 | + } | |
| 372 | + else | |
| 373 | + { | |
| 374 | + int* idst_idx = buf->data.i + root->buf_idx*get_length_subbuf() + | |
| 375 | + vi*sample_count + root->offset; | |
| 376 | + for( int i = 0; i < num_valid; i++ ) | |
| 377 | + { | |
| 378 | + idx = src_idx[i]; | |
| 379 | + count_i = co[idx*2]; | |
| 380 | + if( count_i ) | |
| 381 | + for( cur_ofs = co[idx*2+1]; count_i > 0; count_i--, j++, cur_ofs++ ) | |
| 382 | + idst_idx[j] = cur_ofs; | |
| 383 | + } | |
| 384 | + } | |
| 385 | + } | |
| 386 | + | |
| 387 | + // subsample cv_lables | |
| 388 | + const int* src_lbls = get_cv_labels(data_root, (int*)(uchar*)inn_buf); | |
| 389 | + if (is_buf_16u) | |
| 390 | + { | |
| 391 | + unsigned short* udst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() + | |
| 392 | + (workVarCount-1)*sample_count + root->offset); | |
| 393 | + for( int i = 0; i < count; i++ ) | |
| 394 | + udst[i] = (unsigned short)src_lbls[sidx[i]]; | |
| 395 | + } | |
| 396 | + else | |
| 397 | + { | |
| 398 | + int* idst = buf->data.i + root->buf_idx*get_length_subbuf() + | |
| 399 | + (workVarCount-1)*sample_count + root->offset; | |
| 400 | + for( int i = 0; i < count; i++ ) | |
| 401 | + idst[i] = src_lbls[sidx[i]]; | |
| 402 | + } | |
| 403 | + | |
| 404 | + // subsample sample_indices | |
| 405 | + const int* sample_idx_src = get_sample_indices(data_root, (int*)(uchar*)inn_buf); | |
| 406 | + if (is_buf_16u) | |
| 407 | + { | |
| 408 | + unsigned short* sample_idx_dst = (unsigned short*)(buf->data.s + root->buf_idx*get_length_subbuf() + | |
| 409 | + workVarCount*sample_count + root->offset); | |
| 410 | + for( int i = 0; i < count; i++ ) | |
| 411 | + sample_idx_dst[i] = (unsigned short)sample_idx_src[sidx[i]]; | |
| 412 | + } | |
| 413 | + else | |
| 414 | + { | |
| 415 | + int* sample_idx_dst = buf->data.i + root->buf_idx*get_length_subbuf() + | |
| 416 | + workVarCount*sample_count + root->offset; | |
| 417 | + for( int i = 0; i < count; i++ ) | |
| 418 | + sample_idx_dst[i] = sample_idx_src[sidx[i]]; | |
| 419 | + } | |
| 420 | + | |
| 421 | + for( int vi = 0; vi < var_count; vi++ ) | |
| 422 | + root->set_num_valid(vi, count); | |
| 423 | + } | |
| 424 | + | |
| 425 | + cvReleaseMat( &isubsample_idx ); | |
| 426 | + cvReleaseMat( &subsample_co ); | |
| 427 | + | |
| 428 | + return root; | |
| 429 | +} | |
| 430 | + | |
| 431 | +//---------------------------- CascadeBoostTrainData ----------------------------- | |
| 432 | + | |
| 433 | +CascadeBoostTrainData::CascadeBoostTrainData( const FeatureEvaluator* _featureEvaluator, | |
| 434 | + const CvDTreeParams& _params ) | |
| 435 | +{ | |
| 436 | + is_classifier = true; | |
| 437 | + var_all = var_count = (int)_featureEvaluator->getNumFeatures(); | |
| 438 | + | |
| 439 | + featureEvaluator = _featureEvaluator; | |
| 440 | + shared = true; | |
| 441 | + set_params( _params ); | |
| 442 | + max_c_count = MAX( 2, featureEvaluator->getMaxCatCount() ); | |
| 443 | + var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ); | |
| 444 | + if ( featureEvaluator->getMaxCatCount() > 0 ) | |
| 445 | + { | |
| 446 | + numPrecalcIdx = 0; | |
| 447 | + cat_var_count = var_count; | |
| 448 | + ord_var_count = 0; | |
| 449 | + for( int vi = 0; vi < var_count; vi++ ) | |
| 450 | + { | |
| 451 | + var_type->data.i[vi] = vi; | |
| 452 | + } | |
| 453 | + } | |
| 454 | + else | |
| 455 | + { | |
| 456 | + cat_var_count = 0; | |
| 457 | + ord_var_count = var_count; | |
| 458 | + for( int vi = 1; vi <= var_count; vi++ ) | |
| 459 | + { | |
| 460 | + var_type->data.i[vi-1] = -vi; | |
| 461 | + } | |
| 462 | + } | |
| 463 | + var_type->data.i[var_count] = cat_var_count; | |
| 464 | + var_type->data.i[var_count+1] = cat_var_count+1; | |
| 465 | + | |
| 466 | + int maxSplitSize = cvAlign(sizeof(CvDTreeSplit) + (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*)); | |
| 467 | + int treeBlockSize = MAX((int)sizeof(CvDTreeNode)*8, maxSplitSize); | |
| 468 | + treeBlockSize = MAX(treeBlockSize + BlockSizeDelta, MinBlockSize); | |
| 469 | + tree_storage = cvCreateMemStorage( treeBlockSize ); | |
| 470 | + node_heap = cvCreateSet( 0, sizeof(node_heap[0]), sizeof(CvDTreeNode), tree_storage ); | |
| 471 | + split_heap = cvCreateSet( 0, sizeof(split_heap[0]), maxSplitSize, tree_storage ); | |
| 472 | +} | |
| 473 | + | |
| 474 | +CascadeBoostTrainData::CascadeBoostTrainData( const FeatureEvaluator* _featureEvaluator, | |
| 475 | + int _numSamples, | |
| 476 | + int _precalcValBufSize, int _precalcIdxBufSize, | |
| 477 | + const CvDTreeParams& _params ) | |
| 478 | +{ | |
| 479 | + setData( _featureEvaluator, _numSamples, _precalcValBufSize, _precalcIdxBufSize, _params ); | |
| 480 | +} | |
| 481 | + | |
| 482 | +void CascadeBoostTrainData::setData( const FeatureEvaluator* _featureEvaluator, | |
| 483 | + int _numSamples, | |
| 484 | + int _precalcValBufSize, int _precalcIdxBufSize, | |
| 485 | + const CvDTreeParams& _params ) | |
| 486 | +{ | |
| 487 | + int* idst = 0; | |
| 488 | + unsigned short* udst = 0; | |
| 489 | + | |
| 490 | + uint64 effective_buf_size = 0; | |
| 491 | + int effective_buf_height = 0, effective_buf_width = 0; | |
| 492 | + | |
| 493 | + | |
| 494 | + clear(); | |
| 495 | + shared = true; | |
| 496 | + have_labels = true; | |
| 497 | + have_priors = false; | |
| 498 | + is_classifier = true; | |
| 499 | + | |
| 500 | + rng = &cv::theRNG(); | |
| 501 | + | |
| 502 | + set_params( _params ); | |
| 503 | + | |
| 504 | + CV_Assert( _featureEvaluator ); | |
| 505 | + featureEvaluator = _featureEvaluator; | |
| 506 | + | |
| 507 | + max_c_count = MAX( 2, featureEvaluator->getMaxCatCount() ); | |
| 508 | + _resp = featureEvaluator->getCls(); | |
| 509 | + responses = &_resp; | |
| 510 | + // TODO: check responses: elements must be 0 or 1 | |
| 511 | + | |
| 512 | + if( _precalcValBufSize < 0 || _precalcIdxBufSize < 0) | |
| 513 | + CV_Error( CV_StsOutOfRange, "_numPrecalcVal and _numPrecalcIdx must be positive or 0" ); | |
| 514 | + | |
| 515 | + var_count = var_all = featureEvaluator->getNumFeatures() * featureEvaluator->getFeatureSize(); | |
| 516 | + sample_count = _numSamples; | |
| 517 | + | |
| 518 | + is_buf_16u = false; | |
| 519 | + if (sample_count < 65536) | |
| 520 | + is_buf_16u = true; | |
| 521 | + | |
| 522 | + numPrecalcVal = min( cvRound((double)_precalcValBufSize*1048576. / (sizeof(float)*sample_count)), var_count ); | |
| 523 | + numPrecalcIdx = min( cvRound((double)_precalcIdxBufSize*1048576. / | |
| 524 | + ((is_buf_16u ? sizeof(unsigned short) : sizeof (int))*sample_count)), var_count ); | |
| 525 | + | |
| 526 | + assert( numPrecalcIdx >= 0 && numPrecalcVal >= 0 ); | |
| 527 | + | |
| 528 | + valCache.create( numPrecalcVal, sample_count, CV_32FC1 ); | |
| 529 | + var_type = cvCreateMat( 1, var_count + 2, CV_32SC1 ); | |
| 530 | + if ( featureEvaluator->getMaxCatCount() > 0 ) | |
| 531 | + { | |
| 532 | + numPrecalcIdx = 0; | |
| 533 | + cat_var_count = var_count; | |
| 534 | + ord_var_count = 0; | |
| 535 | + for( int vi = 0; vi < var_count; vi++ ) | |
| 536 | + { | |
| 537 | + var_type->data.i[vi] = vi; | |
| 538 | + } | |
| 539 | + } | |
| 540 | + else | |
| 541 | + { | |
| 542 | + cat_var_count = 0; | |
| 543 | + ord_var_count = var_count; | |
| 544 | + for( int vi = 1; vi <= var_count; vi++ ) | |
| 545 | + { | |
| 546 | + var_type->data.i[vi-1] = -vi; | |
| 547 | + } | |
| 548 | + } | |
| 549 | + var_type->data.i[var_count] = cat_var_count; | |
| 550 | + var_type->data.i[var_count+1] = cat_var_count+1; | |
| 551 | + work_var_count = ( cat_var_count ? 0 : numPrecalcIdx ) + 1/*cv_lables*/; | |
| 552 | + buf_count = 2; | |
| 553 | + | |
| 554 | + buf_size = -1; // the member buf_size is obsolete | |
| 555 | + | |
| 556 | + effective_buf_size = (uint64)(work_var_count + 1)*(uint64)sample_count * buf_count; // this is the total size of "CvMat buf" to be allocated | |
| 557 | + effective_buf_width = sample_count; | |
| 558 | + effective_buf_height = work_var_count+1; | |
| 559 | + | |
| 560 | + if (effective_buf_width >= effective_buf_height) | |
| 561 | + effective_buf_height *= buf_count; | |
| 562 | + else | |
| 563 | + effective_buf_width *= buf_count; | |
| 564 | + | |
| 565 | + if ((uint64)effective_buf_width * (uint64)effective_buf_height != effective_buf_size) | |
| 566 | + { | |
| 567 | + CV_Error(CV_StsBadArg, "The memory buffer cannot be allocated since its size exceeds integer fields limit"); | |
| 568 | + } | |
| 569 | + if ( is_buf_16u ) | |
| 570 | + buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_16UC1 ); | |
| 571 | + else | |
| 572 | + buf = cvCreateMat( effective_buf_height, effective_buf_width, CV_32SC1 ); | |
| 573 | + | |
| 574 | + cat_count = cvCreateMat( 1, cat_var_count + 1, CV_32SC1 ); | |
| 575 | + | |
| 576 | + // precalculate valCache and set indices in buf | |
| 577 | + precalculate(); | |
| 578 | + | |
| 579 | + // now calculate the maximum size of split, | |
| 580 | + // create memory storage that will keep nodes and splits of the decision tree | |
| 581 | + // allocate root node and the buffer for the whole training data | |
| 582 | + int maxSplitSize = cvAlign(sizeof(CvDTreeSplit) + | |
| 583 | + (MAX(0,sample_count - 33)/32)*sizeof(int),sizeof(void*)); | |
| 584 | + int treeBlockSize = MAX((int)sizeof(CvDTreeNode)*8, maxSplitSize); | |
| 585 | + treeBlockSize = MAX(treeBlockSize + BlockSizeDelta, MinBlockSize); | |
| 586 | + tree_storage = cvCreateMemStorage( treeBlockSize ); | |
| 587 | + node_heap = cvCreateSet( 0, sizeof(*node_heap), sizeof(CvDTreeNode), tree_storage ); | |
| 588 | + | |
| 589 | + int nvSize = var_count*sizeof(int); | |
| 590 | + nvSize = cvAlign(MAX( nvSize, (int)sizeof(CvSetElem) ), sizeof(void*)); | |
| 591 | + int tempBlockSize = nvSize; | |
| 592 | + tempBlockSize = MAX( tempBlockSize + BlockSizeDelta, MinBlockSize ); | |
| 593 | + temp_storage = cvCreateMemStorage( tempBlockSize ); | |
| 594 | + nv_heap = cvCreateSet( 0, sizeof(*nv_heap), nvSize, temp_storage ); | |
| 595 | + | |
| 596 | + data_root = new_node( 0, sample_count, 0, 0 ); | |
| 597 | + | |
| 598 | + // set sample labels | |
| 599 | + if (is_buf_16u) | |
| 600 | + udst = (unsigned short*)(buf->data.s + work_var_count*sample_count); | |
| 601 | + else | |
| 602 | + idst = buf->data.i + work_var_count*sample_count; | |
| 603 | + | |
| 604 | + for (int si = 0; si < sample_count; si++) | |
| 605 | + { | |
| 606 | + if (udst) | |
| 607 | + udst[si] = (unsigned short)si; | |
| 608 | + else | |
| 609 | + idst[si] = si; | |
| 610 | + } | |
| 611 | + for( int vi = 0; vi < var_count; vi++ ) | |
| 612 | + data_root->set_num_valid(vi, sample_count); | |
| 613 | + for( int vi = 0; vi < cat_var_count; vi++ ) | |
| 614 | + cat_count->data.i[vi] = max_c_count; | |
| 615 | + | |
| 616 | + cat_count->data.i[cat_var_count] = 2; | |
| 617 | + | |
| 618 | + maxSplitSize = cvAlign(sizeof(CvDTreeSplit) + | |
| 619 | + (MAX(0,max_c_count - 33)/32)*sizeof(int),sizeof(void*)); | |
| 620 | + split_heap = cvCreateSet( 0, sizeof(*split_heap), maxSplitSize, tree_storage ); | |
| 621 | + | |
| 622 | + priors = cvCreateMat( 1, get_num_classes(), CV_64F ); | |
| 623 | + cvSet(priors, cvScalar(1)); | |
| 624 | + priors_mult = cvCloneMat( priors ); | |
| 625 | + counts = cvCreateMat( 1, get_num_classes(), CV_32SC1 ); | |
| 626 | + direction = cvCreateMat( 1, sample_count, CV_8UC1 ); | |
| 627 | + split_buf = cvCreateMat( 1, sample_count, CV_32SC1 );//TODO: make a pointer | |
| 628 | +} | |
| 629 | + | |
| 630 | +void CascadeBoostTrainData::free_train_data() | |
| 631 | +{ | |
| 632 | + CvDTreeTrainData::free_train_data(); | |
| 633 | + valCache.release(); | |
| 634 | +} | |
| 635 | + | |
| 636 | +const int* CascadeBoostTrainData::get_class_labels( CvDTreeNode* n, int* labelsBuf) | |
| 637 | +{ | |
| 638 | + int nodeSampleCount = n->sample_count; | |
| 639 | + int rStep = CV_IS_MAT_CONT( responses->type ) ? 1 : responses->step / CV_ELEM_SIZE( responses->type ); | |
| 640 | + | |
| 641 | + int* sampleIndicesBuf = labelsBuf; // | |
| 642 | + const int* sampleIndices = get_sample_indices(n, sampleIndicesBuf); | |
| 643 | + for( int si = 0; si < nodeSampleCount; si++ ) | |
| 644 | + { | |
| 645 | + int sidx = sampleIndices[si]; | |
| 646 | + labelsBuf[si] = (int)responses->data.fl[sidx*rStep]; | |
| 647 | + } | |
| 648 | + return labelsBuf; | |
| 649 | +} | |
| 650 | + | |
| 651 | +const int* CascadeBoostTrainData::get_sample_indices( CvDTreeNode* n, int* indicesBuf ) | |
| 652 | +{ | |
| 653 | + return CvDTreeTrainData::get_cat_var_data( n, get_work_var_count(), indicesBuf ); | |
| 654 | +} | |
| 655 | + | |
| 656 | +const int* CascadeBoostTrainData::get_cv_labels( CvDTreeNode* n, int* labels_buf ) | |
| 657 | +{ | |
| 658 | + return CvDTreeTrainData::get_cat_var_data( n, get_work_var_count() - 1, labels_buf ); | |
| 659 | +} | |
| 660 | + | |
| 661 | +void CascadeBoostTrainData::get_ord_var_data( CvDTreeNode* n, int vi, float* ordValuesBuf, int* sortedIndicesBuf, | |
| 662 | + const float** ordValues, const int** sortedIndices, int* sampleIndicesBuf ) | |
| 663 | +{ | |
| 664 | + int nodeSampleCount = n->sample_count; | |
| 665 | + const int* sampleIndices = get_sample_indices(n, sampleIndicesBuf); | |
| 666 | + | |
| 667 | + if ( vi < numPrecalcIdx ) | |
| 668 | + { | |
| 669 | + if( !is_buf_16u ) | |
| 670 | + *sortedIndices = buf->data.i + n->buf_idx*get_length_subbuf() + vi*sample_count + n->offset; | |
| 671 | + else | |
| 672 | + { | |
| 673 | + const unsigned short* shortIndices = (const unsigned short*)(buf->data.s + n->buf_idx*get_length_subbuf() + | |
| 674 | + vi*sample_count + n->offset ); | |
| 675 | + for( int i = 0; i < nodeSampleCount; i++ ) | |
| 676 | + sortedIndicesBuf[i] = shortIndices[i]; | |
| 677 | + | |
| 678 | + *sortedIndices = sortedIndicesBuf; | |
| 679 | + } | |
| 680 | + | |
| 681 | + if( vi < numPrecalcVal ) | |
| 682 | + { | |
| 683 | + for( int i = 0; i < nodeSampleCount; i++ ) | |
| 684 | + { | |
| 685 | + int idx = (*sortedIndices)[i]; | |
| 686 | + idx = sampleIndices[idx]; | |
| 687 | + ordValuesBuf[i] = valCache.at<float>( vi, idx); | |
| 688 | + } | |
| 689 | + } | |
| 690 | + else | |
| 691 | + { | |
| 692 | + for( int i = 0; i < nodeSampleCount; i++ ) | |
| 693 | + { | |
| 694 | + int idx = (*sortedIndices)[i]; | |
| 695 | + idx = sampleIndices[idx]; | |
| 696 | + ordValuesBuf[i] = (*featureEvaluator)( vi, idx); | |
| 697 | + } | |
| 698 | + } | |
| 699 | + } | |
| 700 | + else // vi >= numPrecalcIdx | |
| 701 | + { | |
| 702 | + cv::AutoBuffer<float> abuf(nodeSampleCount); | |
| 703 | + float* sampleValues = &abuf[0]; | |
| 704 | + | |
| 705 | + if ( vi < numPrecalcVal ) | |
| 706 | + { | |
| 707 | + for( int i = 0; i < nodeSampleCount; i++ ) | |
| 708 | + { | |
| 709 | + sortedIndicesBuf[i] = i; | |
| 710 | + sampleValues[i] = valCache.at<float>( vi, sampleIndices[i] ); | |
| 711 | + } | |
| 712 | + } | |
| 713 | + else | |
| 714 | + { | |
| 715 | + for( int i = 0; i < nodeSampleCount; i++ ) | |
| 716 | + { | |
| 717 | + sortedIndicesBuf[i] = i; | |
| 718 | + sampleValues[i] = (*featureEvaluator)( vi, sampleIndices[i]); | |
| 719 | + } | |
| 720 | + } | |
| 721 | + icvSortIntAux( sortedIndicesBuf, nodeSampleCount, &sampleValues[0] ); | |
| 722 | + for( int i = 0; i < nodeSampleCount; i++ ) | |
| 723 | + ordValuesBuf[i] = (&sampleValues[0])[sortedIndicesBuf[i]]; | |
| 724 | + *sortedIndices = sortedIndicesBuf; | |
| 725 | + } | |
| 726 | + | |
| 727 | + *ordValues = ordValuesBuf; | |
| 728 | +} | |
| 729 | + | |
| 730 | +const int* CascadeBoostTrainData::get_cat_var_data( CvDTreeNode* n, int vi, int* catValuesBuf ) | |
| 731 | +{ | |
| 732 | + int nodeSampleCount = n->sample_count; | |
| 733 | + int* sampleIndicesBuf = catValuesBuf; // | |
| 734 | + const int* sampleIndices = get_sample_indices(n, sampleIndicesBuf); | |
| 735 | + | |
| 736 | + if ( vi < numPrecalcVal ) | |
| 737 | + { | |
| 738 | + for( int i = 0; i < nodeSampleCount; i++ ) | |
| 739 | + catValuesBuf[i] = (int) valCache.at<float>( vi, sampleIndices[i]); | |
| 740 | + } | |
| 741 | + else | |
| 742 | + { | |
| 743 | + if( vi >= numPrecalcVal && vi < var_count ) | |
| 744 | + { | |
| 745 | + for( int i = 0; i < nodeSampleCount; i++ ) | |
| 746 | + catValuesBuf[i] = (int)(*featureEvaluator)( vi, sampleIndices[i] ); | |
| 747 | + } | |
| 748 | + else | |
| 749 | + { | |
| 750 | + get_cv_labels( n, catValuesBuf ); | |
| 751 | + } | |
| 752 | + } | |
| 753 | + | |
| 754 | + return catValuesBuf; | |
| 755 | +} | |
| 756 | + | |
| 757 | +float CascadeBoostTrainData::getVarValue( int vi, int si ) | |
| 758 | +{ | |
| 759 | + if ( vi < numPrecalcVal && !valCache.empty() ) | |
| 760 | + return valCache.at<float>( vi, si ); | |
| 761 | + return (*featureEvaluator)( vi, si ); | |
| 762 | +} | |
| 763 | + | |
| 764 | + | |
| 765 | +struct FeatureIdxOnlyPrecalc : ParallelLoopBody | |
| 766 | +{ | |
| 767 | + FeatureIdxOnlyPrecalc( const FeatureEvaluator* _featureEvaluator, CvMat* _buf, int _sample_count, bool _is_buf_16u ) | |
| 768 | + { | |
| 769 | + featureEvaluator = _featureEvaluator; | |
| 770 | + sample_count = _sample_count; | |
| 771 | + udst = (unsigned short*)_buf->data.s; | |
| 772 | + idst = _buf->data.i; | |
| 773 | + is_buf_16u = _is_buf_16u; | |
| 774 | + } | |
| 775 | + void operator()( const Range& range ) const | |
| 776 | + { | |
| 777 | + cv::AutoBuffer<float> valCache(sample_count); | |
| 778 | + float* valCachePtr = (float*)valCache; | |
| 779 | + for ( int fi = range.start; fi < range.end; fi++) | |
| 780 | + { | |
| 781 | + for( int si = 0; si < sample_count; si++ ) | |
| 782 | + { | |
| 783 | + valCachePtr[si] = (*featureEvaluator)( fi, si ); | |
| 784 | + if ( is_buf_16u ) | |
| 785 | + *(udst + fi*sample_count + si) = (unsigned short)si; | |
| 786 | + else | |
| 787 | + *(idst + fi*sample_count + si) = si; | |
| 788 | + } | |
| 789 | + if ( is_buf_16u ) | |
| 790 | + icvSortUShAux( udst + fi*sample_count, sample_count, valCachePtr ); | |
| 791 | + else | |
| 792 | + icvSortIntAux( idst + fi*sample_count, sample_count, valCachePtr ); | |
| 793 | + } | |
| 794 | + } | |
| 795 | + const FeatureEvaluator* featureEvaluator; | |
| 796 | + int sample_count; | |
| 797 | + int* idst; | |
| 798 | + unsigned short* udst; | |
| 799 | + bool is_buf_16u; | |
| 800 | +}; | |
| 801 | + | |
| 802 | +struct FeatureValAndIdxPrecalc : ParallelLoopBody | |
| 803 | +{ | |
| 804 | + FeatureValAndIdxPrecalc( const FeatureEvaluator* _featureEvaluator, CvMat* _buf, Mat* _valCache, int _sample_count, bool _is_buf_16u ) | |
| 805 | + { | |
| 806 | + featureEvaluator = _featureEvaluator; | |
| 807 | + valCache = _valCache; | |
| 808 | + sample_count = _sample_count; | |
| 809 | + udst = (unsigned short*)_buf->data.s; | |
| 810 | + idst = _buf->data.i; | |
| 811 | + is_buf_16u = _is_buf_16u; | |
| 812 | + } | |
| 813 | + void operator()( const Range& range ) const | |
| 814 | + { | |
| 815 | + for ( int fi = range.start; fi < range.end; fi++) | |
| 816 | + { | |
| 817 | + for( int si = 0; si < sample_count; si++ ) | |
| 818 | + { | |
| 819 | + valCache->at<float>(fi,si) = (*featureEvaluator)( fi, si ); | |
| 820 | + if ( is_buf_16u ) | |
| 821 | + *(udst + fi*sample_count + si) = (unsigned short)si; | |
| 822 | + else | |
| 823 | + *(idst + fi*sample_count + si) = si; | |
| 824 | + } | |
| 825 | + if ( is_buf_16u ) | |
| 826 | + icvSortUShAux( udst + fi*sample_count, sample_count, valCache->ptr<float>(fi) ); | |
| 827 | + else | |
| 828 | + icvSortIntAux( idst + fi*sample_count, sample_count, valCache->ptr<float>(fi) ); | |
| 829 | + } | |
| 830 | + } | |
| 831 | + const FeatureEvaluator* featureEvaluator; | |
| 832 | + Mat* valCache; | |
| 833 | + int sample_count; | |
| 834 | + int* idst; | |
| 835 | + unsigned short* udst; | |
| 836 | + bool is_buf_16u; | |
| 837 | +}; | |
| 838 | + | |
| 839 | +struct FeatureValOnlyPrecalc : ParallelLoopBody | |
| 840 | +{ | |
| 841 | + FeatureValOnlyPrecalc( const FeatureEvaluator* _featureEvaluator, Mat* _valCache, int _sample_count ) | |
| 842 | + { | |
| 843 | + featureEvaluator = _featureEvaluator; | |
| 844 | + valCache = _valCache; | |
| 845 | + sample_count = _sample_count; | |
| 846 | + } | |
| 847 | + void operator()( const Range& range ) const | |
| 848 | + { | |
| 849 | + for ( int fi = range.start; fi < range.end; fi++) | |
| 850 | + for( int si = 0; si < sample_count; si++ ) | |
| 851 | + valCache->at<float>(fi,si) = (*featureEvaluator)( fi, si ); | |
| 852 | + } | |
| 853 | + const FeatureEvaluator* featureEvaluator; | |
| 854 | + Mat* valCache; | |
| 855 | + int sample_count; | |
| 856 | +}; | |
| 857 | + | |
| 858 | +void CascadeBoostTrainData::precalculate() | |
| 859 | +{ | |
| 860 | + int minNum = MIN( numPrecalcVal, numPrecalcIdx); | |
| 861 | + | |
| 862 | + double proctime = -TIME( 0 ); | |
| 863 | + parallel_for_( Range(numPrecalcVal, numPrecalcIdx), | |
| 864 | + FeatureIdxOnlyPrecalc(featureEvaluator, buf, sample_count, is_buf_16u!=0) ); | |
| 865 | + parallel_for_( Range(0, minNum), | |
| 866 | + FeatureValAndIdxPrecalc(featureEvaluator, buf, &valCache, sample_count, is_buf_16u!=0) ); | |
| 867 | + parallel_for_( Range(minNum, numPrecalcVal), | |
| 868 | + FeatureValOnlyPrecalc(featureEvaluator, &valCache, sample_count) ); | |
| 869 | + cout << "Precalculation time: " << (proctime + TIME( 0 )) << endl; | |
| 870 | +} | |
| 871 | + | |
| 872 | +//-------------------------------- CascadeBoostTree ---------------------------------------- | |
| 873 | + | |
| 874 | +CvDTreeNode* CascadeBoostTree::predict( int sampleIdx ) const | |
| 875 | +{ | |
| 876 | + CvDTreeNode* node = root; | |
| 877 | + if( !node ) | |
| 878 | + CV_Error( CV_StsError, "The tree has not been trained yet" ); | |
| 879 | + | |
| 880 | + if ( ((CascadeBoostTrainData*)data)->featureEvaluator->getMaxCatCount() == 0 ) // ordered | |
| 881 | + { | |
| 882 | + while( node->left ) | |
| 883 | + { | |
| 884 | + CvDTreeSplit* split = node->split; | |
| 885 | + float val = ((CascadeBoostTrainData*)data)->getVarValue( split->var_idx, sampleIdx ); | |
| 886 | + node = val <= split->ord.c ? node->left : node->right; | |
| 887 | + } | |
| 888 | + } | |
| 889 | + else // categorical | |
| 890 | + { | |
| 891 | + while( node->left ) | |
| 892 | + { | |
| 893 | + CvDTreeSplit* split = node->split; | |
| 894 | + int c = (int)((CascadeBoostTrainData*)data)->getVarValue( split->var_idx, sampleIdx ); | |
| 895 | + node = CV_DTREE_CAT_DIR(c, split->subset) < 0 ? node->left : node->right; | |
| 896 | + } | |
| 897 | + } | |
| 898 | + return node; | |
| 899 | +} | |
| 900 | + | |
| 901 | +void CascadeBoostTree::write( FileStorage &fs, const Mat& featureMap ) | |
| 902 | +{ | |
| 903 | + int maxCatCount = ((CascadeBoostTrainData*)data)->featureEvaluator->getMaxCatCount(); | |
| 904 | + int subsetN = (maxCatCount + 31)/32; | |
| 905 | + queue<CvDTreeNode*> internalNodesQueue; | |
| 906 | + int size = (int)pow( 2.f, (float)ensemble->get_params().max_depth); | |
| 907 | + Ptr<float> leafVals = new float[size]; | |
| 908 | + int leafValIdx = 0; | |
| 909 | + int internalNodeIdx = 1; | |
| 910 | + CvDTreeNode* tempNode; | |
| 911 | + | |
| 912 | + CV_DbgAssert( root ); | |
| 913 | + internalNodesQueue.push( root ); | |
| 914 | + | |
| 915 | + fs << "{"; | |
| 916 | + fs << CC_INTERNAL_NODES << "[:"; | |
| 917 | + while (!internalNodesQueue.empty()) | |
| 918 | + { | |
| 919 | + tempNode = internalNodesQueue.front(); | |
| 920 | + CV_Assert( tempNode->left ); | |
| 921 | + if ( !tempNode->left->left && !tempNode->left->right) // left node is leaf | |
| 922 | + { | |
| 923 | + leafVals[-leafValIdx] = (float)tempNode->left->value; | |
| 924 | + fs << leafValIdx-- ; | |
| 925 | + } | |
| 926 | + else | |
| 927 | + { | |
| 928 | + internalNodesQueue.push( tempNode->left ); | |
| 929 | + fs << internalNodeIdx++; | |
| 930 | + } | |
| 931 | + CV_Assert( tempNode->right ); | |
| 932 | + if ( !tempNode->right->left && !tempNode->right->right) // right node is leaf | |
| 933 | + { | |
| 934 | + leafVals[-leafValIdx] = (float)tempNode->right->value; | |
| 935 | + fs << leafValIdx--; | |
| 936 | + } | |
| 937 | + else | |
| 938 | + { | |
| 939 | + internalNodesQueue.push( tempNode->right ); | |
| 940 | + fs << internalNodeIdx++; | |
| 941 | + } | |
| 942 | + int fidx = tempNode->split->var_idx; | |
| 943 | + fidx = featureMap.empty() ? fidx : featureMap.at<int>(0, fidx); | |
| 944 | + fs << fidx; | |
| 945 | + if ( !maxCatCount ) | |
| 946 | + fs << tempNode->split->ord.c; | |
| 947 | + else | |
| 948 | + for( int i = 0; i < subsetN; i++ ) | |
| 949 | + fs << tempNode->split->subset[i]; | |
| 950 | + internalNodesQueue.pop(); | |
| 951 | + } | |
| 952 | + fs << "]"; // CC_INTERNAL_NODES | |
| 953 | + | |
| 954 | + fs << CC_LEAF_VALUES << "[:"; | |
| 955 | + for (int ni = 0; ni < -leafValIdx; ni++) | |
| 956 | + fs << leafVals[ni]; | |
| 957 | + fs << "]"; // CC_LEAF_VALUES | |
| 958 | + fs << "}"; | |
| 959 | +} | |
| 960 | + | |
| 961 | +void CascadeBoostTree::read( const FileNode &node, CvBoost* _ensemble, | |
| 962 | + CvDTreeTrainData* _data ) | |
| 963 | +{ | |
| 964 | + int maxCatCount = ((CascadeBoostTrainData*)_data)->featureEvaluator->getMaxCatCount(); | |
| 965 | + int subsetN = (maxCatCount + 31)/32; | |
| 966 | + int step = 3 + ( maxCatCount>0 ? subsetN : 1 ); | |
| 967 | + | |
| 968 | + queue<CvDTreeNode*> internalNodesQueue; | |
| 969 | + FileNodeIterator internalNodesIt, leafValsuesIt; | |
| 970 | + CvDTreeNode* prntNode, *cldNode; | |
| 971 | + | |
| 972 | + clear(); | |
| 973 | + data = _data; | |
| 974 | + ensemble = _ensemble; | |
| 975 | + pruned_tree_idx = 0; | |
| 976 | + | |
| 977 | + // read tree nodes | |
| 978 | + FileNode rnode = node[CC_INTERNAL_NODES]; | |
| 979 | + internalNodesIt = rnode.end(); | |
| 980 | + leafValsuesIt = node[CC_LEAF_VALUES].end(); | |
| 981 | + internalNodesIt--; leafValsuesIt--; | |
| 982 | + for( size_t i = 0; i < rnode.size()/step; i++ ) | |
| 983 | + { | |
| 984 | + prntNode = data->new_node( 0, 0, 0, 0 ); | |
| 985 | + if ( maxCatCount > 0 ) | |
| 986 | + { | |
| 987 | + prntNode->split = data->new_split_cat( 0, 0 ); | |
| 988 | + for( int j = subsetN-1; j>=0; j--) | |
| 989 | + { | |
| 990 | + *internalNodesIt >> prntNode->split->subset[j]; internalNodesIt--; | |
| 991 | + } | |
| 992 | + } | |
| 993 | + else | |
| 994 | + { | |
| 995 | + float split_value; | |
| 996 | + *internalNodesIt >> split_value; internalNodesIt--; | |
| 997 | + prntNode->split = data->new_split_ord( 0, split_value, 0, 0, 0); | |
| 998 | + } | |
| 999 | + *internalNodesIt >> prntNode->split->var_idx; internalNodesIt--; | |
| 1000 | + int ridx, lidx; | |
| 1001 | + *internalNodesIt >> ridx; internalNodesIt--; | |
| 1002 | + *internalNodesIt >> lidx;internalNodesIt--; | |
| 1003 | + if ( ridx <= 0) | |
| 1004 | + { | |
| 1005 | + prntNode->right = cldNode = data->new_node( 0, 0, 0, 0 ); | |
| 1006 | + *leafValsuesIt >> cldNode->value; leafValsuesIt--; | |
| 1007 | + cldNode->parent = prntNode; | |
| 1008 | + } | |
| 1009 | + else | |
| 1010 | + { | |
| 1011 | + prntNode->right = internalNodesQueue.front(); | |
| 1012 | + prntNode->right->parent = prntNode; | |
| 1013 | + internalNodesQueue.pop(); | |
| 1014 | + } | |
| 1015 | + | |
| 1016 | + if ( lidx <= 0) | |
| 1017 | + { | |
| 1018 | + prntNode->left = cldNode = data->new_node( 0, 0, 0, 0 ); | |
| 1019 | + *leafValsuesIt >> cldNode->value; leafValsuesIt--; | |
| 1020 | + cldNode->parent = prntNode; | |
| 1021 | + } | |
| 1022 | + else | |
| 1023 | + { | |
| 1024 | + prntNode->left = internalNodesQueue.front(); | |
| 1025 | + prntNode->left->parent = prntNode; | |
| 1026 | + internalNodesQueue.pop(); | |
| 1027 | + } | |
| 1028 | + | |
| 1029 | + internalNodesQueue.push( prntNode ); | |
| 1030 | + } | |
| 1031 | + | |
| 1032 | + root = internalNodesQueue.front(); | |
| 1033 | + internalNodesQueue.pop(); | |
| 1034 | +} | |
| 1035 | + | |
| 1036 | +void CascadeBoostTree::split_node_data( CvDTreeNode* node ) | |
| 1037 | +{ | |
| 1038 | + int n = node->sample_count, nl, nr, scount = data->sample_count; | |
| 1039 | + char* dir = (char*)data->direction->data.ptr; | |
| 1040 | + CvDTreeNode *left = 0, *right = 0; | |
| 1041 | + int* newIdx = data->split_buf->data.i; | |
| 1042 | + int newBufIdx = data->get_child_buf_idx( node ); | |
| 1043 | + int workVarCount = data->get_work_var_count(); | |
| 1044 | + CvMat* buf = data->buf; | |
| 1045 | + size_t length_buf_row = data->get_length_subbuf(); | |
| 1046 | + cv::AutoBuffer<uchar> inn_buf(n*(3*sizeof(int)+sizeof(float))); | |
| 1047 | + int* tempBuf = (int*)(uchar*)inn_buf; | |
| 1048 | + bool splitInputData; | |
| 1049 | + | |
| 1050 | + complete_node_dir(node); | |
| 1051 | + | |
| 1052 | + for( int i = nl = nr = 0; i < n; i++ ) | |
| 1053 | + { | |
| 1054 | + int d = dir[i]; | |
| 1055 | + // initialize new indices for splitting ordered variables | |
| 1056 | + newIdx[i] = (nl & (d-1)) | (nr & -d); // d ? ri : li | |
| 1057 | + nr += d; | |
| 1058 | + nl += d^1; | |
| 1059 | + } | |
| 1060 | + | |
| 1061 | + node->left = left = data->new_node( node, nl, newBufIdx, node->offset ); | |
| 1062 | + node->right = right = data->new_node( node, nr, newBufIdx, node->offset + nl ); | |
| 1063 | + | |
| 1064 | + splitInputData = node->depth + 1 < data->params.max_depth && | |
| 1065 | + (node->left->sample_count > data->params.min_sample_count || | |
| 1066 | + node->right->sample_count > data->params.min_sample_count); | |
| 1067 | + | |
| 1068 | + // split ordered variables, keep both halves sorted. | |
| 1069 | + for( int vi = 0; vi < ((CascadeBoostTrainData*)data)->numPrecalcIdx; vi++ ) | |
| 1070 | + { | |
| 1071 | + int ci = data->get_var_type(vi); | |
| 1072 | + if( ci >= 0 || !splitInputData ) | |
| 1073 | + continue; | |
| 1074 | + | |
| 1075 | + int n1 = node->get_num_valid(vi); | |
| 1076 | + float *src_val_buf = (float*)(tempBuf + n); | |
| 1077 | + int *src_sorted_idx_buf = (int*)(src_val_buf + n); | |
| 1078 | + int *src_sample_idx_buf = src_sorted_idx_buf + n; | |
| 1079 | + const int* src_sorted_idx = 0; | |
| 1080 | + const float* src_val = 0; | |
| 1081 | + data->get_ord_var_data(node, vi, src_val_buf, src_sorted_idx_buf, &src_val, &src_sorted_idx, src_sample_idx_buf); | |
| 1082 | + | |
| 1083 | + for(int i = 0; i < n; i++) | |
| 1084 | + tempBuf[i] = src_sorted_idx[i]; | |
| 1085 | + | |
| 1086 | + if (data->is_buf_16u) | |
| 1087 | + { | |
| 1088 | + ushort *ldst, *rdst; | |
| 1089 | + ldst = (ushort*)(buf->data.s + left->buf_idx*length_buf_row + | |
| 1090 | + vi*scount + left->offset); | |
| 1091 | + rdst = (ushort*)(ldst + nl); | |
| 1092 | + | |
| 1093 | + // split sorted | |
| 1094 | + for( int i = 0; i < n1; i++ ) | |
| 1095 | + { | |
| 1096 | + int idx = tempBuf[i]; | |
| 1097 | + int d = dir[idx]; | |
| 1098 | + idx = newIdx[idx]; | |
| 1099 | + if (d) | |
| 1100 | + { | |
| 1101 | + *rdst = (ushort)idx; | |
| 1102 | + rdst++; | |
| 1103 | + } | |
| 1104 | + else | |
| 1105 | + { | |
| 1106 | + *ldst = (ushort)idx; | |
| 1107 | + ldst++; | |
| 1108 | + } | |
| 1109 | + } | |
| 1110 | + CV_Assert( n1 == n ); | |
| 1111 | + } | |
| 1112 | + else | |
| 1113 | + { | |
| 1114 | + int *ldst, *rdst; | |
| 1115 | + ldst = buf->data.i + left->buf_idx*length_buf_row + | |
| 1116 | + vi*scount + left->offset; | |
| 1117 | + rdst = buf->data.i + right->buf_idx*length_buf_row + | |
| 1118 | + vi*scount + right->offset; | |
| 1119 | + | |
| 1120 | + // split sorted | |
| 1121 | + for( int i = 0; i < n1; i++ ) | |
| 1122 | + { | |
| 1123 | + int idx = tempBuf[i]; | |
| 1124 | + int d = dir[idx]; | |
| 1125 | + idx = newIdx[idx]; | |
| 1126 | + if (d) | |
| 1127 | + { | |
| 1128 | + *rdst = idx; | |
| 1129 | + rdst++; | |
| 1130 | + } | |
| 1131 | + else | |
| 1132 | + { | |
| 1133 | + *ldst = idx; | |
| 1134 | + ldst++; | |
| 1135 | + } | |
| 1136 | + } | |
| 1137 | + CV_Assert( n1 == n ); | |
| 1138 | + } | |
| 1139 | + } | |
| 1140 | + | |
| 1141 | + // split cv_labels using newIdx relocation table | |
| 1142 | + int *src_lbls_buf = tempBuf + n; | |
| 1143 | + const int* src_lbls = data->get_cv_labels(node, src_lbls_buf); | |
| 1144 | + | |
| 1145 | + for(int i = 0; i < n; i++) | |
| 1146 | + tempBuf[i] = src_lbls[i]; | |
| 1147 | + | |
| 1148 | + if (data->is_buf_16u) | |
| 1149 | + { | |
| 1150 | + unsigned short *ldst = (unsigned short *)(buf->data.s + left->buf_idx*length_buf_row + | |
| 1151 | + (workVarCount-1)*scount + left->offset); | |
| 1152 | + unsigned short *rdst = (unsigned short *)(buf->data.s + right->buf_idx*length_buf_row + | |
| 1153 | + (workVarCount-1)*scount + right->offset); | |
| 1154 | + | |
| 1155 | + for( int i = 0; i < n; i++ ) | |
| 1156 | + { | |
| 1157 | + int idx = tempBuf[i]; | |
| 1158 | + if (dir[i]) | |
| 1159 | + { | |
| 1160 | + *rdst = (unsigned short)idx; | |
| 1161 | + rdst++; | |
| 1162 | + } | |
| 1163 | + else | |
| 1164 | + { | |
| 1165 | + *ldst = (unsigned short)idx; | |
| 1166 | + ldst++; | |
| 1167 | + } | |
| 1168 | + } | |
| 1169 | + | |
| 1170 | + } | |
| 1171 | + else | |
| 1172 | + { | |
| 1173 | + int *ldst = buf->data.i + left->buf_idx*length_buf_row + | |
| 1174 | + (workVarCount-1)*scount + left->offset; | |
| 1175 | + int *rdst = buf->data.i + right->buf_idx*length_buf_row + | |
| 1176 | + (workVarCount-1)*scount + right->offset; | |
| 1177 | + | |
| 1178 | + for( int i = 0; i < n; i++ ) | |
| 1179 | + { | |
| 1180 | + int idx = tempBuf[i]; | |
| 1181 | + if (dir[i]) | |
| 1182 | + { | |
| 1183 | + *rdst = idx; | |
| 1184 | + rdst++; | |
| 1185 | + } | |
| 1186 | + else | |
| 1187 | + { | |
| 1188 | + *ldst = idx; | |
| 1189 | + ldst++; | |
| 1190 | + } | |
| 1191 | + } | |
| 1192 | + } | |
| 1193 | + | |
| 1194 | + // split sample indices | |
| 1195 | + int *sampleIdx_src_buf = tempBuf + n; | |
| 1196 | + const int* sampleIdx_src = data->get_sample_indices(node, sampleIdx_src_buf); | |
| 1197 | + | |
| 1198 | + for(int i = 0; i < n; i++) | |
| 1199 | + tempBuf[i] = sampleIdx_src[i]; | |
| 1200 | + | |
| 1201 | + if (data->is_buf_16u) | |
| 1202 | + { | |
| 1203 | + unsigned short* ldst = (unsigned short*)(buf->data.s + left->buf_idx*length_buf_row + | |
| 1204 | + workVarCount*scount + left->offset); | |
| 1205 | + unsigned short* rdst = (unsigned short*)(buf->data.s + right->buf_idx*length_buf_row + | |
| 1206 | + workVarCount*scount + right->offset); | |
| 1207 | + for (int i = 0; i < n; i++) | |
| 1208 | + { | |
| 1209 | + unsigned short idx = (unsigned short)tempBuf[i]; | |
| 1210 | + if (dir[i]) | |
| 1211 | + { | |
| 1212 | + *rdst = idx; | |
| 1213 | + rdst++; | |
| 1214 | + } | |
| 1215 | + else | |
| 1216 | + { | |
| 1217 | + *ldst = idx; | |
| 1218 | + ldst++; | |
| 1219 | + } | |
| 1220 | + } | |
| 1221 | + } | |
| 1222 | + else | |
| 1223 | + { | |
| 1224 | + int* ldst = buf->data.i + left->buf_idx*length_buf_row + | |
| 1225 | + workVarCount*scount + left->offset; | |
| 1226 | + int* rdst = buf->data.i + right->buf_idx*length_buf_row + | |
| 1227 | + workVarCount*scount + right->offset; | |
| 1228 | + for (int i = 0; i < n; i++) | |
| 1229 | + { | |
| 1230 | + int idx = tempBuf[i]; | |
| 1231 | + if (dir[i]) | |
| 1232 | + { | |
| 1233 | + *rdst = idx; | |
| 1234 | + rdst++; | |
| 1235 | + } | |
| 1236 | + else | |
| 1237 | + { | |
| 1238 | + *ldst = idx; | |
| 1239 | + ldst++; | |
| 1240 | + } | |
| 1241 | + } | |
| 1242 | + } | |
| 1243 | + | |
| 1244 | + for( int vi = 0; vi < data->var_count; vi++ ) | |
| 1245 | + { | |
| 1246 | + left->set_num_valid(vi, (int)(nl)); | |
| 1247 | + right->set_num_valid(vi, (int)(nr)); | |
| 1248 | + } | |
| 1249 | + | |
| 1250 | + // deallocate the parent node data that is not needed anymore | |
| 1251 | + data->free_node_data(node); | |
| 1252 | +} | |
| 1253 | + | |
| 1254 | +static void auxMarkFeaturesInMap( const CvDTreeNode* node, Mat& featureMap) | |
| 1255 | +{ | |
| 1256 | + if ( node && node->split ) | |
| 1257 | + { | |
| 1258 | + featureMap.ptr<int>(0)[node->split->var_idx] = 1; | |
| 1259 | + auxMarkFeaturesInMap( node->left, featureMap ); | |
| 1260 | + auxMarkFeaturesInMap( node->right, featureMap ); | |
| 1261 | + } | |
| 1262 | +} | |
| 1263 | + | |
| 1264 | +void CascadeBoostTree::markFeaturesInMap( Mat& featureMap ) | |
| 1265 | +{ | |
| 1266 | + auxMarkFeaturesInMap( root, featureMap ); | |
| 1267 | +} | |
| 1268 | + | |
| 1269 | +//----------------------------------- CascadeBoost -------------------------------------- | |
| 1270 | + | |
| 1271 | +bool CascadeBoost::train( const FeatureEvaluator* _featureEvaluator, | |
| 1272 | + int _numSamples, | |
| 1273 | + int _precalcValBufSize, int _precalcIdxBufSize, | |
| 1274 | + const CascadeBoostParams& _params ) | |
| 1275 | +{ | |
| 1276 | + bool isTrained = false; | |
| 1277 | + CV_Assert( !data ); | |
| 1278 | + clear(); | |
| 1279 | + | |
| 1280 | + data = new CascadeBoostTrainData( _featureEvaluator, _numSamples, | |
| 1281 | + _precalcValBufSize, _precalcIdxBufSize, _params ); | |
| 1282 | + | |
| 1283 | + CvMemStorage *storage = cvCreateMemStorage(); | |
| 1284 | + weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage ); | |
| 1285 | + storage = 0; | |
| 1286 | + | |
| 1287 | + set_params( _params ); | |
| 1288 | + if ( (_params.boost_type == LOGIT) || (_params.boost_type == GENTLE) ) | |
| 1289 | + data->do_responses_copy(); | |
| 1290 | + | |
| 1291 | + update_weights( 0 ); | |
| 1292 | + | |
| 1293 | + cout << "+----+---------+---------+" << endl; | |
| 1294 | + cout << "| N | HR | FA |" << endl; | |
| 1295 | + cout << "+----+---------+---------+" << endl; | |
| 1296 | + | |
| 1297 | + do | |
| 1298 | + { | |
| 1299 | + CascadeBoostTree* tree = new CascadeBoostTree; | |
| 1300 | + if( !tree->train( data, subsample_mask, this ) ) | |
| 1301 | + { | |
| 1302 | + delete tree; | |
| 1303 | + break; | |
| 1304 | + } | |
| 1305 | + | |
| 1306 | + cvSeqPush( weak, &tree ); | |
| 1307 | + update_weights( tree ); | |
| 1308 | + trim_weights(); | |
| 1309 | + if( cvCountNonZero(subsample_mask) == 0 ) | |
| 1310 | + break; | |
| 1311 | + } | |
| 1312 | + while( !isErrDesired() && (weak->total < params.weak_count) ); | |
| 1313 | + | |
| 1314 | + if(weak->total > 0) | |
| 1315 | + { | |
| 1316 | + data->is_classifier = true; | |
| 1317 | + data->free_train_data(); | |
| 1318 | + isTrained = true; | |
| 1319 | + } | |
| 1320 | + else | |
| 1321 | + clear(); | |
| 1322 | + | |
| 1323 | + return isTrained; | |
| 1324 | +} | |
| 1325 | + | |
| 1326 | +float CascadeBoost::predict( int sampleIdx, bool returnSum ) const | |
| 1327 | +{ | |
| 1328 | + CV_Assert( weak ); | |
| 1329 | + double sum = 0; | |
| 1330 | + CvSeqReader reader; | |
| 1331 | + cvStartReadSeq( weak, &reader ); | |
| 1332 | + cvSetSeqReaderPos( &reader, 0 ); | |
| 1333 | + for( int i = 0; i < weak->total; i++ ) | |
| 1334 | + { | |
| 1335 | + CvBoostTree* wtree; | |
| 1336 | + CV_READ_SEQ_ELEM( wtree, reader ); | |
| 1337 | + sum += ((CascadeBoostTree*)wtree)->predict(sampleIdx)->value; | |
| 1338 | + } | |
| 1339 | + if( !returnSum ) | |
| 1340 | + sum = sum < threshold - CV_THRESHOLD_EPS ? 0.0 : 1.0; | |
| 1341 | + return (float)sum; | |
| 1342 | +} | |
| 1343 | + | |
| 1344 | +bool CascadeBoost::set_params( const CvBoostParams& _params ) | |
| 1345 | +{ | |
| 1346 | + minHitRate = ((CascadeBoostParams&)_params).minHitRate; | |
| 1347 | + maxFalseAlarm = ((CascadeBoostParams&)_params).maxFalseAlarm; | |
| 1348 | + return ( ( minHitRate > 0 ) && ( minHitRate < 1) && | |
| 1349 | + ( maxFalseAlarm > 0 ) && ( maxFalseAlarm < 1) && | |
| 1350 | + CvBoost::set_params( _params )); | |
| 1351 | +} | |
| 1352 | + | |
| 1353 | +void CascadeBoost::update_weights( CvBoostTree* tree ) | |
| 1354 | +{ | |
| 1355 | + int n = data->sample_count; | |
| 1356 | + double sumW = 0.; | |
| 1357 | + int step = 0; | |
| 1358 | + float* fdata = 0; | |
| 1359 | + int *sampleIdxBuf; | |
| 1360 | + const int* sampleIdx = 0; | |
| 1361 | + int inn_buf_size = ((params.boost_type == LOGIT) || (params.boost_type == GENTLE) ? n*sizeof(int) : 0) + | |
| 1362 | + ( !tree ? n*sizeof(int) : 0 ); | |
| 1363 | + cv::AutoBuffer<uchar> inn_buf(inn_buf_size); | |
| 1364 | + uchar* cur_inn_buf_pos = (uchar*)inn_buf; | |
| 1365 | + if ( (params.boost_type == LOGIT) || (params.boost_type == GENTLE) ) | |
| 1366 | + { | |
| 1367 | + step = CV_IS_MAT_CONT(data->responses_copy->type) ? | |
| 1368 | + 1 : data->responses_copy->step / CV_ELEM_SIZE(data->responses_copy->type); | |
| 1369 | + fdata = data->responses_copy->data.fl; | |
| 1370 | + sampleIdxBuf = (int*)cur_inn_buf_pos; cur_inn_buf_pos = (uchar*)(sampleIdxBuf + n); | |
| 1371 | + sampleIdx = data->get_sample_indices( data->data_root, sampleIdxBuf ); | |
| 1372 | + } | |
| 1373 | + CvMat* buf = data->buf; | |
| 1374 | + size_t length_buf_row = data->get_length_subbuf(); | |
| 1375 | + if( !tree ) // before training the first tree, initialize weights and other parameters | |
| 1376 | + { | |
| 1377 | + int* classLabelsBuf = (int*)cur_inn_buf_pos; cur_inn_buf_pos = (uchar*)(classLabelsBuf + n); | |
| 1378 | + const int* classLabels = data->get_class_labels(data->data_root, classLabelsBuf); | |
| 1379 | + // in case of logitboost and gentle adaboost each weak tree is a regression tree, | |
| 1380 | + // so we need to convert class labels to floating-point values | |
| 1381 | + double w0 = 1./n; | |
| 1382 | + double p[2] = { 1, 1 }; | |
| 1383 | + | |
| 1384 | + cvReleaseMat( &orig_response ); | |
| 1385 | + cvReleaseMat( &sum_response ); | |
| 1386 | + cvReleaseMat( &weak_eval ); | |
| 1387 | + cvReleaseMat( &subsample_mask ); | |
| 1388 | + cvReleaseMat( &weights ); | |
| 1389 | + | |
| 1390 | + orig_response = cvCreateMat( 1, n, CV_32S ); | |
| 1391 | + weak_eval = cvCreateMat( 1, n, CV_64F ); | |
| 1392 | + subsample_mask = cvCreateMat( 1, n, CV_8U ); | |
| 1393 | + weights = cvCreateMat( 1, n, CV_64F ); | |
| 1394 | + subtree_weights = cvCreateMat( 1, n + 2, CV_64F ); | |
| 1395 | + | |
| 1396 | + if (data->is_buf_16u) | |
| 1397 | + { | |
| 1398 | + unsigned short* labels = (unsigned short*)(buf->data.s + data->data_root->buf_idx*length_buf_row + | |
| 1399 | + data->data_root->offset + (data->work_var_count-1)*data->sample_count); | |
| 1400 | + for( int i = 0; i < n; i++ ) | |
| 1401 | + { | |
| 1402 | + // save original categorical responses {0,1}, convert them to {-1,1} | |
| 1403 | + orig_response->data.i[i] = classLabels[i]*2 - 1; | |
| 1404 | + // make all the samples active at start. | |
| 1405 | + // later, in trim_weights() deactivate/reactive again some, if need | |
| 1406 | + subsample_mask->data.ptr[i] = (uchar)1; | |
| 1407 | + // make all the initial weights the same. | |
| 1408 | + weights->data.db[i] = w0*p[classLabels[i]]; | |
| 1409 | + // set the labels to find (from within weak tree learning proc) | |
| 1410 | + // the particular sample weight, and where to store the response. | |
| 1411 | + labels[i] = (unsigned short)i; | |
| 1412 | + } | |
| 1413 | + } | |
| 1414 | + else | |
| 1415 | + { | |
| 1416 | + int* labels = buf->data.i + data->data_root->buf_idx*length_buf_row + | |
| 1417 | + data->data_root->offset + (data->work_var_count-1)*data->sample_count; | |
| 1418 | + | |
| 1419 | + for( int i = 0; i < n; i++ ) | |
| 1420 | + { | |
| 1421 | + // save original categorical responses {0,1}, convert them to {-1,1} | |
| 1422 | + orig_response->data.i[i] = classLabels[i]*2 - 1; | |
| 1423 | + subsample_mask->data.ptr[i] = (uchar)1; | |
| 1424 | + weights->data.db[i] = w0*p[classLabels[i]]; | |
| 1425 | + labels[i] = i; | |
| 1426 | + } | |
| 1427 | + } | |
| 1428 | + | |
| 1429 | + if( params.boost_type == LOGIT ) | |
| 1430 | + { | |
| 1431 | + sum_response = cvCreateMat( 1, n, CV_64F ); | |
| 1432 | + | |
| 1433 | + for( int i = 0; i < n; i++ ) | |
| 1434 | + { | |
| 1435 | + sum_response->data.db[i] = 0; | |
| 1436 | + fdata[sampleIdx[i]*step] = orig_response->data.i[i] > 0 ? 2.f : -2.f; | |
| 1437 | + } | |
| 1438 | + | |
| 1439 | + // in case of logitboost each weak tree is a regression tree. | |
| 1440 | + // the target function values are recalculated for each of the trees | |
| 1441 | + data->is_classifier = false; | |
| 1442 | + } | |
| 1443 | + else if( params.boost_type == GENTLE ) | |
| 1444 | + { | |
| 1445 | + for( int i = 0; i < n; i++ ) | |
| 1446 | + fdata[sampleIdx[i]*step] = (float)orig_response->data.i[i]; | |
| 1447 | + | |
| 1448 | + data->is_classifier = false; | |
| 1449 | + } | |
| 1450 | + } | |
| 1451 | + else | |
| 1452 | + { | |
| 1453 | + // at this moment, for all the samples that participated in the training of the most | |
| 1454 | + // recent weak classifier we know the responses. For other samples we need to compute them | |
| 1455 | + if( have_subsample ) | |
| 1456 | + { | |
| 1457 | + // invert the subsample mask | |
| 1458 | + cvXorS( subsample_mask, cvScalar(1.), subsample_mask ); | |
| 1459 | + | |
| 1460 | + // run tree through all the non-processed samples | |
| 1461 | + for( int i = 0; i < n; i++ ) | |
| 1462 | + if( subsample_mask->data.ptr[i] ) | |
| 1463 | + { | |
| 1464 | + weak_eval->data.db[i] = ((CascadeBoostTree*)tree)->predict( i )->value; | |
| 1465 | + } | |
| 1466 | + } | |
| 1467 | + | |
| 1468 | + // now update weights and other parameters for each type of boosting | |
| 1469 | + if( params.boost_type == DISCRETE ) | |
| 1470 | + { | |
| 1471 | + // Discrete AdaBoost: | |
| 1472 | + // weak_eval[i] (=f(x_i)) is in {-1,1} | |
| 1473 | + // err = sum(w_i*(f(x_i) != y_i))/sum(w_i) | |
| 1474 | + // C = log((1-err)/err) | |
| 1475 | + // w_i *= exp(C*(f(x_i) != y_i)) | |
| 1476 | + | |
| 1477 | + double C, err = 0.; | |
| 1478 | + double scale[] = { 1., 0. }; | |
| 1479 | + | |
| 1480 | + for( int i = 0; i < n; i++ ) | |
| 1481 | + { | |
| 1482 | + double w = weights->data.db[i]; | |
| 1483 | + sumW += w; | |
| 1484 | + err += w*(weak_eval->data.db[i] != orig_response->data.i[i]); | |
| 1485 | + } | |
| 1486 | + | |
| 1487 | + if( sumW != 0 ) | |
| 1488 | + err /= sumW; | |
| 1489 | + C = err = -logRatio( err ); | |
| 1490 | + scale[1] = exp(err); | |
| 1491 | + | |
| 1492 | + sumW = 0; | |
| 1493 | + for( int i = 0; i < n; i++ ) | |
| 1494 | + { | |
| 1495 | + double w = weights->data.db[i]* | |
| 1496 | + scale[weak_eval->data.db[i] != orig_response->data.i[i]]; | |
| 1497 | + sumW += w; | |
| 1498 | + weights->data.db[i] = w; | |
| 1499 | + } | |
| 1500 | + | |
| 1501 | + tree->scale( C ); | |
| 1502 | + } | |
| 1503 | + else if( params.boost_type == REAL ) | |
| 1504 | + { | |
| 1505 | + // Real AdaBoost: | |
| 1506 | + // weak_eval[i] = f(x_i) = 0.5*log(p(x_i)/(1-p(x_i))), p(x_i)=P(y=1|x_i) | |
| 1507 | + // w_i *= exp(-y_i*f(x_i)) | |
| 1508 | + | |
| 1509 | + for( int i = 0; i < n; i++ ) | |
| 1510 | + weak_eval->data.db[i] *= -orig_response->data.i[i]; | |
| 1511 | + | |
| 1512 | + cvExp( weak_eval, weak_eval ); | |
| 1513 | + | |
| 1514 | + for( int i = 0; i < n; i++ ) | |
| 1515 | + { | |
| 1516 | + double w = weights->data.db[i]*weak_eval->data.db[i]; | |
| 1517 | + sumW += w; | |
| 1518 | + weights->data.db[i] = w; | |
| 1519 | + } | |
| 1520 | + } | |
| 1521 | + else if( params.boost_type == LOGIT ) | |
| 1522 | + { | |
| 1523 | + // LogitBoost: | |
| 1524 | + // weak_eval[i] = f(x_i) in [-z_max,z_max] | |
| 1525 | + // sum_response = F(x_i). | |
| 1526 | + // F(x_i) += 0.5*f(x_i) | |
| 1527 | + // p(x_i) = exp(F(x_i))/(exp(F(x_i)) + exp(-F(x_i))=1/(1+exp(-2*F(x_i))) | |
| 1528 | + // reuse weak_eval: weak_eval[i] <- p(x_i) | |
| 1529 | + // w_i = p(x_i)*1(1 - p(x_i)) | |
| 1530 | + // z_i = ((y_i+1)/2 - p(x_i))/(p(x_i)*(1 - p(x_i))) | |
| 1531 | + // store z_i to the data->data_root as the new target responses | |
| 1532 | + | |
| 1533 | + const double lbWeightThresh = FLT_EPSILON; | |
| 1534 | + const double lbZMax = 10.; | |
| 1535 | + | |
| 1536 | + for( int i = 0; i < n; i++ ) | |
| 1537 | + { | |
| 1538 | + double s = sum_response->data.db[i] + 0.5*weak_eval->data.db[i]; | |
| 1539 | + sum_response->data.db[i] = s; | |
| 1540 | + weak_eval->data.db[i] = -2*s; | |
| 1541 | + } | |
| 1542 | + | |
| 1543 | + cvExp( weak_eval, weak_eval ); | |
| 1544 | + | |
| 1545 | + for( int i = 0; i < n; i++ ) | |
| 1546 | + { | |
| 1547 | + double p = 1./(1. + weak_eval->data.db[i]); | |
| 1548 | + double w = p*(1 - p), z; | |
| 1549 | + w = MAX( w, lbWeightThresh ); | |
| 1550 | + weights->data.db[i] = w; | |
| 1551 | + sumW += w; | |
| 1552 | + if( orig_response->data.i[i] > 0 ) | |
| 1553 | + { | |
| 1554 | + z = 1./p; | |
| 1555 | + fdata[sampleIdx[i]*step] = (float)min(z, lbZMax); | |
| 1556 | + } | |
| 1557 | + else | |
| 1558 | + { | |
| 1559 | + z = 1./(1-p); | |
| 1560 | + fdata[sampleIdx[i]*step] = (float)-min(z, lbZMax); | |
| 1561 | + } | |
| 1562 | + } | |
| 1563 | + } | |
| 1564 | + else | |
| 1565 | + { | |
| 1566 | + // Gentle AdaBoost: | |
| 1567 | + // weak_eval[i] = f(x_i) in [-1,1] | |
| 1568 | + // w_i *= exp(-y_i*f(x_i)) | |
| 1569 | + assert( params.boost_type == GENTLE ); | |
| 1570 | + | |
| 1571 | + for( int i = 0; i < n; i++ ) | |
| 1572 | + weak_eval->data.db[i] *= -orig_response->data.i[i]; | |
| 1573 | + | |
| 1574 | + cvExp( weak_eval, weak_eval ); | |
| 1575 | + | |
| 1576 | + for( int i = 0; i < n; i++ ) | |
| 1577 | + { | |
| 1578 | + double w = weights->data.db[i] * weak_eval->data.db[i]; | |
| 1579 | + weights->data.db[i] = w; | |
| 1580 | + sumW += w; | |
| 1581 | + } | |
| 1582 | + } | |
| 1583 | + } | |
| 1584 | + | |
| 1585 | + // renormalize weights | |
| 1586 | + if( sumW > FLT_EPSILON ) | |
| 1587 | + { | |
| 1588 | + sumW = 1./sumW; | |
| 1589 | + for( int i = 0; i < n; ++i ) | |
| 1590 | + weights->data.db[i] *= sumW; | |
| 1591 | + } | |
| 1592 | +} | |
| 1593 | + | |
| 1594 | +bool CascadeBoost::isErrDesired() | |
| 1595 | +{ | |
| 1596 | + int sCount = data->sample_count, | |
| 1597 | + numPos = 0, numNeg = 0, numFalse = 0, numPosTrue = 0; | |
| 1598 | + vector<float> eval(sCount); | |
| 1599 | + | |
| 1600 | + for( int i = 0; i < sCount; i++ ) | |
| 1601 | + if( ((CascadeBoostTrainData*)data)->featureEvaluator->getCls( i ) == 1.0F ) | |
| 1602 | + eval[numPos++] = predict( i, true ); | |
| 1603 | + icvSortFlt( &eval[0], numPos, 0 ); | |
| 1604 | + int thresholdIdx = (int)((1.0F - minHitRate) * numPos); | |
| 1605 | + threshold = eval[ thresholdIdx ]; | |
| 1606 | + numPosTrue = numPos - thresholdIdx; | |
| 1607 | + for( int i = thresholdIdx - 1; i >= 0; i--) | |
| 1608 | + if ( abs( eval[i] - threshold) < FLT_EPSILON ) | |
| 1609 | + numPosTrue++; | |
| 1610 | + float hitRate = ((float) numPosTrue) / ((float) numPos); | |
| 1611 | + | |
| 1612 | + for( int i = 0; i < sCount; i++ ) | |
| 1613 | + { | |
| 1614 | + if( ((CascadeBoostTrainData*)data)->featureEvaluator->getCls( i ) == 0.0F ) | |
| 1615 | + { | |
| 1616 | + numNeg++; | |
| 1617 | + if( predict( i ) ) | |
| 1618 | + numFalse++; | |
| 1619 | + } | |
| 1620 | + } | |
| 1621 | + float falseAlarm = ((float) numFalse) / ((float) numNeg); | |
| 1622 | + | |
| 1623 | + cout << "|"; cout.width(4); cout << right << weak->total; | |
| 1624 | + cout << "|"; cout.width(9); cout << right << hitRate; | |
| 1625 | + cout << "|"; cout.width(9); cout << right << falseAlarm; | |
| 1626 | + cout << "|" << endl; | |
| 1627 | + cout << "+----+---------+---------+" << endl; | |
| 1628 | + | |
| 1629 | + return falseAlarm <= maxFalseAlarm; | |
| 1630 | +} | |
| 1631 | + | |
| 1632 | +void CascadeBoost::write( FileStorage &fs, const Mat& featureMap ) const | |
| 1633 | +{ | |
| 1634 | +// char cmnt[30]; | |
| 1635 | + CascadeBoostTree* weakTree; | |
| 1636 | + fs << CC_WEAK_COUNT << weak->total; | |
| 1637 | + fs << CC_STAGE_THRESHOLD << threshold; | |
| 1638 | + fs << CC_WEAK_CLASSIFIERS << "["; | |
| 1639 | + for( int wi = 0; wi < weak->total; wi++) | |
| 1640 | + { | |
| 1641 | + /*sprintf( cmnt, "tree %i", wi ); | |
| 1642 | + cvWriteComment( fs, cmnt, 0 );*/ | |
| 1643 | + weakTree = *((CascadeBoostTree**) cvGetSeqElem( weak, wi )); | |
| 1644 | + weakTree->write( fs, featureMap ); | |
| 1645 | + } | |
| 1646 | + fs << "]"; | |
| 1647 | +} | |
| 1648 | + | |
| 1649 | +bool CascadeBoost::read( const FileNode &node, | |
| 1650 | + const FeatureEvaluator* _featureEvaluator, | |
| 1651 | + const CascadeBoostParams& _params ) | |
| 1652 | +{ | |
| 1653 | + CvMemStorage* storage; | |
| 1654 | + clear(); | |
| 1655 | + data = new CascadeBoostTrainData( _featureEvaluator, _params ); | |
| 1656 | + set_params( _params ); | |
| 1657 | + | |
| 1658 | + node[CC_STAGE_THRESHOLD] >> threshold; | |
| 1659 | + FileNode rnode = node[CC_WEAK_CLASSIFIERS]; | |
| 1660 | + | |
| 1661 | + storage = cvCreateMemStorage(); | |
| 1662 | + weak = cvCreateSeq( 0, sizeof(CvSeq), sizeof(CvBoostTree*), storage ); | |
| 1663 | + for( FileNodeIterator it = rnode.begin(); it != rnode.end(); it++ ) | |
| 1664 | + { | |
| 1665 | + CascadeBoostTree* tree = new CascadeBoostTree(); | |
| 1666 | + tree->read( *it, this, data ); | |
| 1667 | + cvSeqPush( weak, &tree ); | |
| 1668 | + } | |
| 1669 | + return true; | |
| 1670 | +} | |
| 1671 | + | |
| 1672 | +void CascadeBoost::markUsedFeaturesInMap( Mat& featureMap ) | |
| 1673 | +{ | |
| 1674 | + for( int wi = 0; wi < weak->total; wi++ ) | |
| 1675 | + { | |
| 1676 | + CascadeBoostTree* weakTree = *((CascadeBoostTree**) cvGetSeqElem( weak, wi )); | |
| 1677 | + weakTree->markFeaturesInMap( featureMap ); | |
| 1678 | + } | |
| 1679 | +} | |
| 1680 | + | ... | ... |
openbr/core/boost.h
0 → 100644
| 1 | +#ifndef _BOOST_H_ | |
| 2 | +#define _BOOST_H_ | |
| 3 | + | |
| 4 | +#include "features.h" | |
| 5 | +#include "ml.h" | |
| 6 | + | |
| 7 | +namespace br | |
| 8 | +{ | |
| 9 | + | |
| 10 | +struct CascadeBoostParams : CvBoostParams | |
| 11 | +{ | |
| 12 | + float minHitRate; | |
| 13 | + float maxFalseAlarm; | |
| 14 | + | |
| 15 | + CascadeBoostParams(); | |
| 16 | + CascadeBoostParams( int _boostType, float _minHitRate, float _maxFalseAlarm, | |
| 17 | + double _weightTrimRate, int _maxDepth, int _maxWeakCount ); | |
| 18 | + virtual ~CascadeBoostParams() {} | |
| 19 | + void write( cv::FileStorage &fs ) const; | |
| 20 | + bool read( const cv::FileNode &node ); | |
| 21 | + virtual void printDefaults() const; | |
| 22 | + virtual void printAttrs() const; | |
| 23 | + virtual bool scanAttr( const std::string prmName, const std::string val); | |
| 24 | +}; | |
| 25 | + | |
| 26 | +struct CascadeBoostTrainData : CvDTreeTrainData | |
| 27 | +{ | |
| 28 | + CascadeBoostTrainData( const FeatureEvaluator* _featureEvaluator, | |
| 29 | + const CvDTreeParams& _params ); | |
| 30 | + CascadeBoostTrainData( const FeatureEvaluator* _featureEvaluator, | |
| 31 | + int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize, | |
| 32 | + const CvDTreeParams& _params = CvDTreeParams() ); | |
| 33 | + virtual void setData( const FeatureEvaluator* _featureEvaluator, | |
| 34 | + int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize, | |
| 35 | + const CvDTreeParams& _params=CvDTreeParams() ); | |
| 36 | + void precalculate(); | |
| 37 | + | |
| 38 | + virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx ); | |
| 39 | + | |
| 40 | + virtual const int* get_class_labels( CvDTreeNode* n, int* labelsBuf ); | |
| 41 | + virtual const int* get_cv_labels( CvDTreeNode* n, int* labelsBuf); | |
| 42 | + virtual const int* get_sample_indices( CvDTreeNode* n, int* indicesBuf ); | |
| 43 | + | |
| 44 | + virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ordValuesBuf, int* sortedIndicesBuf, | |
| 45 | + const float** ordValues, const int** sortedIndices, int* sampleIndicesBuf ); | |
| 46 | + virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* catValuesBuf ); | |
| 47 | + virtual float getVarValue( int vi, int si ); | |
| 48 | + virtual void free_train_data(); | |
| 49 | + | |
| 50 | + const FeatureEvaluator* featureEvaluator; | |
| 51 | + cv::Mat valCache; // precalculated feature values (CV_32FC1) | |
| 52 | + CvMat _resp; // for casting | |
| 53 | + int numPrecalcVal, numPrecalcIdx; | |
| 54 | +}; | |
| 55 | + | |
| 56 | +class CascadeBoostTree : public CvBoostTree | |
| 57 | +{ | |
| 58 | +public: | |
| 59 | + virtual CvDTreeNode* predict( int sampleIdx ) const; | |
| 60 | + void write( cv::FileStorage &fs, const cv::Mat& featureMap ); | |
| 61 | + void read( const cv::FileNode &node, CvBoost* _ensemble, CvDTreeTrainData* _data ); | |
| 62 | + void markFeaturesInMap( cv::Mat& featureMap ); | |
| 63 | +protected: | |
| 64 | + virtual void split_node_data( CvDTreeNode* n ); | |
| 65 | +}; | |
| 66 | + | |
| 67 | +class CascadeBoost : public CvBoost | |
| 68 | +{ | |
| 69 | +public: | |
| 70 | + virtual bool train( const FeatureEvaluator* _featureEvaluator, | |
| 71 | + int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize, | |
| 72 | + const CascadeBoostParams& _params=CascadeBoostParams() ); | |
| 73 | + virtual float predict( int sampleIdx, bool returnSum = false ) const; | |
| 74 | + | |
| 75 | + float getThreshold() const { return threshold; } | |
| 76 | + void write( cv::FileStorage &fs, const cv::Mat& featureMap ) const; | |
| 77 | + bool read( const cv::FileNode &node, const FeatureEvaluator* _featureEvaluator, | |
| 78 | + const CascadeBoostParams& _params ); | |
| 79 | + void markUsedFeaturesInMap( cv::Mat& featureMap ); | |
| 80 | +protected: | |
| 81 | + virtual bool set_params( const CvBoostParams& _params ); | |
| 82 | + virtual void update_weights( CvBoostTree* tree ); | |
| 83 | + virtual bool isErrDesired(); | |
| 84 | + | |
| 85 | + float threshold; | |
| 86 | + float minHitRate, maxFalseAlarm; | |
| 87 | +}; | |
| 88 | + | |
| 89 | +} // namespace br | |
| 90 | + | |
| 91 | +#endif | |
| 92 | + | ... | ... |
openbr/core/cascade.cpp
0 → 100644
| 1 | +#include "cascade.h" | |
| 2 | +#include <stdio.h> | |
| 3 | +#include <iostream> | |
| 4 | +#include <fstream> | |
| 5 | + | |
| 6 | +using namespace std; | |
| 7 | +using namespace br; | |
| 8 | +using namespace cv; | |
| 9 | + | |
| 10 | +bool CascadeImageReader::create( const string _posFilename, const string _negFilename, Size _winSize ) | |
| 11 | +{ | |
| 12 | + return posReader.create(_posFilename) && negReader.create(_negFilename, _winSize); | |
| 13 | +} | |
| 14 | + | |
| 15 | +CascadeImageReader::NegReader::NegReader() | |
| 16 | +{ | |
| 17 | + src.create( 0, 0 , CV_8UC1 ); | |
| 18 | + img.create( 0, 0, CV_8UC1 ); | |
| 19 | + point = offset = Point( 0, 0 ); | |
| 20 | + scale = 1.0F; | |
| 21 | + scaleFactor = 1.4142135623730950488016887242097F; | |
| 22 | + stepFactor = 0.5F; | |
| 23 | +} | |
| 24 | + | |
| 25 | +bool CascadeImageReader::NegReader::create( const string _filename, Size _winSize ) | |
| 26 | +{ | |
| 27 | + string dirname, str; | |
| 28 | + std::ifstream file(_filename.c_str()); | |
| 29 | + if ( !file.is_open() ) | |
| 30 | + return false; | |
| 31 | + | |
| 32 | + size_t pos = _filename.rfind('\\'); | |
| 33 | + char dlmrt = '\\'; | |
| 34 | + if (pos == string::npos) | |
| 35 | + { | |
| 36 | + pos = _filename.rfind('/'); | |
| 37 | + dlmrt = '/'; | |
| 38 | + } | |
| 39 | + dirname = pos == string::npos ? "" : _filename.substr(0, pos) + dlmrt; | |
| 40 | + while( !file.eof() ) | |
| 41 | + { | |
| 42 | + std::getline(file, str); | |
| 43 | + if (str.empty()) break; | |
| 44 | + if (str.at(0) == '#' ) continue; /* comment */ | |
| 45 | + imgFilenames.push_back(dirname + str); | |
| 46 | + } | |
| 47 | + file.close(); | |
| 48 | + | |
| 49 | + winSize = _winSize; | |
| 50 | + last = round = 0; | |
| 51 | + return true; | |
| 52 | +} | |
| 53 | + | |
| 54 | +bool CascadeImageReader::NegReader::nextImg() | |
| 55 | +{ | |
| 56 | + Point _offset = Point(0,0); | |
| 57 | + size_t count = imgFilenames.size(); | |
| 58 | + for( size_t i = 0; i < count; i++ ) | |
| 59 | + { | |
| 60 | + src = imread( imgFilenames[last++], 0 ); | |
| 61 | + if( src.empty() ) | |
| 62 | + continue; | |
| 63 | + round += last / count; | |
| 64 | + round = round % (winSize.width * winSize.height); | |
| 65 | + last %= count; | |
| 66 | + | |
| 67 | + _offset.x = std::min( (int)round % winSize.width, src.cols - winSize.width ); | |
| 68 | + _offset.y = std::min( (int)round / winSize.width, src.rows - winSize.height ); | |
| 69 | + if( !src.empty() && src.type() == CV_8UC1 | |
| 70 | + && _offset.x >= 0 && _offset.y >= 0 ) | |
| 71 | + break; | |
| 72 | + } | |
| 73 | + | |
| 74 | + if( src.empty() ) | |
| 75 | + return false; // no appropriate image | |
| 76 | + point = offset = _offset; | |
| 77 | + scale = max( ((float)winSize.width + point.x) / ((float)src.cols), | |
| 78 | + ((float)winSize.height + point.y) / ((float)src.rows) ); | |
| 79 | + | |
| 80 | + Size sz( (int)(scale*src.cols + 0.5F), (int)(scale*src.rows + 0.5F) ); | |
| 81 | + resize( src, img, sz ); | |
| 82 | + return true; | |
| 83 | +} | |
| 84 | + | |
| 85 | +bool CascadeImageReader::NegReader::get( Mat& _img ) | |
| 86 | +{ | |
| 87 | + CV_Assert( !_img.empty() ); | |
| 88 | + CV_Assert( _img.type() == CV_8UC1 ); | |
| 89 | + CV_Assert( _img.cols == winSize.width ); | |
| 90 | + CV_Assert( _img.rows == winSize.height ); | |
| 91 | + | |
| 92 | + if( img.empty() ) | |
| 93 | + if ( !nextImg() ) | |
| 94 | + return false; | |
| 95 | + | |
| 96 | + Mat mat( winSize.height, winSize.width, CV_8UC1, | |
| 97 | + (void*)(img.data + point.y * img.step + point.x * img.elemSize()), img.step ); | |
| 98 | + mat.copyTo(_img); | |
| 99 | + | |
| 100 | + if( (int)( point.x + (1.0F + stepFactor ) * winSize.width ) < img.cols ) | |
| 101 | + point.x += (int)(stepFactor * winSize.width); | |
| 102 | + else | |
| 103 | + { | |
| 104 | + point.x = offset.x; | |
| 105 | + if( (int)( point.y + (1.0F + stepFactor ) * winSize.height ) < img.rows ) | |
| 106 | + point.y += (int)(stepFactor * winSize.height); | |
| 107 | + else | |
| 108 | + { | |
| 109 | + point.y = offset.y; | |
| 110 | + scale *= scaleFactor; | |
| 111 | + if( scale <= 1.0F ) | |
| 112 | + resize( src, img, Size( (int)(scale*src.cols), (int)(scale*src.rows) ) ); | |
| 113 | + else | |
| 114 | + { | |
| 115 | + if ( !nextImg() ) | |
| 116 | + return false; | |
| 117 | + } | |
| 118 | + } | |
| 119 | + } | |
| 120 | + return true; | |
| 121 | +} | |
| 122 | + | |
| 123 | +CascadeImageReader::PosReader::PosReader() | |
| 124 | +{ | |
| 125 | + file = 0; | |
| 126 | + vec = 0; | |
| 127 | +} | |
| 128 | + | |
| 129 | +bool CascadeImageReader::PosReader::create( const string _filename ) | |
| 130 | +{ | |
| 131 | + if ( file ) | |
| 132 | + fclose( file ); | |
| 133 | + file = fopen( _filename.c_str(), "rb" ); | |
| 134 | + | |
| 135 | + if( !file ) | |
| 136 | + return false; | |
| 137 | + short tmp = 0; | |
| 138 | + if( fread( &count, sizeof( count ), 1, file ) != 1 || | |
| 139 | + fread( &vecSize, sizeof( vecSize ), 1, file ) != 1 || | |
| 140 | + fread( &tmp, sizeof( tmp ), 1, file ) != 1 || | |
| 141 | + fread( &tmp, sizeof( tmp ), 1, file ) != 1 ) | |
| 142 | + CV_Error_( CV_StsParseError, ("wrong file format for %s\n", _filename.c_str()) ); | |
| 143 | + base = sizeof( count ) + sizeof( vecSize ) + 2*sizeof( tmp ); | |
| 144 | + if( feof( file ) ) | |
| 145 | + return false; | |
| 146 | + last = 0; | |
| 147 | + vec = (short*) cvAlloc( sizeof( *vec ) * vecSize ); | |
| 148 | + CV_Assert( vec ); | |
| 149 | + return true; | |
| 150 | +} | |
| 151 | + | |
| 152 | +bool CascadeImageReader::PosReader::get( Mat &_img ) | |
| 153 | +{ | |
| 154 | + CV_Assert( _img.rows * _img.cols == vecSize ); | |
| 155 | + uchar tmp = 0; | |
| 156 | + size_t elements_read = fread( &tmp, sizeof( tmp ), 1, file ); | |
| 157 | + if( elements_read != 1 ) | |
| 158 | + CV_Error( CV_StsBadArg, "Can not get new positive sample. The most possible reason is " | |
| 159 | + "insufficient count of samples in given vec-file.\n"); | |
| 160 | + elements_read = fread( vec, sizeof( vec[0] ), vecSize, file ); | |
| 161 | + if( elements_read != (size_t)(vecSize) ) | |
| 162 | + CV_Error( CV_StsBadArg, "Can not get new positive sample. Seems that vec-file has incorrect structure.\n"); | |
| 163 | + | |
| 164 | + if( feof( file ) || last++ >= count ) | |
| 165 | + CV_Error( CV_StsBadArg, "Can not get new positive sample. vec-file is over.\n"); | |
| 166 | + | |
| 167 | + for( int r = 0; r < _img.rows; r++ ) | |
| 168 | + { | |
| 169 | + for( int c = 0; c < _img.cols; c++ ) | |
| 170 | + _img.ptr(r)[c] = (uchar)vec[r * _img.cols + c]; | |
| 171 | + } | |
| 172 | + return true; | |
| 173 | +} | |
| 174 | + | |
| 175 | +void CascadeImageReader::PosReader::restart() | |
| 176 | +{ | |
| 177 | + CV_Assert( file ); | |
| 178 | + last = 0; | |
| 179 | + fseek( file, base, SEEK_SET ); | |
| 180 | +} | |
| 181 | + | |
| 182 | +CascadeImageReader::PosReader::~PosReader() | |
| 183 | +{ | |
| 184 | + if (file) | |
| 185 | + fclose( file ); | |
| 186 | + cvFree( &vec ); | |
| 187 | +} | |
| 188 | + | |
| 189 | +// -------------------------------------- Cascade -------------------------------------------- | |
| 190 | + | |
| 191 | +static const char* stageTypes[] = { CC_BOOST }; | |
| 192 | +static const char* featureTypes[] = { CC_LBP, CC_HAAR, CC_HOG, CC_HOGMULTI, CC_NPD }; | |
| 193 | + | |
| 194 | +CascadeParams::CascadeParams() : stageType( defaultStageType ), | |
| 195 | + featureType( defaultFeatureType ), winSize( cvSize(24, 24) ) | |
| 196 | +{ | |
| 197 | + name = CC_CASCADE_PARAMS; | |
| 198 | +} | |
| 199 | +CascadeParams::CascadeParams( int _stageType, int _featureType ) : stageType( _stageType ), | |
| 200 | + featureType( _featureType ), winSize( cvSize(24, 24) ) | |
| 201 | +{ | |
| 202 | + name = CC_CASCADE_PARAMS; | |
| 203 | +} | |
| 204 | + | |
| 205 | +//---------------------------- CascadeParams -------------------------------------- | |
| 206 | + | |
| 207 | +void CascadeParams::write( FileStorage &fs ) const | |
| 208 | +{ | |
| 209 | + string stageTypeStr = stageType == BOOST ? CC_BOOST : string(); | |
| 210 | + CV_Assert( !stageTypeStr.empty() ); | |
| 211 | + fs << CC_STAGE_TYPE << stageTypeStr; | |
| 212 | + string featureTypeStr = featureType == FeatureParams::LBP ? CC_HAAR : | |
| 213 | + 0; | |
| 214 | + CV_Assert( !stageTypeStr.empty() ); | |
| 215 | + fs << CC_FEATURE_TYPE << featureTypeStr; | |
| 216 | + fs << CC_HEIGHT << winSize.height; | |
| 217 | + fs << CC_WIDTH << winSize.width; | |
| 218 | +} | |
| 219 | + | |
| 220 | +bool CascadeParams::read( const FileNode &node ) | |
| 221 | +{ | |
| 222 | + if ( node.empty() ) | |
| 223 | + return false; | |
| 224 | + string stageTypeStr, featureTypeStr; | |
| 225 | + FileNode rnode = node[CC_STAGE_TYPE]; | |
| 226 | + if ( !rnode.isString() ) | |
| 227 | + return false; | |
| 228 | + rnode >> stageTypeStr; | |
| 229 | + stageType = !stageTypeStr.compare( CC_BOOST ) ? BOOST : -1; | |
| 230 | + if (stageType == -1) | |
| 231 | + return false; | |
| 232 | + rnode = node[CC_FEATURE_TYPE]; | |
| 233 | + if ( !rnode.isString() ) | |
| 234 | + return false; | |
| 235 | + rnode >> featureTypeStr; | |
| 236 | + featureType = !featureTypeStr.compare( CC_LBP ) ? FeatureParams::LBP : | |
| 237 | + -1; | |
| 238 | + if (featureType == -1) | |
| 239 | + return false; | |
| 240 | + node[CC_HEIGHT] >> winSize.height; | |
| 241 | + node[CC_WIDTH] >> winSize.width; | |
| 242 | + return winSize.height > 0 && winSize.width > 0; | |
| 243 | +} | |
| 244 | + | |
| 245 | +void CascadeParams::printDefaults() const | |
| 246 | +{ | |
| 247 | + Params::printDefaults(); | |
| 248 | + cout << " [-stageType <"; | |
| 249 | + for( int i = 0; i < (int)(sizeof(stageTypes)/sizeof(stageTypes[0])); i++ ) | |
| 250 | + { | |
| 251 | + cout << (i ? " | " : "") << stageTypes[i]; | |
| 252 | + if ( i == defaultStageType ) | |
| 253 | + cout << "(default)"; | |
| 254 | + } | |
| 255 | + cout << ">]" << endl; | |
| 256 | + | |
| 257 | + cout << " [-featureType <{"; | |
| 258 | + for( int i = 0; i < (int)(sizeof(featureTypes)/sizeof(featureTypes[0])); i++ ) | |
| 259 | + { | |
| 260 | + cout << (i ? ", " : "") << featureTypes[i]; | |
| 261 | + if ( i == defaultStageType ) | |
| 262 | + cout << "(default)"; | |
| 263 | + } | |
| 264 | + cout << "}>]" << endl; | |
| 265 | + cout << " [-w <sampleWidth = " << winSize.width << ">]" << endl; | |
| 266 | + cout << " [-h <sampleHeight = " << winSize.height << ">]" << endl; | |
| 267 | +} | |
| 268 | + | |
| 269 | +void CascadeParams::printAttrs() const | |
| 270 | +{ | |
| 271 | + cout << "stageType: " << stageTypes[stageType] << endl; | |
| 272 | + cout << "featureType: " << featureTypes[featureType] << endl; | |
| 273 | + cout << "sampleWidth: " << winSize.width << endl; | |
| 274 | + cout << "sampleHeight: " << winSize.height << endl; | |
| 275 | +} | |
| 276 | + | |
| 277 | +bool CascadeParams::scanAttr( const string prmName, const string val ) | |
| 278 | +{ | |
| 279 | + bool res = true; | |
| 280 | + if( !prmName.compare( "-stageType" ) ) | |
| 281 | + { | |
| 282 | + for( int i = 0; i < (int)(sizeof(stageTypes)/sizeof(stageTypes[0])); i++ ) | |
| 283 | + if( !val.compare( stageTypes[i] ) ) | |
| 284 | + stageType = i; | |
| 285 | + } | |
| 286 | + else if( !prmName.compare( "-featureType" ) ) | |
| 287 | + { | |
| 288 | + for( int i = 0; i < (int)(sizeof(featureTypes)/sizeof(featureTypes[0])); i++ ) | |
| 289 | + if( !val.compare( featureTypes[i] ) ) | |
| 290 | + featureType = i; | |
| 291 | + } | |
| 292 | + else if( !prmName.compare( "-w" ) ) | |
| 293 | + { | |
| 294 | + winSize.width = atoi( val.c_str() ); | |
| 295 | + } | |
| 296 | + else if( !prmName.compare( "-h" ) ) | |
| 297 | + { | |
| 298 | + winSize.height = atoi( val.c_str() ); | |
| 299 | + } | |
| 300 | + else | |
| 301 | + res = false; | |
| 302 | + return res; | |
| 303 | +} | |
| 304 | + | |
| 305 | +//---------------------------- CascadeClassifier -------------------------------------- | |
| 306 | + | |
| 307 | +bool BrCascadeClassifier::train( const string _cascadeDirName, | |
| 308 | + const string _posFilename, | |
| 309 | + const string _negFilename, | |
| 310 | + int _numPos, int _numNeg, | |
| 311 | + int _precalcValBufSize, int _precalcIdxBufSize, | |
| 312 | + int _numStages, | |
| 313 | + const CascadeParams& _cascadeParams, | |
| 314 | + const FeatureParams& _featureParams, | |
| 315 | + const CascadeBoostParams& _stageParams, | |
| 316 | + bool baseFormatSave ) | |
| 317 | +{ | |
| 318 | + // Start recording clock ticks for training time output | |
| 319 | + const clock_t begin_time = clock(); | |
| 320 | + | |
| 321 | + if( _cascadeDirName.empty() || _posFilename.empty() || _negFilename.empty() ) | |
| 322 | + CV_Error( CV_StsBadArg, "_cascadeDirName or _bgfileName or _vecFileName is NULL" ); | |
| 323 | + | |
| 324 | + string dirName; | |
| 325 | + if (_cascadeDirName.find_last_of("/\\") == (_cascadeDirName.length() - 1) ) | |
| 326 | + dirName = _cascadeDirName; | |
| 327 | + else | |
| 328 | + dirName = _cascadeDirName + '/'; | |
| 329 | + | |
| 330 | + numPos = _numPos; | |
| 331 | + numNeg = _numNeg; | |
| 332 | + numStages = _numStages; | |
| 333 | + if ( !imgReader.create( _posFilename, _negFilename, _cascadeParams.winSize ) ) | |
| 334 | + { | |
| 335 | + cout << "Image reader can not be created from -vec " << _posFilename | |
| 336 | + << " and -bg " << _negFilename << "." << endl; | |
| 337 | + return false; | |
| 338 | + } | |
| 339 | + if ( !load( dirName ) ) | |
| 340 | + { | |
| 341 | + cascadeParams = _cascadeParams; | |
| 342 | + featureParams = FeatureParams::create(cascadeParams.featureType); | |
| 343 | + featureParams->init(_featureParams); | |
| 344 | + stageParams = new CascadeBoostParams; | |
| 345 | + *stageParams = _stageParams; | |
| 346 | + featureEvaluator = FeatureEvaluator::create(cascadeParams.featureType); | |
| 347 | + featureEvaluator->init( (FeatureParams*)featureParams, numPos + numNeg, cascadeParams.winSize ); | |
| 348 | + stageClassifiers.reserve( numStages ); | |
| 349 | + } | |
| 350 | + cout << "PARAMETERS:" << endl; | |
| 351 | + cout << "cascadeDirName: " << _cascadeDirName << endl; | |
| 352 | + cout << "vecFileName: " << _posFilename << endl; | |
| 353 | + cout << "bgFileName: " << _negFilename << endl; | |
| 354 | + cout << "numPos: " << _numPos << endl; | |
| 355 | + cout << "numNeg: " << _numNeg << endl; | |
| 356 | + cout << "numStages: " << numStages << endl; | |
| 357 | + cout << "precalcValBufSize[Mb] : " << _precalcValBufSize << endl; | |
| 358 | + cout << "precalcIdxBufSize[Mb] : " << _precalcIdxBufSize << endl; | |
| 359 | + cascadeParams.printAttrs(); | |
| 360 | + stageParams->printAttrs(); | |
| 361 | + featureParams->printAttrs(); | |
| 362 | + | |
| 363 | + int startNumStages = (int)stageClassifiers.size(); | |
| 364 | + if ( startNumStages > 1 ) | |
| 365 | + cout << endl << "Stages 0-" << startNumStages-1 << " are loaded" << endl; | |
| 366 | + else if ( startNumStages == 1) | |
| 367 | + cout << endl << "Stage 0 is loaded" << endl; | |
| 368 | + | |
| 369 | + double requiredLeafFARate = pow( (double) stageParams->maxFalseAlarm, (double) numStages ) / | |
| 370 | + (double)stageParams->max_depth; | |
| 371 | + double tempLeafFARate; | |
| 372 | + | |
| 373 | + for( int i = startNumStages; i < numStages; i++ ) | |
| 374 | + { | |
| 375 | + cout << endl << "===== TRAINING " << i << "-stage =====" << endl; | |
| 376 | + cout << "<BEGIN" << endl; | |
| 377 | + if ( !updateTrainingSet( tempLeafFARate ) ) | |
| 378 | + { | |
| 379 | + cout << "Train dataset for temp stage can not be filled. " | |
| 380 | + "Branch training terminated." << endl; | |
| 381 | + break; | |
| 382 | + } | |
| 383 | + if( tempLeafFARate <= requiredLeafFARate ) | |
| 384 | + { | |
| 385 | + cout << "Required leaf false alarm rate achieved. " | |
| 386 | + "Branch training terminated." << endl; | |
| 387 | + break; | |
| 388 | + } | |
| 389 | + | |
| 390 | + CascadeBoost* tempStage = new CascadeBoost; | |
| 391 | + bool isStageTrained = tempStage->train( (FeatureEvaluator*)featureEvaluator, | |
| 392 | + curNumSamples, _precalcValBufSize, _precalcIdxBufSize, | |
| 393 | + *((CascadeBoostParams*)stageParams) ); | |
| 394 | + cout << "END>" << endl; | |
| 395 | + | |
| 396 | + if(!isStageTrained) | |
| 397 | + break; | |
| 398 | + | |
| 399 | + stageClassifiers.push_back( tempStage ); | |
| 400 | + | |
| 401 | + // save params | |
| 402 | + if( i == 0) | |
| 403 | + { | |
| 404 | + std::string paramsFilename = dirName + CC_PARAMS_FILENAME; | |
| 405 | + FileStorage fs( paramsFilename, FileStorage::WRITE); | |
| 406 | + if ( !fs.isOpened() ) | |
| 407 | + { | |
| 408 | + cout << "Parameters can not be written, because file " << paramsFilename | |
| 409 | + << " can not be opened." << endl; | |
| 410 | + return false; | |
| 411 | + } | |
| 412 | + fs << FileStorage::getDefaultObjectName(paramsFilename) << "{"; | |
| 413 | + writeParams( fs ); | |
| 414 | + fs << "}"; | |
| 415 | + } | |
| 416 | + // save current stage | |
| 417 | + char buf[10]; | |
| 418 | + sprintf(buf, "%s%d", "stage", i ); | |
| 419 | + string stageFilename = dirName + buf + ".xml"; | |
| 420 | + FileStorage fs( stageFilename, FileStorage::WRITE ); | |
| 421 | + if ( !fs.isOpened() ) | |
| 422 | + { | |
| 423 | + cout << "Current stage can not be written, because file " << stageFilename | |
| 424 | + << " can not be opened." << endl; | |
| 425 | + return false; | |
| 426 | + } | |
| 427 | + fs << FileStorage::getDefaultObjectName(stageFilename) << "{"; | |
| 428 | + tempStage->write( fs, Mat() ); | |
| 429 | + fs << "}"; | |
| 430 | + | |
| 431 | + // Output training time up till now | |
| 432 | + float seconds = float( clock () - begin_time ) / CLOCKS_PER_SEC; | |
| 433 | + int days = int(seconds) / 60 / 60 / 24; | |
| 434 | + int hours = (int(seconds) / 60 / 60) % 24; | |
| 435 | + int minutes = (int(seconds) / 60) % 60; | |
| 436 | + int seconds_left = int(seconds) % 60; | |
| 437 | + cout << "Training until now has taken " << days << " days " << hours << " hours " << minutes << " minutes " << seconds_left <<" seconds." << endl; | |
| 438 | + } | |
| 439 | + | |
| 440 | + if(stageClassifiers.size() == 0) | |
| 441 | + { | |
| 442 | + cout << "Cascade classifier can't be trained. Check the used training parameters." << endl; | |
| 443 | + return false; | |
| 444 | + } | |
| 445 | + | |
| 446 | + save( dirName + CC_CASCADE_FILENAME, baseFormatSave ); | |
| 447 | + | |
| 448 | + return true; | |
| 449 | +} | |
| 450 | + | |
| 451 | +int BrCascadeClassifier::predict( int sampleIdx ) | |
| 452 | +{ | |
| 453 | + CV_DbgAssert( sampleIdx < numPos + numNeg ); | |
| 454 | + for (vector< Ptr<CascadeBoost> >::iterator it = stageClassifiers.begin(); | |
| 455 | + it != stageClassifiers.end(); it++ ) | |
| 456 | + { | |
| 457 | + if ( (*it)->predict( sampleIdx ) == 0.f ) | |
| 458 | + return 0; | |
| 459 | + } | |
| 460 | + return 1; | |
| 461 | +} | |
| 462 | + | |
| 463 | +bool BrCascadeClassifier::updateTrainingSet( double& acceptanceRatio) | |
| 464 | +{ | |
| 465 | + int64 posConsumed = 0, negConsumed = 0; | |
| 466 | + imgReader.restart(); | |
| 467 | + | |
| 468 | + int posCount = fillPassedSamples( 0, numPos, true, posConsumed ); | |
| 469 | + if( !posCount ) | |
| 470 | + return false; | |
| 471 | + cout << "POS count : consumed " << posCount << " : " << (int)posConsumed << endl; | |
| 472 | + | |
| 473 | + int proNumNeg = cvRound( ( ((double)numNeg) * ((double)posCount) ) / numPos ); // apply only a fraction of negative samples. double is required since overflow is possible | |
| 474 | + int negCount = fillPassedSamples( posCount, proNumNeg, false, negConsumed ); | |
| 475 | + if ( !negCount ) | |
| 476 | + return false; | |
| 477 | + | |
| 478 | + curNumSamples = posCount + negCount; | |
| 479 | + acceptanceRatio = negConsumed == 0 ? 0 : ( (double)negCount/(double)(int64)negConsumed ); | |
| 480 | + cout << "NEG count : acceptanceRatio " << negCount << " : " << acceptanceRatio << endl; | |
| 481 | + return true; | |
| 482 | +} | |
| 483 | + | |
| 484 | +int BrCascadeClassifier::fillPassedSamples( int first, int count, bool isPositive, int64& consumed ) | |
| 485 | +{ | |
| 486 | + int getcount = 0; | |
| 487 | + Mat img(cascadeParams.winSize, CV_8UC1); | |
| 488 | + for( int i = first; i < first + count; i++ ) | |
| 489 | + { | |
| 490 | + for( ; ; ) | |
| 491 | + { | |
| 492 | + bool isGetImg = isPositive ? imgReader.getPos( img ) : | |
| 493 | + imgReader.getNeg( img ); | |
| 494 | + if( !isGetImg ) | |
| 495 | + return getcount; | |
| 496 | + consumed++; | |
| 497 | + | |
| 498 | + featureEvaluator->setImage( img, isPositive ? 1 : 0, i ); | |
| 499 | + if( predict( i ) == 1.0F ) | |
| 500 | + { | |
| 501 | + getcount++; | |
| 502 | + printf("%s current samples: %d\r", isPositive ? "POS":"NEG", getcount); | |
| 503 | + break; | |
| 504 | + } | |
| 505 | + } | |
| 506 | + } | |
| 507 | + return getcount; | |
| 508 | +} | |
| 509 | + | |
| 510 | +void BrCascadeClassifier::writeParams( FileStorage &fs ) const | |
| 511 | +{ | |
| 512 | + cascadeParams.write( fs ); | |
| 513 | + fs << CC_STAGE_PARAMS << "{"; stageParams->write( fs ); fs << "}"; | |
| 514 | + fs << CC_FEATURE_PARAMS << "{"; featureParams->write( fs ); fs << "}"; | |
| 515 | +} | |
| 516 | + | |
| 517 | +void BrCascadeClassifier::writeFeatures( FileStorage &fs, const Mat& featureMap ) const | |
| 518 | +{ | |
| 519 | + ((FeatureEvaluator*)((Ptr<FeatureEvaluator>)featureEvaluator))->writeFeatures( fs, featureMap ); | |
| 520 | +} | |
| 521 | + | |
| 522 | +void BrCascadeClassifier::writeStages( FileStorage &fs, const Mat& featureMap ) const | |
| 523 | +{ | |
| 524 | + char cmnt[30]; | |
| 525 | + int i = 0; | |
| 526 | + fs << CC_STAGES << "["; | |
| 527 | + for( vector< Ptr<CascadeBoost> >::const_iterator it = stageClassifiers.begin(); | |
| 528 | + it != stageClassifiers.end(); it++, i++ ) | |
| 529 | + { | |
| 530 | + sprintf( cmnt, "stage %d", i ); | |
| 531 | + cvWriteComment( fs.fs, cmnt, 0 ); | |
| 532 | + fs << "{"; | |
| 533 | + ((CascadeBoost*)((Ptr<CascadeBoost>)*it))->write( fs, featureMap ); | |
| 534 | + fs << "}"; | |
| 535 | + } | |
| 536 | + fs << "]"; | |
| 537 | +} | |
| 538 | + | |
| 539 | +bool BrCascadeClassifier::readParams( const FileNode &node ) | |
| 540 | +{ | |
| 541 | + if ( !node.isMap() || !cascadeParams.read( node ) ) | |
| 542 | + return false; | |
| 543 | + stageParams = new CascadeBoostParams; | |
| 544 | + FileNode rnode = node[CC_STAGE_PARAMS]; | |
| 545 | + if ( !stageParams->read( rnode ) ) | |
| 546 | + return false; | |
| 547 | + | |
| 548 | + featureParams = FeatureParams::create(cascadeParams.featureType); | |
| 549 | + rnode = node[CC_FEATURE_PARAMS]; | |
| 550 | + if ( !featureParams->read( rnode ) ) | |
| 551 | + return false; | |
| 552 | + return true; | |
| 553 | +} | |
| 554 | + | |
| 555 | +bool BrCascadeClassifier::readStages( const FileNode &node) | |
| 556 | +{ | |
| 557 | + FileNode rnode = node[CC_STAGES]; | |
| 558 | + if (!rnode.empty() || !rnode.isSeq()) | |
| 559 | + return false; | |
| 560 | + stageClassifiers.reserve(numStages); | |
| 561 | + FileNodeIterator it = rnode.begin(); | |
| 562 | + for( int i = 0; i < min( (int)rnode.size(), numStages ); i++, it++ ) | |
| 563 | + { | |
| 564 | + CascadeBoost* tempStage = new CascadeBoost; | |
| 565 | + if ( !tempStage->read( *it, (FeatureEvaluator *)featureEvaluator, *((CascadeBoostParams*)stageParams) ) ) | |
| 566 | + { | |
| 567 | + delete tempStage; | |
| 568 | + return false; | |
| 569 | + } | |
| 570 | + stageClassifiers.push_back(tempStage); | |
| 571 | + } | |
| 572 | + return true; | |
| 573 | +} | |
| 574 | + | |
| 575 | +void BrCascadeClassifier::save( const string filename, bool baseFormat ) | |
| 576 | +{ | |
| 577 | + FileStorage fs( filename, FileStorage::WRITE ); | |
| 578 | + | |
| 579 | + if ( !fs.isOpened() ) | |
| 580 | + return; | |
| 581 | + | |
| 582 | + fs << FileStorage::getDefaultObjectName(filename) << "{"; | |
| 583 | + if ( !baseFormat ) | |
| 584 | + { | |
| 585 | + Mat featureMap; | |
| 586 | + getUsedFeaturesIdxMap( featureMap ); | |
| 587 | + writeParams( fs ); | |
| 588 | + fs << CC_STAGE_NUM << (int)stageClassifiers.size(); | |
| 589 | + writeStages( fs, featureMap ); | |
| 590 | + writeFeatures( fs, featureMap ); | |
| 591 | + } | |
| 592 | + else | |
| 593 | + { | |
| 594 | + qFatal("Old style cascade. Not sure how you got here but it's not supported"); | |
| 595 | + } | |
| 596 | + fs << "}"; | |
| 597 | +} | |
| 598 | + | |
| 599 | +bool BrCascadeClassifier::load( const string cascadeDirName ) | |
| 600 | +{ | |
| 601 | + FileStorage fs( cascadeDirName + CC_PARAMS_FILENAME, FileStorage::READ ); | |
| 602 | + if ( !fs.isOpened() ) | |
| 603 | + return false; | |
| 604 | + FileNode node = fs.getFirstTopLevelNode(); | |
| 605 | + if ( !readParams( node ) ) | |
| 606 | + return false; | |
| 607 | + | |
| 608 | + featureEvaluator = FeatureEvaluator::create(cascadeParams.featureType); | |
| 609 | + featureEvaluator->init( ((FeatureParams*)featureParams), numPos + numNeg, cascadeParams.winSize ); | |
| 610 | + fs.release(); | |
| 611 | + | |
| 612 | + char buf[10]; | |
| 613 | + for ( int si = 0; si < numStages; si++ ) | |
| 614 | + { | |
| 615 | + sprintf( buf, "%s%d", "stage", si); | |
| 616 | + fs.open( cascadeDirName + buf + ".xml", FileStorage::READ ); | |
| 617 | + node = fs.getFirstTopLevelNode(); | |
| 618 | + if ( !fs.isOpened() ) | |
| 619 | + break; | |
| 620 | + CascadeBoost *tempStage = new CascadeBoost; | |
| 621 | + | |
| 622 | + if ( !tempStage->read( node, (FeatureEvaluator*)featureEvaluator, *((CascadeBoostParams*)stageParams )) ) | |
| 623 | + { | |
| 624 | + delete tempStage; | |
| 625 | + fs.release(); | |
| 626 | + break; | |
| 627 | + } | |
| 628 | + stageClassifiers.push_back(tempStage); | |
| 629 | + } | |
| 630 | + return true; | |
| 631 | +} | |
| 632 | + | |
| 633 | +void BrCascadeClassifier::getUsedFeaturesIdxMap( Mat& featureMap ) | |
| 634 | +{ | |
| 635 | + int varCount = featureEvaluator->getNumFeatures() * featureEvaluator->getFeatureSize(); | |
| 636 | + featureMap.create( 1, varCount, CV_32SC1 ); | |
| 637 | + featureMap.setTo(Scalar(-1)); | |
| 638 | + | |
| 639 | + for( vector< Ptr<CascadeBoost> >::const_iterator it = stageClassifiers.begin(); | |
| 640 | + it != stageClassifiers.end(); it++ ) | |
| 641 | + ((CascadeBoost*)((Ptr<CascadeBoost>)(*it)))->markUsedFeaturesInMap( featureMap ); | |
| 642 | + | |
| 643 | + for( int fi = 0, idx = 0; fi < varCount; fi++ ) | |
| 644 | + if ( featureMap.at<int>(0, fi) >= 0 ) | |
| 645 | + featureMap.ptr<int>(0)[fi] = idx++; | |
| 646 | +} | |
| 647 | + | ... | ... |
openbr/core/cascade.h
0 → 100644
| 1 | +#ifndef CASCADE_H | |
| 2 | +#define CASCADE_H | |
| 3 | + | |
| 4 | +#include <openbr/openbr_plugin.h> | |
| 5 | +#include <opencv2/highgui/highgui.hpp> | |
| 6 | +#include "features.h" | |
| 7 | +#include "boost.h" | |
| 8 | + | |
| 9 | +namespace br | |
| 10 | +{ | |
| 11 | + | |
| 12 | +class CascadeImageReader | |
| 13 | +{ | |
| 14 | +public: | |
| 15 | + bool create( const std::string _posFilename, const std::string _negFilename, cv::Size _winSize ); | |
| 16 | + void restart() { posReader.restart(); } | |
| 17 | + bool getNeg(cv::Mat &_img) { return negReader.get( _img ); } | |
| 18 | + bool getPos(cv::Mat &_img) { return posReader.get( _img ); } | |
| 19 | + | |
| 20 | +private: | |
| 21 | + class PosReader | |
| 22 | + { | |
| 23 | + public: | |
| 24 | + PosReader(); | |
| 25 | + virtual ~PosReader(); | |
| 26 | + bool create( const std::string _filename ); | |
| 27 | + bool get( cv::Mat &_img ); | |
| 28 | + void restart(); | |
| 29 | + | |
| 30 | + short* vec; | |
| 31 | + FILE* file; | |
| 32 | + int count; | |
| 33 | + int vecSize; | |
| 34 | + int last; | |
| 35 | + int base; | |
| 36 | + } posReader; | |
| 37 | + | |
| 38 | + class NegReader | |
| 39 | + { | |
| 40 | + public: | |
| 41 | + NegReader(); | |
| 42 | + bool create( const std::string _filename, cv::Size _winSize ); | |
| 43 | + bool get( cv::Mat& _img ); | |
| 44 | + bool nextImg(); | |
| 45 | + | |
| 46 | + cv::Mat src, img; | |
| 47 | + std::vector<std::string> imgFilenames; | |
| 48 | + cv::Point offset, point; | |
| 49 | + float scale; | |
| 50 | + float scaleFactor; | |
| 51 | + float stepFactor; | |
| 52 | + size_t last, round; | |
| 53 | + cv::Size winSize; | |
| 54 | + } negReader; | |
| 55 | +}; | |
| 56 | + | |
| 57 | +class CascadeParams : public Params | |
| 58 | +{ | |
| 59 | +public: | |
| 60 | + enum { BOOST = 0 }; | |
| 61 | + static const int defaultStageType = BOOST; | |
| 62 | + static const int defaultFeatureType = FeatureParams::LBP; | |
| 63 | + | |
| 64 | + CascadeParams(); | |
| 65 | + CascadeParams( int _stageType, int _featureType ); | |
| 66 | + void write( cv::FileStorage &fs ) const; | |
| 67 | + bool read( const cv::FileNode &node ); | |
| 68 | + | |
| 69 | + void printDefaults() const; | |
| 70 | + void printAttrs() const; | |
| 71 | + bool scanAttr( const std::string prmName, const std::string val ); | |
| 72 | + | |
| 73 | + int stageType; | |
| 74 | + int featureType; | |
| 75 | + cv::Size winSize; | |
| 76 | +}; | |
| 77 | + | |
| 78 | +class BrCascadeClassifier | |
| 79 | +{ | |
| 80 | +public: | |
| 81 | + bool train( const std::string _cascadeDirName, | |
| 82 | + const std::string _posFilename, | |
| 83 | + const std::string _negFilename, | |
| 84 | + int _numPos, int _numNeg, | |
| 85 | + int _precalcValBufSize, int _precalcIdxBufSize, | |
| 86 | + int _numStages, | |
| 87 | + const CascadeParams& _cascadeParams, | |
| 88 | + const FeatureParams& _featureParams, | |
| 89 | + const CascadeBoostParams& _stageParams, | |
| 90 | + bool baseFormatSave = false ); | |
| 91 | +private: | |
| 92 | + int predict( int sampleIdx ); | |
| 93 | + void save( const std::string cascadeDirName, bool baseFormat = false ); | |
| 94 | + bool load( const std::string cascadeDirName ); | |
| 95 | + bool updateTrainingSet( double& acceptanceRatio ); | |
| 96 | + int fillPassedSamples( int first, int count, bool isPositive, int64& consumed ); | |
| 97 | + | |
| 98 | + void writeParams( cv::FileStorage &fs ) const; | |
| 99 | + void writeStages( cv::FileStorage &fs, const cv::Mat& featureMap ) const; | |
| 100 | + void writeFeatures( cv::FileStorage &fs, const cv::Mat& featureMap ) const; | |
| 101 | + bool readParams( const cv::FileNode &node ); | |
| 102 | + bool readStages( const cv::FileNode &node ); | |
| 103 | + | |
| 104 | + void getUsedFeaturesIdxMap( cv::Mat& featureMap ); | |
| 105 | + | |
| 106 | + CascadeParams cascadeParams; | |
| 107 | + cv::Ptr<FeatureParams> featureParams; | |
| 108 | + cv::Ptr<CascadeBoostParams> stageParams; | |
| 109 | + | |
| 110 | + cv::Ptr<FeatureEvaluator> featureEvaluator; | |
| 111 | + std::vector< cv::Ptr<CascadeBoost> > stageClassifiers; | |
| 112 | + CascadeImageReader imgReader; | |
| 113 | + int numStages, curNumSamples; | |
| 114 | + int numPos, numNeg; | |
| 115 | +}; | |
| 116 | + | |
| 117 | +} // namespace br | |
| 118 | + | |
| 119 | +#endif // CASCADE_H | ... | ... |
openbr/core/features.cpp
0 → 100644
| 1 | +#include "features.h" | |
| 2 | + | |
| 3 | +using namespace cv; | |
| 4 | +using namespace br; | |
| 5 | + | |
| 6 | +//------------------------- Params ----------------------------------------------- | |
| 7 | + | |
| 8 | +float calcNormFactor( const Mat& sum, const Mat& sqSum ) | |
| 9 | +{ | |
| 10 | + CV_DbgAssert( sum.cols > 3 && sqSum.rows > 3 ); | |
| 11 | + Rect normrect( 1, 1, sum.cols - 3, sum.rows - 3 ); | |
| 12 | + size_t p0, p1, p2, p3; | |
| 13 | + CV_SUM_OFFSETS( p0, p1, p2, p3, normrect, sum.step1() ) | |
| 14 | + double area = normrect.width * normrect.height; | |
| 15 | + const int *sp = (const int*)sum.data; | |
| 16 | + int valSum = sp[p0] - sp[p1] - sp[p2] + sp[p3]; | |
| 17 | + const double *sqp = (const double *)sqSum.data; | |
| 18 | + double valSqSum = sqp[p0] - sqp[p1] - sqp[p2] + sqp[p3]; | |
| 19 | + return (float) sqrt( (double) (area * valSqSum - (double)valSum * valSum) ); | |
| 20 | +} | |
| 21 | + | |
| 22 | +Params::Params() : name( "params" ) {} | |
| 23 | +void Params::printDefaults() const { std::cout << "--" << name << "--" << endl; } | |
| 24 | +void Params::printAttrs() const {} | |
| 25 | +bool Params::scanAttr( const string, const string ) { return false; } | |
| 26 | + | |
| 27 | + | |
| 28 | +//---------------------------- FeatureParams -------------------------------------- | |
| 29 | + | |
| 30 | +FeatureParams::FeatureParams() : maxCatCount( 0 ), featSize( 1 ) | |
| 31 | +{ | |
| 32 | + name = CC_FEATURE_PARAMS; | |
| 33 | +} | |
| 34 | + | |
| 35 | +void FeatureParams::init( const FeatureParams& fp ) | |
| 36 | +{ | |
| 37 | + maxCatCount = fp.maxCatCount; | |
| 38 | + featSize = fp.featSize; | |
| 39 | +} | |
| 40 | + | |
| 41 | +void FeatureParams::write( FileStorage &fs ) const | |
| 42 | +{ | |
| 43 | + fs << CC_MAX_CAT_COUNT << maxCatCount; | |
| 44 | + fs << CC_FEATURE_SIZE << featSize; | |
| 45 | +} | |
| 46 | + | |
| 47 | +bool FeatureParams::read( const FileNode &node ) | |
| 48 | +{ | |
| 49 | + if ( node.empty() ) | |
| 50 | + return false; | |
| 51 | + maxCatCount = node[CC_MAX_CAT_COUNT]; | |
| 52 | + featSize = node[CC_FEATURE_SIZE]; | |
| 53 | + return ( maxCatCount >= 0 && featSize >= 1 ); | |
| 54 | +} | |
| 55 | + | |
| 56 | +Ptr<FeatureParams> FeatureParams::create( int featureType ) | |
| 57 | +{ | |
| 58 | + return featureType == LBP ? Ptr<FeatureParams>(new LBPFeatureParams) : | |
| 59 | + Ptr<FeatureParams>(); | |
| 60 | +} | |
| 61 | + | |
| 62 | +//------------------------------------- FeatureEvaluator --------------------------------------- | |
| 63 | + | |
| 64 | +void FeatureEvaluator::init(const FeatureParams *_featureParams, | |
| 65 | + int _maxSampleCount, Size _winSize ) | |
| 66 | +{ | |
| 67 | + CV_Assert(_maxSampleCount > 0); | |
| 68 | + featureParams = (FeatureParams *)_featureParams; | |
| 69 | + winSize = _winSize; | |
| 70 | + numFeatures = 0; | |
| 71 | + cls.create( (int)_maxSampleCount, 1, CV_32FC1 ); | |
| 72 | + generateFeatures(); | |
| 73 | +} | |
| 74 | + | |
| 75 | +void FeatureEvaluator::setImage(const Mat &img, uchar clsLabel, int idx) | |
| 76 | +{ | |
| 77 | + CV_Assert(img.cols == winSize.width); | |
| 78 | + CV_Assert(img.rows == winSize.height); | |
| 79 | + CV_Assert(idx < cls.rows); | |
| 80 | + cls.ptr<float>(idx)[0] = clsLabel; | |
| 81 | +} | |
| 82 | + | |
| 83 | +Ptr<FeatureEvaluator> FeatureEvaluator::create(int type) | |
| 84 | +{ | |
| 85 | + return type == FeatureParams::LBP ? Ptr<FeatureEvaluator>(new LBPEvaluator) : | |
| 86 | + Ptr<FeatureEvaluator>(); | |
| 87 | +} | |
| 88 | + | |
| 89 | +// ------------------------------------ LBP ----------------------------------------------- | |
| 90 | + | |
| 91 | +LBPFeatureParams::LBPFeatureParams() | |
| 92 | +{ | |
| 93 | + maxCatCount = 256; | |
| 94 | + name = LBPF_NAME; | |
| 95 | +} | |
| 96 | + | |
| 97 | +void LBPEvaluator::init(const FeatureParams *_featureParams, int _maxSampleCount, Size _winSize) | |
| 98 | +{ | |
| 99 | + CV_Assert( _maxSampleCount > 0); | |
| 100 | + sum.create((int)_maxSampleCount, (_winSize.width + 1) * (_winSize.height + 1), CV_32SC1); | |
| 101 | + FeatureEvaluator::init( _featureParams, _maxSampleCount, _winSize ); | |
| 102 | +} | |
| 103 | + | |
| 104 | +void LBPEvaluator::setImage(const Mat &img, uchar clsLabel, int idx) | |
| 105 | +{ | |
| 106 | + CV_DbgAssert( !sum.empty() ); | |
| 107 | + FeatureEvaluator::setImage( img, clsLabel, idx ); | |
| 108 | + Mat innSum(winSize.height + 1, winSize.width + 1, sum.type(), sum.ptr<int>((int)idx)); | |
| 109 | + integral( img, innSum ); | |
| 110 | +} | |
| 111 | + | |
| 112 | +void LBPEvaluator::writeFeatures( FileStorage &fs, const Mat& featureMap ) const | |
| 113 | +{ | |
| 114 | + _writeFeatures( features, fs, featureMap ); | |
| 115 | +} | |
| 116 | + | |
| 117 | +void LBPEvaluator::generateFeatures() | |
| 118 | +{ | |
| 119 | + int offset = winSize.width + 1; | |
| 120 | + for( int x = 0; x < winSize.width; x++ ) | |
| 121 | + for( int y = 0; y < winSize.height; y++ ) | |
| 122 | + for( int w = 1; w <= winSize.width / 3; w++ ) | |
| 123 | + for( int h = 1; h <= winSize.height / 3; h++ ) | |
| 124 | + if ( (x+3*w <= winSize.width) && (y+3*h <= winSize.height) ) | |
| 125 | + features.push_back( Feature(offset, x, y, w, h ) ); | |
| 126 | + numFeatures = (int)features.size(); | |
| 127 | +} | |
| 128 | + | |
| 129 | +LBPEvaluator::Feature::Feature() | |
| 130 | +{ | |
| 131 | + rect = cvRect(0, 0, 0, 0); | |
| 132 | +} | |
| 133 | + | |
| 134 | +LBPEvaluator::Feature::Feature( int offset, int x, int y, int _blockWidth, int _blockHeight ) | |
| 135 | +{ | |
| 136 | + Rect tr = rect = cvRect(x, y, _blockWidth, _blockHeight); | |
| 137 | + CV_SUM_OFFSETS( p[0], p[1], p[4], p[5], tr, offset ) | |
| 138 | + tr.x += 2*rect.width; | |
| 139 | + CV_SUM_OFFSETS( p[2], p[3], p[6], p[7], tr, offset ) | |
| 140 | + tr.y +=2*rect.height; | |
| 141 | + CV_SUM_OFFSETS( p[10], p[11], p[14], p[15], tr, offset ) | |
| 142 | + tr.x -= 2*rect.width; | |
| 143 | + CV_SUM_OFFSETS( p[8], p[9], p[12], p[13], tr, offset ) | |
| 144 | +} | |
| 145 | + | |
| 146 | +void LBPEvaluator::Feature::write(FileStorage &fs) const | |
| 147 | +{ | |
| 148 | + fs << CC_RECT << "[:" << rect.x << rect.y << rect.width << rect.height << "]"; | |
| 149 | +} | ... | ... |
openbr/core/features.h
0 → 100644
| 1 | +#ifndef FEATURE_H | |
| 2 | +#define FEATURE_H | |
| 3 | + | |
| 4 | +#include <openbr/openbr_plugin.h> | |
| 5 | +#include "opencv2/imgproc/imgproc.hpp" | |
| 6 | +#include <iostream> | |
| 7 | + | |
| 8 | +#define CC_CASCADE_FILENAME "cascade.xml" | |
| 9 | +#define CC_PARAMS_FILENAME "params.xml" | |
| 10 | + | |
| 11 | +#define CC_CASCADE_PARAMS "cascadeParams" | |
| 12 | +#define CC_STAGE_TYPE "stageType" | |
| 13 | +#define CC_FEATURE_TYPE "featureType" | |
| 14 | +#define CC_HEIGHT "height" | |
| 15 | +#define CC_WIDTH "width" | |
| 16 | + | |
| 17 | +#define CC_STAGE_NUM "stageNum" | |
| 18 | +#define CC_STAGES "stages" | |
| 19 | +#define CC_STAGE_PARAMS "stageParams" | |
| 20 | + | |
| 21 | +#define CC_BOOST "BOOST" | |
| 22 | +#define CC_BOOST_TYPE "boostType" | |
| 23 | +#define CC_DISCRETE_BOOST "DAB" | |
| 24 | +#define CC_REAL_BOOST "RAB" | |
| 25 | +#define CC_LOGIT_BOOST "LB" | |
| 26 | +#define CC_GENTLE_BOOST "GAB" | |
| 27 | +#define CC_MINHITRATE "minHitRate" | |
| 28 | +#define CC_MAXFALSEALARM "maxFalseAlarm" | |
| 29 | +#define CC_TRIM_RATE "weightTrimRate" | |
| 30 | +#define CC_MAX_DEPTH "maxDepth" | |
| 31 | +#define CC_WEAK_COUNT "maxWeakCount" | |
| 32 | +#define CC_STAGE_THRESHOLD "stageThreshold" | |
| 33 | +#define CC_WEAK_CLASSIFIERS "weakClassifiers" | |
| 34 | +#define CC_INTERNAL_NODES "internalNodes" | |
| 35 | +#define CC_LEAF_VALUES "leafValues" | |
| 36 | + | |
| 37 | +#define CC_FEATURES "features" | |
| 38 | +#define CC_FEATURE_PARAMS "featureParams" | |
| 39 | +#define CC_MAX_CAT_COUNT "maxCatCount" | |
| 40 | +#define CC_FEATURE_SIZE "featSize" | |
| 41 | + | |
| 42 | +#define CC_HAAR "HAAR" | |
| 43 | +#define CC_MODE "mode" | |
| 44 | +#define CC_MODE_BASIC "BASIC" | |
| 45 | +#define CC_MODE_CORE "CORE" | |
| 46 | +#define CC_MODE_ALL "ALL" | |
| 47 | +#define CC_RECTS "rects" | |
| 48 | +#define CC_TILTED "tilted" | |
| 49 | + | |
| 50 | +#define CC_LBP "LBP" | |
| 51 | +#define CC_RECT "rect" | |
| 52 | + | |
| 53 | +#define CC_HOG "HOG" | |
| 54 | +#define CC_HOGMULTI "HOGMulti" | |
| 55 | + | |
| 56 | +#define CC_NPD "NPD" | |
| 57 | +#define CC_POINTS "points" | |
| 58 | +#define CC_POINT "point" | |
| 59 | + | |
| 60 | +#ifdef _WIN32 | |
| 61 | +#define TIME( arg ) (((double) clock()) / CLOCKS_PER_SEC) | |
| 62 | +#else | |
| 63 | +#define TIME( arg ) (time( arg )) | |
| 64 | +#endif | |
| 65 | + | |
| 66 | +#define CV_SUM_OFFSETS( p0, p1, p2, p3, rect, step ) \ | |
| 67 | + /* (x, y) */ \ | |
| 68 | + (p0) = (rect).x + (step) * (rect).y; \ | |
| 69 | + /* (x + w, y) */ \ | |
| 70 | + (p1) = (rect).x + (rect).width + (step) * (rect).y; \ | |
| 71 | + /* (x + w, y) */ \ | |
| 72 | + (p2) = (rect).x + (step) * ((rect).y + (rect).height); \ | |
| 73 | + /* (x + w, y + h) */ \ | |
| 74 | + (p3) = (rect).x + (rect).width + (step) * ((rect).y + (rect).height); | |
| 75 | + | |
| 76 | +#define CV_TILTED_OFFSETS( p0, p1, p2, p3, rect, step ) \ | |
| 77 | + /* (x, y) */ \ | |
| 78 | + (p0) = (rect).x + (step) * (rect).y; \ | |
| 79 | + /* (x - h, y + h) */ \ | |
| 80 | + (p1) = (rect).x - (rect).height + (step) * ((rect).y + (rect).height);\ | |
| 81 | + /* (x + w, y + w) */ \ | |
| 82 | + (p2) = (rect).x + (rect).width + (step) * ((rect).y + (rect).width); \ | |
| 83 | + /* (x + w - h, y + w + h) */ \ | |
| 84 | + (p3) = (rect).x + (rect).width - (rect).height \ | |
| 85 | + + (step) * ((rect).y + (rect).width + (rect).height); | |
| 86 | + | |
| 87 | +namespace br | |
| 88 | +{ | |
| 89 | + | |
| 90 | +float calcNormFactor( const cv::Mat& sum, const cv::Mat& sqSum ); | |
| 91 | + | |
| 92 | +template<class Feature> | |
| 93 | +void _writeFeatures( const std::vector<Feature> features, cv::FileStorage &fs, const cv::Mat& featureMap ) | |
| 94 | +{ | |
| 95 | + fs << CC_FEATURES << "["; | |
| 96 | + const cv::Mat_<int>& featureMap_ = (const cv::Mat_<int>&)featureMap; | |
| 97 | + for ( int fi = 0; fi < featureMap.cols; fi++ ) | |
| 98 | + if ( featureMap_(0, fi) >= 0 ) | |
| 99 | + { | |
| 100 | + fs << "{"; | |
| 101 | + features[fi].write( fs ); | |
| 102 | + fs << "}"; | |
| 103 | + } | |
| 104 | + fs << "]"; | |
| 105 | +} | |
| 106 | + | |
| 107 | +class Params | |
| 108 | +{ | |
| 109 | +public: | |
| 110 | + Params(); | |
| 111 | + virtual ~Params() {} | |
| 112 | + // from|to file | |
| 113 | + virtual void write( cv::FileStorage &fs ) const = 0; | |
| 114 | + virtual bool read( const cv::FileNode &node ) = 0; | |
| 115 | + // from|to screen | |
| 116 | + virtual void printDefaults() const; | |
| 117 | + virtual void printAttrs() const; | |
| 118 | + virtual bool scanAttr( const std::string prmName, const std::string val ); | |
| 119 | + std::string name; | |
| 120 | +}; | |
| 121 | + | |
| 122 | +class FeatureParams : public Params | |
| 123 | +{ | |
| 124 | +public: | |
| 125 | + enum { LBP = 0 }; | |
| 126 | + FeatureParams(); | |
| 127 | + virtual void init( const FeatureParams& fp ); | |
| 128 | + virtual void write( cv::FileStorage &fs ) const; | |
| 129 | + virtual bool read( const cv::FileNode &node ); | |
| 130 | + static cv::Ptr<FeatureParams> create( int featureType ); | |
| 131 | + int maxCatCount; // 0 in case of numerical features | |
| 132 | + int featSize; // 1 in case of simple features (HAAR, LBP) and N_BINS(9)*N_CELLS(4) in case of Dalal's HOG features | |
| 133 | +}; | |
| 134 | + | |
| 135 | +class FeatureEvaluator | |
| 136 | +{ | |
| 137 | +public: | |
| 138 | + virtual ~FeatureEvaluator() {} | |
| 139 | + virtual void init(const FeatureParams *_featureParams, | |
| 140 | + int _maxSampleCount, cv::Size _winSize ); | |
| 141 | + virtual void setImage(const cv::Mat& img, uchar clsLabel, int idx); | |
| 142 | + virtual void writeFeatures( cv::FileStorage &fs, const cv::Mat& featureMap ) const = 0; | |
| 143 | + virtual float operator()(int featureIdx, int sampleIdx) const = 0; | |
| 144 | + static cv::Ptr<FeatureEvaluator> create(int type); | |
| 145 | + | |
| 146 | + int getNumFeatures() const { return numFeatures; } | |
| 147 | + int getMaxCatCount() const { return featureParams->maxCatCount; } | |
| 148 | + int getFeatureSize() const { return featureParams->featSize; } | |
| 149 | + const cv::Mat& getCls() const { return cls; } | |
| 150 | + float getCls(int si) const { return cls.at<float>(si, 0); } | |
| 151 | +protected: | |
| 152 | + virtual void generateFeatures() = 0; | |
| 153 | + | |
| 154 | + int npos, nneg; | |
| 155 | + int numFeatures; | |
| 156 | + cv::Size winSize; | |
| 157 | + FeatureParams *featureParams; | |
| 158 | + cv::Mat cls; | |
| 159 | +}; | |
| 160 | + | |
| 161 | + | |
| 162 | +//------------------------- LBP Feature --------------------------------- | |
| 163 | + | |
| 164 | +#define LBPF_NAME "lbpFeatureParams" | |
| 165 | + | |
| 166 | +struct LBPFeatureParams : FeatureParams | |
| 167 | +{ | |
| 168 | + LBPFeatureParams(); | |
| 169 | + | |
| 170 | +}; | |
| 171 | + | |
| 172 | +class LBPEvaluator : public FeatureEvaluator | |
| 173 | +{ | |
| 174 | +public: | |
| 175 | + virtual ~LBPEvaluator() {} | |
| 176 | + virtual void init(const FeatureParams *_featureParams, | |
| 177 | + int _maxSampleCount, cv::Size _winSize ); | |
| 178 | + virtual void setImage(const cv::Mat& img, uchar clsLabel, int idx); | |
| 179 | + virtual float operator()(int featureIdx, int sampleIdx) const | |
| 180 | + { return (float)features[featureIdx].calc( sum, sampleIdx); } | |
| 181 | + virtual void writeFeatures( cv::FileStorage &fs, const cv::Mat& featureMap ) const; | |
| 182 | +protected: | |
| 183 | + virtual void generateFeatures(); | |
| 184 | + | |
| 185 | + class Feature | |
| 186 | + { | |
| 187 | + public: | |
| 188 | + Feature(); | |
| 189 | + Feature( int offset, int x, int y, int _block_w, int _block_h ); | |
| 190 | + uchar calc( const cv::Mat& _sum, size_t y ) const; | |
| 191 | + void write( cv::FileStorage &fs ) const; | |
| 192 | + | |
| 193 | + cv::Rect rect; | |
| 194 | + int p[16]; | |
| 195 | + }; | |
| 196 | + std::vector<Feature> features; | |
| 197 | + | |
| 198 | + cv::Mat sum; | |
| 199 | +}; | |
| 200 | + | |
| 201 | +inline uchar LBPEvaluator::Feature::calc(const cv::Mat &_sum, size_t y) const | |
| 202 | +{ | |
| 203 | + const int* psum = _sum.ptr<int>((int)y); | |
| 204 | + int cval = psum[p[5]] - psum[p[6]] - psum[p[9]] + psum[p[10]]; | |
| 205 | + | |
| 206 | + return (uchar)((psum[p[0]] - psum[p[1]] - psum[p[4]] + psum[p[5]] >= cval ? 128 : 0) | // 0 | |
| 207 | + (psum[p[1]] - psum[p[2]] - psum[p[5]] + psum[p[6]] >= cval ? 64 : 0) | // 1 | |
| 208 | + (psum[p[2]] - psum[p[3]] - psum[p[6]] + psum[p[7]] >= cval ? 32 : 0) | // 2 | |
| 209 | + (psum[p[6]] - psum[p[7]] - psum[p[10]] + psum[p[11]] >= cval ? 16 : 0) | // 5 | |
| 210 | + (psum[p[10]] - psum[p[11]] - psum[p[14]] + psum[p[15]] >= cval ? 8 : 0) | // 8 | |
| 211 | + (psum[p[9]] - psum[p[10]] - psum[p[13]] + psum[p[14]] >= cval ? 4 : 0) | // 7 | |
| 212 | + (psum[p[8]] - psum[p[9]] - psum[p[12]] + psum[p[13]] >= cval ? 2 : 0) | // 6 | |
| 213 | + (psum[p[4]] - psum[p[5]] - psum[p[8]] + psum[p[9]] >= cval ? 1 : 0)); // 3 | |
| 214 | +} | |
| 215 | + | |
| 216 | +} // namespace br | |
| 217 | + | |
| 218 | +#endif // FEATURE_H | ... | ... |
openbr/plugins/metadata/cascade.cpp
| ... | ... | @@ -21,139 +21,9 @@ |
| 21 | 21 | #include <openbr/core/opencvutils.h> |
| 22 | 22 | #include <openbr/core/resource.h> |
| 23 | 23 | #include <openbr/core/qtutils.h> |
| 24 | +#include <openbr/core/cascade.h> | |
| 24 | 25 | |
| 25 | 26 | using namespace cv; |
| 26 | - | |
| 27 | -struct TrainParams | |
| 28 | -{ | |
| 29 | - QString data; // REQUIRED: Filepath to store trained classifier | |
| 30 | - QString vec; // REQUIRED: Filepath to store vector of positive samples, default "vector" | |
| 31 | - QString img; // Filepath to source object image. Either this or info is REQUIRED | |
| 32 | - QString info; // Description file of source images. Either this or img is REQUIRED | |
| 33 | - QString bg; // REQUIRED: Filepath to background list file | |
| 34 | - int num; // Number of samples to generate | |
| 35 | - int bgcolor; // Background color supplied image (via img) | |
| 36 | - int bgthresh; // Threshold to determine bgcolor match | |
| 37 | - bool inv; // Invert colors | |
| 38 | - bool randinv; // Randomly invert colors | |
| 39 | - int maxidev; // Max intensity deviation of foreground pixels | |
| 40 | - double maxxangle; // Maximum rotation angle (X) | |
| 41 | - double maxyangle; // Maximum rotation angle (Y) | |
| 42 | - double maxzangle; // Maximum rotation angle (Z) | |
| 43 | - bool show; // Show generated samples | |
| 44 | - int w; // REQUIRED: Sample width | |
| 45 | - int h; // REQUIRED: Sample height | |
| 46 | - int numPos; // Number of positive samples | |
| 47 | - int numNeg; // Number of negative samples | |
| 48 | - int numStages; // Number of stages | |
| 49 | - int precalcValBufSize; // Precalculated val buffer size in Mb | |
| 50 | - int precalcIdxBufSize; // Precalculated index buffer size in Mb | |
| 51 | - bool baseFormatSave; // Save in old format | |
| 52 | - QString stageType; // Stage type (BOOST) | |
| 53 | - QString featureType; // Feature type (HAAR, LBP) | |
| 54 | - QString bt; // Boosted classifier type (DAB, RAB, LB, GAB) | |
| 55 | - double minHitRate; // Minimal hit rate per stage | |
| 56 | - double maxFalseAlarmRate; // Max false alarm rate per stage | |
| 57 | - double weightTrimRate; // Weight for trimming | |
| 58 | - int maxDepth; // Max weak tree depth | |
| 59 | - int maxWeakCount; // Max weak tree count per stage | |
| 60 | - QString mode; // Haar feature mode (BASIC, CORE, ALL) | |
| 61 | - | |
| 62 | - TrainParams() | |
| 63 | - { | |
| 64 | - num = -1; | |
| 65 | - maxidev = -1; | |
| 66 | - maxxangle = -1; | |
| 67 | - maxyangle = -1; | |
| 68 | - maxzangle = -1; | |
| 69 | - w = -1; | |
| 70 | - h = -1; | |
| 71 | - numPos = -1; | |
| 72 | - numNeg = -1; | |
| 73 | - numStages = -1; | |
| 74 | - precalcValBufSize = -1; | |
| 75 | - precalcIdxBufSize = -1; | |
| 76 | - minHitRate = -1; | |
| 77 | - maxFalseAlarmRate = -1; | |
| 78 | - weightTrimRate = -1; | |
| 79 | - maxDepth = -1; | |
| 80 | - maxWeakCount = -1; | |
| 81 | - inv = false; | |
| 82 | - randinv = false; | |
| 83 | - show = false; | |
| 84 | - baseFormatSave = false; | |
| 85 | - vec = "vector.vec"; | |
| 86 | - bgcolor = -1; | |
| 87 | - bgthresh = -1; | |
| 88 | - } | |
| 89 | -}; | |
| 90 | - | |
| 91 | -static QStringList buildTrainingArgs(const TrainParams ¶ms) | |
| 92 | -{ | |
| 93 | - QStringList args; | |
| 94 | - if (params.data != "") args << "-data" << params.data; | |
| 95 | - else qFatal("Must specify storage location for cascade"); | |
| 96 | - if (params.vec != "") args << "-vec" << params.vec; | |
| 97 | - else qFatal("Must specify location of positive vector"); | |
| 98 | - if (params.bg != "") args << "-bg" << params.bg; | |
| 99 | - else qFatal("Must specify negative images"); | |
| 100 | - if (params.numPos >= 0) args << "-numPos" << QString::number(params.numPos); | |
| 101 | - if (params.numNeg >= 0) args << "-numNeg" << QString::number(params.numNeg); | |
| 102 | - if (params.numStages >= 0) args << "-numStages" << QString::number(params.numStages); | |
| 103 | - if (params.precalcValBufSize >= 0) args << "-precalcValBufSize" << QString::number(params.precalcValBufSize); | |
| 104 | - if (params.precalcIdxBufSize >= 0) args << "-precalcIdxBufSize" << QString::number(params.precalcIdxBufSize); | |
| 105 | - if (params.baseFormatSave) args << "-baseFormatSave"; | |
| 106 | - if (params.stageType != "") args << "-stageType" << params.stageType; | |
| 107 | - if (params.featureType != "") args << "-featureType" << params.featureType; | |
| 108 | - if (params.w >= 0) args << "-w" << QString::number(params.w); | |
| 109 | - else qFatal("Must specify width"); | |
| 110 | - if (params.h >= 0) args << "-h" << QString::number(params.h); | |
| 111 | - else qFatal("Must specify height"); | |
| 112 | - if (params.bt != "") args << "-bt" << params.bt; | |
| 113 | - if (params.minHitRate >= 0) args << "-minHitRate" << QString::number(params.minHitRate); | |
| 114 | - if (params.maxFalseAlarmRate >= 0) args << "-maxFalseAlarmRate" << QString::number(params.maxFalseAlarmRate); | |
| 115 | - if (params.weightTrimRate >= 0) args << "-weightTrimRate" << QString::number(params.weightTrimRate); | |
| 116 | - if (params.maxDepth >= 0) args << "-maxDepth" << QString::number(params.maxDepth); | |
| 117 | - if (params.maxWeakCount >= 0) args << "-maxWeakCount" << QString::number(params.maxWeakCount); | |
| 118 | - if (params.mode != "") args << "-mode" << params.mode; | |
| 119 | - return args; | |
| 120 | -} | |
| 121 | - | |
| 122 | -static QStringList buildSampleArgs(const TrainParams ¶ms) | |
| 123 | -{ | |
| 124 | - QStringList args; | |
| 125 | - if (params.vec != "") args << "-vec" << params.vec; | |
| 126 | - else qFatal("Must specify location of positive vector"); | |
| 127 | - if (params.img != "") args << "-img" << params.img; | |
| 128 | - else if (params.info != "") args << "-info" << params.info; | |
| 129 | - else qFatal("Must specify positive images"); | |
| 130 | - if (params.bg != "") args << "-bg" << params.bg; | |
| 131 | - if (params.num > 0) args << "-num" << QString::number(params.num); | |
| 132 | - if (params.bgcolor >=0 ) args << "-bgcolor" << QString::number(params.bgcolor); | |
| 133 | - if (params.bgthresh >= 0) args << "-bgthresh" << QString::number(params.bgthresh); | |
| 134 | - if (params.maxidev >= 0) args << "-maxidev" << QString::number(params.maxidev); | |
| 135 | - if (params.maxxangle >= 0) args << "-maxxangle" << QString::number(params.maxxangle); | |
| 136 | - if (params.maxyangle >= 0) args << "-maxyangle" << QString::number(params.maxyangle); | |
| 137 | - if (params.maxzangle >= 0) args << "-maxzangle" << QString::number(params.maxzangle); | |
| 138 | - if (params.w >= 0) args << "-w" << QString::number(params.w); | |
| 139 | - if (params.h >= 0) args << "-h" << QString::number(params.h); | |
| 140 | - if (params.show) args << "-show"; | |
| 141 | - if (params.inv) args << "-inv"; | |
| 142 | - if (params.randinv) args << "-randinv"; | |
| 143 | - return args; | |
| 144 | -} | |
| 145 | - | |
| 146 | -static void genSamples(const TrainParams ¶ms) | |
| 147 | -{ | |
| 148 | - const QStringList cmdArgs = buildSampleArgs(params); | |
| 149 | - QProcess::execute("opencv_createsamples",cmdArgs); | |
| 150 | -} | |
| 151 | - | |
| 152 | -static void trainCascade(const TrainParams ¶ms) | |
| 153 | -{ | |
| 154 | - const QStringList cmdArgs = buildTrainingArgs(params); | |
| 155 | - QProcess::execute("opencv_traincascade", cmdArgs); | |
| 156 | -} | |
| 157 | 27 | |
| 158 | 28 | namespace br |
| 159 | 29 | { |
| ... | ... | @@ -202,50 +72,21 @@ class CascadeTransform : public MetaTransform |
| 202 | 72 | Q_PROPERTY(int minNeighbors READ get_minNeighbors WRITE set_minNeighbors RESET reset_minNeighbors STORED false) |
| 203 | 73 | Q_PROPERTY(bool ROCMode READ get_ROCMode WRITE set_ROCMode RESET reset_ROCMode STORED false) |
| 204 | 74 | |
| 205 | - // Training parameters | |
| 206 | - Q_PROPERTY(int numStages READ get_numStages WRITE set_numStages RESET reset_numStages STORED false) | |
| 207 | - Q_PROPERTY(int w READ get_w WRITE set_w RESET reset_w STORED false) | |
| 208 | - Q_PROPERTY(int h READ get_h WRITE set_h RESET reset_h STORED false) | |
| 209 | - Q_PROPERTY(int numPos READ get_numPos WRITE set_numPos RESET reset_numPos STORED false) | |
| 210 | - Q_PROPERTY(int numNeg READ get_numNeg WRITE set_numNeg RESET reset_numNeg STORED false) | |
| 211 | - Q_PROPERTY(int precalcValBufSize READ get_precalcValBufSize WRITE set_precalcValBufSize RESET reset_precalcValBufSize STORED false) | |
| 212 | - Q_PROPERTY(int precalcIdxBufSize READ get_precalcIdxBufSize WRITE set_precalcIdxBufSize RESET reset_precalcIdxBufSize STORED false) | |
| 213 | - Q_PROPERTY(double minHitRate READ get_minHitRate WRITE set_minHitRate RESET reset_minHitRate STORED false) | |
| 214 | - Q_PROPERTY(double maxFalseAlarmRate READ get_maxFalseAlarmRate WRITE set_maxFalseAlarmRate RESET reset_maxFalseAlarmRate STORED false) | |
| 215 | - Q_PROPERTY(double weightTrimRate READ get_weightTrimRate WRITE set_weightTrimRate RESET reset_weightTrimRate STORED false) | |
| 216 | - Q_PROPERTY(int maxDepth READ get_maxDepth WRITE set_maxDepth RESET reset_maxDepth STORED false) | |
| 217 | - Q_PROPERTY(int maxWeakCount READ get_maxWeakCount WRITE set_maxWeakCount RESET reset_maxWeakCount STORED false) | |
| 218 | - Q_PROPERTY(QString stageType READ get_stageType WRITE set_stageType RESET reset_stageType STORED false) | |
| 219 | - Q_PROPERTY(QString featureType READ get_featureType WRITE set_featureType RESET reset_featureType STORED false) | |
| 220 | - Q_PROPERTY(QString bt READ get_bt WRITE set_bt RESET reset_bt STORED false) | |
| 221 | - Q_PROPERTY(QString mode READ get_mode WRITE set_mode RESET reset_mode STORED false) | |
| 222 | - Q_PROPERTY(bool show READ get_show WRITE set_show RESET reset_show STORED false) | |
| 223 | - Q_PROPERTY(bool baseFormatSave READ get_baseFormatSave WRITE set_baseFormatSave RESET reset_baseFormatSave STORED false) | |
| 75 | + // Training parameters | |
| 76 | + Q_PROPERTY(QString vecFile READ get_vecFile WRITE set_vecFile RESET reset_vecFile STORED false) | |
| 77 | + Q_PROPERTY(QString negFile READ get_negFile WRITE set_negFile RESET reset_negFile STORED false) | |
| 78 | + Q_PROPERTY(int numPos READ get_numPos WRITE set_numPos RESET reset_numPos STORED false) | |
| 79 | + Q_PROPERTY(int numNeg READ get_numNeg WRITE set_numNeg RESET reset_numNeg STORED false) | |
| 224 | 80 | |
| 225 | 81 | BR_PROPERTY(QString, model, "FrontalFace") |
| 226 | 82 | BR_PROPERTY(int, minSize, 64) |
| 227 | 83 | BR_PROPERTY(int, minNeighbors, 5) |
| 228 | - BR_PROPERTY(bool, ROCMode, false) | |
| 229 | - | |
| 230 | - // Training parameters - Default values provided trigger OpenCV defaults | |
| 231 | - BR_PROPERTY(int, numStages, -1) | |
| 232 | - BR_PROPERTY(int, w, -1) | |
| 233 | - BR_PROPERTY(int, h, -1) | |
| 234 | - BR_PROPERTY(int, numPos, -1) | |
| 235 | - BR_PROPERTY(int, numNeg, -1) | |
| 236 | - BR_PROPERTY(int, precalcValBufSize, -1) | |
| 237 | - BR_PROPERTY(int, precalcIdxBufSize, -1) | |
| 238 | - BR_PROPERTY(double, minHitRate, -1) | |
| 239 | - BR_PROPERTY(double, maxFalseAlarmRate, -1) | |
| 240 | - BR_PROPERTY(double, weightTrimRate, -1) | |
| 241 | - BR_PROPERTY(int, maxDepth, -1) | |
| 242 | - BR_PROPERTY(int, maxWeakCount, -1) | |
| 243 | - BR_PROPERTY(QString, stageType, "") | |
| 244 | - BR_PROPERTY(QString, featureType, "") | |
| 245 | - BR_PROPERTY(QString, bt, "") | |
| 246 | - BR_PROPERTY(QString, mode, "") | |
| 247 | - BR_PROPERTY(bool, show, false) | |
| 248 | - BR_PROPERTY(bool, baseFormatSave, false) | |
| 84 | + BR_PROPERTY(bool, ROCMode, false) | |
| 85 | + | |
| 86 | + BR_PROPERTY(QString, vecFile, "data.vec") | |
| 87 | + BR_PROPERTY(QString, negFile, "neg.txt") | |
| 88 | + BR_PROPERTY(int, numPos, 1000) | |
| 89 | + BR_PROPERTY(int, numNeg, 1000) | |
| 249 | 90 | |
| 250 | 91 | Resource<CascadeClassifier> cascadeResource; |
| 251 | 92 | |
| ... | ... | @@ -259,115 +100,19 @@ class CascadeTransform : public MetaTransform |
| 259 | 100 | // Train transform |
| 260 | 101 | void train(const TemplateList& data) |
| 261 | 102 | { |
| 262 | - // Don't train if we're using OpenCV's prebuilt cascades | |
| 263 | - if (model == "Ear" || model == "Eye" || model == "FrontalFace" || model == "ProfileFace") | |
| 264 | - return; | |
| 265 | - | |
| 266 | - // Open positive and negative list temporary files | |
| 267 | - QTemporaryFile posFile; | |
| 268 | - QTemporaryFile negFile; | |
| 269 | - | |
| 270 | - posFile.open(); | |
| 271 | - negFile.open(); | |
| 272 | - | |
| 273 | - QTextStream posStream(&posFile); | |
| 274 | - QTextStream negStream(&negFile); | |
| 275 | - | |
| 276 | - TrainParams params; | |
| 277 | - | |
| 278 | - // Fill in from params (param defaults are same as struct defaults, so no checks are needed) | |
| 279 | - params.numStages = numStages; | |
| 280 | - params.w = w; | |
| 281 | - params.h = h; | |
| 282 | - params.numPos = numPos; | |
| 283 | - params.numNeg = numNeg; | |
| 284 | - params.precalcValBufSize = precalcValBufSize; | |
| 285 | - params.precalcIdxBufSize = precalcIdxBufSize; | |
| 286 | - params.minHitRate = minHitRate; | |
| 287 | - params.maxFalseAlarmRate = maxFalseAlarmRate; | |
| 288 | - params.weightTrimRate = weightTrimRate; | |
| 289 | - params.maxDepth = maxDepth; | |
| 290 | - params.maxWeakCount = maxWeakCount; | |
| 291 | - params.stageType = stageType; | |
| 292 | - params.featureType = featureType; | |
| 293 | - params.bt = bt; | |
| 294 | - params.mode = mode; | |
| 295 | - params.show = show; | |
| 296 | - params.baseFormatSave = baseFormatSave; | |
| 297 | - if (params.w < 0) params.w = minSize; | |
| 298 | - if (params.h < 0) params.h = minSize; | |
| 299 | - | |
| 300 | - int posCount = 0; | |
| 301 | - int negCount = 0; | |
| 302 | - | |
| 303 | - bool buildPos = false; // If true, build positive vector from single image | |
| 304 | - | |
| 305 | - const FileList files = data.files(); | |
| 306 | - | |
| 307 | - for (int i = 0; i < files.length(); i++) { | |
| 308 | - File f = files[i]; | |
| 309 | - if (f.contains("training-set")) { | |
| 310 | - QString tset = f.get<QString>("training-set",QString()).toLower(); | |
| 311 | - | |
| 312 | - // Negative samples | |
| 313 | - if (tset == "neg") { | |
| 314 | - negStream << f.path() << QDir::separator() << f.fileName() << endl; | |
| 315 | - negCount++; | |
| 316 | - // Positive samples for crop/rescale | |
| 317 | - } else if (tset == "pos") { | |
| 318 | - QString buffer = ""; | |
| 319 | - | |
| 320 | - // Extract rectangles | |
| 321 | - QList<QRectF> rects = f.rects(); | |
| 322 | - for (int j = 0; j < rects.size(); j++) { | |
| 323 | - QRectF r = rects[j]; | |
| 324 | - buffer += " " + QString::number(r.x()) + " " + QString::number(r.y()) + " " + QString::number(r.width()) + " "+ QString::number(r.height()); | |
| 325 | - posCount++; | |
| 326 | - } | |
| 327 | - | |
| 328 | - posStream << f.path() << QDir::separator() << f.fileName() << " " << f.rects().length() << " " << buffer << endl; | |
| 329 | - | |
| 330 | - // Single positive sample for background removal and overlay on negatives | |
| 331 | - } else if (tset == "pos-base") { | |
| 332 | - buildPos = true; | |
| 333 | - params.img = f.path() + QDir::separator() + f.fileName(); | |
| 334 | - | |
| 335 | - // Parse settings (unique to this one tag) | |
| 336 | - if (f.contains("num")) params.num = f.get<int>("num"); | |
| 337 | - if (f.contains("bgcolor")) params.bgcolor = f.get<int>("bgcolor"); | |
| 338 | - if (f.contains("bgthresh")) params.bgthresh =f.get<int>("bgthresh"); | |
| 339 | - if (f.contains("inv")) params.inv = f.get<bool>("inv",false); | |
| 340 | - if (f.contains("randinv")) params.randinv = f.get<bool>("randinv",false); | |
| 341 | - if (f.contains("maxidev")) params.maxidev = f.get<int>("maxidev"); | |
| 342 | - if (f.contains("maxxangle")) params.maxxangle = f.get<double>("maxxangle"); | |
| 343 | - if (f.contains("maxyangle")) params.maxyangle = f.get<double>("maxyangle"); | |
| 344 | - if (f.contains("maxzangle")) params.maxzangle = f.get<double>("maxzangle"); | |
| 345 | - } | |
| 346 | - } | |
| 347 | - } | |
| 348 | - | |
| 349 | - posFile.close(); | |
| 350 | - negFile.close(); | |
| 351 | - | |
| 352 | - // Fill in remaining params conditionally | |
| 353 | - if (buildPos) { | |
| 354 | - if (params.numPos < 0) { | |
| 355 | - if (params.num > 0) params.numPos = params.num*.95; | |
| 356 | - else params.numPos = 950; | |
| 357 | - } | |
| 358 | - } else { | |
| 359 | - params.info = posFile.fileName(); | |
| 360 | - if (params.numPos < 0) params.numPos = posCount*.95; | |
| 361 | - } | |
| 103 | + (void)data; | |
| 362 | 104 | |
| 363 | - if (params.num < 0) params.num = posCount; | |
| 364 | - if (params.numNeg < 0) params.numNeg = negCount*10; | |
| 105 | + BrCascadeClassifier classifier; | |
| 365 | 106 | |
| 366 | - params.bg = negFile.fileName(); | |
| 367 | - params.data = Globals->sdkPath + "/share/openbr/models/openbrcascades/" + model + "/cascade.xml"; | |
| 107 | + CascadeParams cascadeParams(CascadeParams::BOOST, FeatureParams::LBP); | |
| 108 | + CascadeBoostParams stageParams(CvBoost::GENTLE, 0.999, 0.5, 0.95, 1, 200); | |
| 109 | + LBPFeatureParams featureParams; | |
| 368 | 110 | |
| 369 | - genSamples(params); | |
| 370 | - trainCascade(params); | |
| 111 | + QString cascadeDir = Globals->sdkPath + "/share/openbr/models/openbrcascades/" + model; | |
| 112 | + classifier.train(cascadeDir.toStdString(), | |
| 113 | + vecFile.toStdString(), negFile.toStdString(), | |
| 114 | + numPos, numNeg, 1024, 1024, 20, | |
| 115 | + cascadeParams, featureParams, stageParams); | |
| 371 | 116 | } |
| 372 | 117 | |
| 373 | 118 | void project(const Template &src, Template &dst) const | ... | ... |