Commit df64be5e2716b969ffe441fa79825e625c80b7d9
1 parent
4691ca21
Revert "Revert "Merge pull request #320 from biometrics/block_compression""
This reverts commit decda7bc5277e839d1bcf3e78754f2400878c969.
Showing
10 changed files
with
269 additions
and
55 deletions
openbr/core/common.cpp
| ... | ... | @@ -16,21 +16,32 @@ |
| 16 | 16 | |
| 17 | 17 | #include "common.h" |
| 18 | 18 | #include <QMutex> |
| 19 | +#include <RandomLib/Random.hpp> | |
| 19 | 20 | |
| 20 | 21 | using namespace std; |
| 21 | 22 | |
| 23 | +static RandomLib::Random g_rand; | |
| 24 | +static QMutex rngLock; | |
| 25 | + | |
| 22 | 26 | /**** GLOBAL ****/ |
| 23 | 27 | void Common::seedRNG() { |
| 24 | - static QMutex seedControl; | |
| 25 | - QMutexLocker lock(&seedControl); | |
| 28 | + QMutexLocker lock(&rngLock); | |
| 26 | 29 | |
| 27 | 30 | static bool seeded = false; |
| 28 | 31 | if (!seeded) { |
| 29 | 32 | srand(0); // We seed with 0 instead of time(NULL) to have reproducible randomness |
| 30 | 33 | seeded = true; |
| 34 | + g_rand.Reseed(0); | |
| 31 | 35 | } |
| 32 | 36 | } |
| 33 | 37 | |
| 38 | +double Common::randN() | |
| 39 | +{ | |
| 40 | + QMutexLocker lock(&rngLock); | |
| 41 | + | |
| 42 | + return g_rand.FloatN(); | |
| 43 | +} | |
| 44 | + | |
| 34 | 45 | QList<int> Common::RandSample(int n, int max, int min, bool unique) |
| 35 | 46 | { |
| 36 | 47 | QList<int> samples; samples.reserve(n); | ... | ... |
openbr/core/common.h
| ... | ... | @@ -220,6 +220,9 @@ double KernelDensityEstimation(const V<T> &vals, double x, double h) |
| 220 | 220 | return y / (vals.size() * h); |
| 221 | 221 | } |
| 222 | 222 | |
| 223 | +// Return a random number, uniformly distributed over 0,1 | |
| 224 | +double randN(); | |
| 225 | + | |
| 223 | 226 | /*! |
| 224 | 227 | * \brief Returns a vector of n integers sampled in the range <min, max]. |
| 225 | 228 | * |
| ... | ... | @@ -236,19 +239,14 @@ QList<int> RandSample(int n, const QSet<int> &values, bool unique = false); |
| 236 | 239 | template <typename T> |
| 237 | 240 | QList<int> RandSample(int n, const QList<T> &weights, bool unique = false) |
| 238 | 241 | { |
| 239 | - static bool seeded = false; | |
| 240 | - if (!seeded) { | |
| 241 | - srand(time(NULL)); | |
| 242 | - seeded = true; | |
| 243 | - } | |
| 244 | - | |
| 245 | 242 | QList<T> cdf = CumSum(weights); |
| 246 | 243 | for (int i=0; i<cdf.size(); i++) // Normalize cdf |
| 247 | 244 | cdf[i] = cdf[i] / cdf.last(); |
| 248 | 245 | |
| 249 | 246 | QList<int> samples; samples.reserve(n); |
| 250 | 247 | while (samples.size() < n) { |
| 251 | - T r = (T)rand() / (T)RAND_MAX; | |
| 248 | + T r = randN(); | |
| 249 | + | |
| 252 | 250 | for (int j=0; j<weights.size(); j++) { |
| 253 | 251 | if ((r >= cdf[j]) && (r <= cdf[j+1])) { |
| 254 | 252 | if (!unique || !samples.contains(j)) | ... | ... |
openbr/core/core.cpp
| ... | ... | @@ -110,9 +110,11 @@ struct AlgorithmCore |
| 110 | 110 | |
| 111 | 111 | void store(const QString &model) const |
| 112 | 112 | { |
| 113 | - // Create stream | |
| 114 | - QByteArray data; | |
| 115 | - QDataStream out(&data, QFile::WriteOnly); | |
| 113 | + QtUtils::BlockCompression compressedWrite; | |
| 114 | + QFile outFile(model); | |
| 115 | + compressedWrite.setBasis(&outFile); | |
| 116 | + QDataStream out(&compressedWrite); | |
| 117 | + compressedWrite.open(QFile::WriteOnly); | |
| 116 | 118 | |
| 117 | 119 | // Serialize algorithm to stream |
| 118 | 120 | transform->serialize(out); |
| ... | ... | @@ -131,18 +133,16 @@ struct AlgorithmCore |
| 131 | 133 | if (mode == TransformCompare) |
| 132 | 134 | comparison->serialize(out); |
| 133 | 135 | |
| 134 | - // Compress and save to file | |
| 135 | - QtUtils::writeFile(model, data, -1); | |
| 136 | + compressedWrite.close(); | |
| 136 | 137 | } |
| 137 | 138 | |
| 138 | 139 | void load(const QString &model) |
| 139 | 140 | { |
| 140 | - // Load from file and decompress | |
| 141 | - QByteArray data; | |
| 142 | - QtUtils::readFile(model, data, true); | |
| 143 | - | |
| 144 | - // Create stream | |
| 145 | - QDataStream in(&data, QFile::ReadOnly); | |
| 141 | + QtUtils::BlockCompression compressedRead; | |
| 142 | + QFile inFile(model); | |
| 143 | + compressedRead.setBasis(&inFile); | |
| 144 | + QDataStream in(&compressedRead); | |
| 145 | + compressedRead.open(QFile::ReadOnly); | |
| 146 | 146 | |
| 147 | 147 | // Load algorithm |
| 148 | 148 | transform = QSharedPointer<Transform>(Transform::deserialize(in)); | ... | ... |
openbr/core/qtutils.cpp
| ... | ... | @@ -500,6 +500,131 @@ QString getAbsolutePath(const QString &filename) |
| 500 | 500 | return QFileInfo(filename).absoluteFilePath(); |
| 501 | 501 | } |
| 502 | 502 | |
| 503 | +BlockCompression::BlockCompression(QIODevice *_basis) | |
| 504 | +{ | |
| 505 | + blockSize = 100000000; | |
| 506 | + setBasis(_basis); | |
| 507 | +} | |
| 508 | + | |
| 509 | +BlockCompression::BlockCompression() { blockSize = 100000000; }; | |
| 510 | + | |
| 511 | + | |
| 512 | +bool BlockCompression::open(QIODevice::OpenMode mode) | |
| 513 | +{ | |
| 514 | + this->setOpenMode(mode); | |
| 515 | + bool res = basis->open(mode); | |
| 516 | + | |
| 517 | + if (!res) | |
| 518 | + return false; | |
| 519 | + | |
| 520 | + blockReader.setDevice(basis); | |
| 521 | + blockWriter.setDevice(basis); | |
| 522 | + | |
| 523 | + if (mode & QIODevice::WriteOnly) { | |
| 524 | + precompressedBlockWriter = new QBuffer; | |
| 525 | + precompressedBlockWriter->open(QIODevice::ReadWrite); | |
| 526 | + } | |
| 527 | + else if (mode & QIODevice::ReadOnly) { | |
| 528 | + QByteArray compressedBlock; | |
| 529 | + blockReader >> compressedBlock; | |
| 530 | + | |
| 531 | + decompressedBlock = qUncompress(compressedBlock); | |
| 532 | + decompressedBlockReader.setBuffer(&decompressedBlock); | |
| 533 | + decompressedBlockReader.open(QIODevice::ReadOnly); | |
| 534 | + } | |
| 535 | + | |
| 536 | + return true; | |
| 537 | +} | |
| 538 | + | |
| 539 | +void BlockCompression::close() | |
| 540 | +{ | |
| 541 | + // flush output buffer | |
| 542 | + if ((openMode() & QIODevice::WriteOnly) && precompressedBlockWriter) { | |
| 543 | + QByteArray compressedBlock = qCompress(precompressedBlockWriter->buffer(), -1); | |
| 544 | + blockWriter << compressedBlock; | |
| 545 | + } | |
| 546 | + basis->close(); | |
| 547 | +} | |
| 548 | + | |
| 549 | +void BlockCompression::setBasis(QIODevice *_basis) | |
| 550 | +{ | |
| 551 | + basis = _basis; | |
| 552 | + blockReader.setDevice(basis); | |
| 553 | + blockWriter.setDevice(basis); | |
| 554 | +} | |
| 555 | + | |
| 556 | +// read from current decompressed block, if out of space, read and decompress another | |
| 557 | +// block from basis | |
| 558 | +qint64 BlockCompression::readData(char *data, qint64 remaining) | |
| 559 | +{ | |
| 560 | + qint64 read = 0; | |
| 561 | + while (remaining > 0) { | |
| 562 | + qint64 single_read = decompressedBlockReader.read(data, remaining); | |
| 563 | + if (single_read == -1) | |
| 564 | + qFatal("miss read"); | |
| 565 | + | |
| 566 | + remaining -= single_read; | |
| 567 | + read += single_read; | |
| 568 | + data += single_read; | |
| 569 | + | |
| 570 | + // need a new block | |
| 571 | + if (remaining > 0) { | |
| 572 | + QByteArray compressedBlock; | |
| 573 | + blockReader >> compressedBlock; | |
| 574 | + if (compressedBlock.size() == 0) { | |
| 575 | + return read; | |
| 576 | + } | |
| 577 | + decompressedBlock = qUncompress(compressedBlock); | |
| 578 | + | |
| 579 | + decompressedBlockReader.close(); | |
| 580 | + decompressedBlockReader.setBuffer(&decompressedBlock); | |
| 581 | + decompressedBlockReader.open(QIODevice::ReadOnly); | |
| 582 | + } | |
| 583 | + } | |
| 584 | + return blockReader.atEnd() && !basis->isReadable() ? -1 : read; | |
| 585 | +} | |
| 586 | + | |
| 587 | +bool BlockCompression::isSequential() const | |
| 588 | +{ | |
| 589 | + return true; | |
| 590 | +} | |
| 591 | + | |
| 592 | +qint64 BlockCompression::writeData(const char *data, qint64 remaining) | |
| 593 | +{ | |
| 594 | + qint64 written = 0; | |
| 595 | + | |
| 596 | + while (remaining > 0) { | |
| 597 | + // how much more can be put in this buffer? | |
| 598 | + qint64 capacity = blockSize - precompressedBlockWriter->pos(); | |
| 599 | + | |
| 600 | + // don't try to write beyond capacity | |
| 601 | + qint64 write_size = qMin(capacity, remaining); | |
| 602 | + | |
| 603 | + qint64 singleWrite = precompressedBlockWriter->write(data, write_size); | |
| 604 | + // ignore the error case here, we consdier basis's failure mode the real | |
| 605 | + // end case | |
| 606 | + if (singleWrite == -1) | |
| 607 | + singleWrite = 0; | |
| 608 | + | |
| 609 | + remaining -= singleWrite; | |
| 610 | + data += singleWrite; | |
| 611 | + written += singleWrite; | |
| 612 | + | |
| 613 | + if (remaining > 0) { | |
| 614 | + QByteArray compressedBlock = qCompress(precompressedBlockWriter->buffer(), -1); | |
| 615 | + | |
| 616 | + if (compressedBlock.size() != 0) | |
| 617 | + blockWriter << compressedBlock; | |
| 618 | + | |
| 619 | + delete precompressedBlockWriter; | |
| 620 | + precompressedBlockWriter = new QBuffer; | |
| 621 | + precompressedBlockWriter->open(QIODevice::ReadWrite); | |
| 622 | + } | |
| 623 | + } | |
| 624 | + return basis->isWritable() ? written : -1; | |
| 625 | +} | |
| 626 | + | |
| 627 | + | |
| 503 | 628 | |
| 504 | 629 | } // namespace QtUtils |
| 505 | 630 | ... | ... |
openbr/core/qtutils.h
| ... | ... | @@ -17,6 +17,7 @@ |
| 17 | 17 | #ifndef QTUTILS_QTUTILS_H |
| 18 | 18 | #define QTUTILS_QTUTILS_H |
| 19 | 19 | |
| 20 | +#include <QBuffer> | |
| 20 | 21 | #include <QByteArray> |
| 21 | 22 | #include <QDir> |
| 22 | 23 | #include <QFile> |
| ... | ... | @@ -93,6 +94,38 @@ namespace QtUtils |
| 93 | 94 | |
| 94 | 95 | /**** Rect Utilities ****/ |
| 95 | 96 | float overlap(const QRectF &r, const QRectF &s); |
| 97 | + | |
| 98 | + | |
| 99 | + class BlockCompression : public QIODevice | |
| 100 | + { | |
| 101 | + public: | |
| 102 | + BlockCompression(QIODevice *_basis); | |
| 103 | + BlockCompression(); | |
| 104 | + int blockSize; | |
| 105 | + QIODevice *basis; | |
| 106 | + | |
| 107 | + bool open(QIODevice::OpenMode mode); | |
| 108 | + | |
| 109 | + void close(); | |
| 110 | + | |
| 111 | + void setBasis(QIODevice *_basis); | |
| 112 | + | |
| 113 | + QDataStream blockReader; | |
| 114 | + QByteArray decompressedBlock; | |
| 115 | + QBuffer decompressedBlockReader; | |
| 116 | + | |
| 117 | + // read from current decompressed block, if out of space, read and decompress another | |
| 118 | + // block from basis | |
| 119 | + qint64 readData(char *data, qint64 remaining); | |
| 120 | + | |
| 121 | + bool isSequential() const; | |
| 122 | + | |
| 123 | + // write to a QByteArray, when max block sized is reached, compress and write | |
| 124 | + // it to basis | |
| 125 | + QBuffer * precompressedBlockWriter; | |
| 126 | + QDataStream blockWriter; | |
| 127 | + qint64 writeData(const char *data, qint64 remaining); | |
| 128 | + }; | |
| 96 | 129 | } |
| 97 | 130 | |
| 98 | 131 | #endif // QTUTILS_QTUTILS_H | ... | ... |
openbr/plugins/algorithms.cpp
| ... | ... | @@ -31,15 +31,15 @@ class AlgorithmsInitializer : public Initializer |
| 31 | 31 | void initialize() const |
| 32 | 32 | { |
| 33 | 33 | // Face |
| 34 | - Globals->abbreviations.insert("FaceRecognition", "FaceDetection+Expand+<FaceRecognitionRegistration>+Expand+<FaceRecognitionExtraction>+<FaceRecognitionEmbedding>+<FaceRecognitionQuantization>+SetMetadata(AlgorithmID,-1):MatchProbability(ByteL1)"); | |
| 35 | - Globals->abbreviations.insert("GenderClassification", "FaceDetection+Expand+<FaceClassificationRegistration>+Expand+<FaceClassificationExtraction>+<GenderClassifier>+Discard"); | |
| 36 | - Globals->abbreviations.insert("AgeRegression", "FaceDetection+Expand+<FaceClassificationRegistration>+Expand+<FaceClassificationExtraction>+<AgeRegressor>+Discard"); | |
| 34 | + Globals->abbreviations.insert("FaceRecognition", "FaceDetection+FaceRecognitionRegistration+<FaceRecognitionExtraction>+<FaceRecognitionEmbedding>+<FaceRecognitionQuantization>+SetMetadata(AlgorithmID,-1):Unit(ByteL1)"); | |
| 35 | + Globals->abbreviations.insert("GenderClassification", "FaceDetection+Expand+FaceClassificationRegistration+Expand+<FaceClassificationExtraction>+<GenderClassifier>+Discard"); | |
| 36 | + Globals->abbreviations.insert("AgeRegression", "FaceDetection+Expand+FaceClassificationRegistration+Expand+<FaceClassificationExtraction>+<AgeRegressor>+Discard"); | |
| 37 | 37 | Globals->abbreviations.insert("FaceQuality", "Open+Expand+Cascade(FrontalFace)+ASEFEyes+Affine(64,64,0.25,0.35)+ImageQuality+Cvt(Gray)+DFFS+Discard"); |
| 38 | 38 | Globals->abbreviations.insert("MedianFace", "Open+Expand+Cascade(FrontalFace)+ASEFEyes+Affine(256,256,0.37,0.45)+Center(Median)"); |
| 39 | 39 | Globals->abbreviations.insert("BlurredFaceDetection", "Open+LimitSize(1024)+SkinMask/(Cvt(Gray)+GradientMask)+And+Morph(Erode,16)+LargestConvexArea"); |
| 40 | 40 | Globals->abbreviations.insert("DrawFaceDetection", "Open+Cascade(FrontalFace)+Expand+ASEFEyes+Draw(inPlace=true)"); |
| 41 | 41 | Globals->abbreviations.insert("ShowFaceDetection", "DrawFaceDetection+Contract+First+Show+Discard"); |
| 42 | - Globals->abbreviations.insert("DownloadFaceRecognition", "Download+Open+ROI+Expand+Cvt(Gray)+Cascade(FrontalFace)+Expand+<FaceRecognitionRegistration>+Expand+<FaceRecognitionExtraction>+<FaceRecognitionEmbedding>+<FaceRecognitionQuantization>+SetMetadata(AlgorithmID,-1):MatchProbability(ByteL1)"); | |
| 42 | + Globals->abbreviations.insert("DownloadFaceRecognition", "Download+Open+ROI+Cvt(Gray)+Cascade(FrontalFace)+FaceRecognitionRegistration+<FaceRecognitionExtraction>+<FaceRecognitionEmbedding>+<FaceRecognitionQuantization>+SetMetadata(AlgorithmID,-1):Unit(ByteL1)"); | |
| 43 | 43 | Globals->abbreviations.insert("OpenBR", "FaceRecognition"); |
| 44 | 44 | Globals->abbreviations.insert("GenderEstimation", "GenderClassification"); |
| 45 | 45 | Globals->abbreviations.insert("AgeEstimation", "AgeRegression"); |
| ... | ... | @@ -50,7 +50,7 @@ class AlgorithmsInitializer : public Initializer |
| 50 | 50 | // Video |
| 51 | 51 | Globals->abbreviations.insert("DisplayVideo", "FPSLimit(30)+Show(false,[FrameNumber])+Discard"); |
| 52 | 52 | Globals->abbreviations.insert("PerFrameDetection", "SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+ASEFEyes+RestoreMat(original)+Draw(inPlace=true)+Show(false,[FrameNumber])+Discard"); |
| 53 | - Globals->abbreviations.insert("AgeGenderDemo", "SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+Expand+<FaceClassificationRegistration>+<FaceClassificationExtraction>+<AgeRegressor>/<GenderClassifier>+Discard+RestoreMat(original)+Draw(inPlace=true)+DrawPropertiesPoint([Age,Gender],Affine_0,inPlace=true)+SaveMat(original)+Discard+Contract+RestoreMat(original)+FPSCalc+Show(false,[AvgFPS,Age,Gender])+Discard"); | |
| 53 | + Globals->abbreviations.insert("AgeGenderDemo", "SaveMat(original)+Cvt(Gray)+Cascade(FrontalFace)+Expand+FaceClassificationRegistration+<FaceClassificationExtraction>+<AgeRegressor>/<GenderClassifier>+Discard+RestoreMat(original)+Draw(inPlace=true)+DrawPropertiesPoint([Age,Gender],Affine_0,inPlace=true)+SaveMat(original)+Discard+Contract+RestoreMat(original)+FPSCalc+Show(false,[AvgFPS,Age,Gender])+Discard"); | |
| 54 | 54 | Globals->abbreviations.insert("ShowOpticalFlowField", "SaveMat(original)+AggregateFrames(2)+OpticalFlow(useMagnitude=false)+Grid(100,100)+DrawOpticalFlow+FPSLimit(30)+Show(false)+Discard"); |
| 55 | 55 | Globals->abbreviations.insert("ShowOpticalFlowMagnitude", "AggregateFrames(2)+OpticalFlow+Normalize(Range,false,0,255)+Cvt(Color)+Draw+FPSLimit(30)+Show(false)+Discard"); |
| 56 | 56 | Globals->abbreviations.insert("ShowMotionSegmentation", "DropFrames(5)+AggregateFrames(2)+OpticalFlow+CvtUChar+WatershedSegmentation+DrawSegmentation+Draw+FPSLimit(30)+Show(false)+Discard"); |
| ... | ... | @@ -92,11 +92,11 @@ class AlgorithmsInitializer : public Initializer |
| 92 | 92 | Globals->abbreviations.insert("DenseHOG", "Gradient+RectRegions(8,8,6,6)+Bin(0,360,8)+Hist(8)"); |
| 93 | 93 | Globals->abbreviations.insert("DenseSIFT", "(Grid(10,10)+SIFTDescriptor(12)+ByRow)"); |
| 94 | 94 | Globals->abbreviations.insert("DenseSIFT2", "(Grid(5,5)+SIFTDescriptor(12)+ByRow)"); |
| 95 | - Globals->abbreviations.insert("FaceRecognitionRegistration", "(ASEFEyes+Affine(88,88,0.25,0.35)+DownsampleTraining(FTE(DFFS),instances=1))"); | |
| 95 | + Globals->abbreviations.insert("FaceRecognitionRegistration", "ASEFEyes+Affine(88,88,0.25,0.35)"); | |
| 96 | 96 | Globals->abbreviations.insert("FaceRecognitionExtraction", "(Mask+DenseSIFT/DenseLBP+DownsampleTraining(PCA(0.95),instances=1)+Normalize(L2)+Cat)"); |
| 97 | 97 | Globals->abbreviations.insert("FaceRecognitionEmbedding", "(Dup(12)+RndSubspace(0.05,1)+DownsampleTraining(LDA(0.98),instances=-2)+Cat+DownsampleTraining(PCA(768),instances=1))"); |
| 98 | 98 | Globals->abbreviations.insert("FaceRecognitionQuantization", "(Normalize(L1)+Quantize)"); |
| 99 | - Globals->abbreviations.insert("FaceClassificationRegistration", "(ASEFEyes+Affine(56,72,0.33,0.45)+FTE(DFFS))"); | |
| 99 | + Globals->abbreviations.insert("FaceClassificationRegistration", "ASEFEyes+Affine(56,72,0.33,0.45)"); | |
| 100 | 100 | Globals->abbreviations.insert("FaceClassificationExtraction", "((Grid(7,7)+SIFTDescriptor(8)+ByRow)/DenseLBP+DownsampleTraining(PCA(0.95),instances=-1, inputVariable=Gender)+Cat)"); |
| 101 | 101 | Globals->abbreviations.insert("AgeRegressor", "DownsampleTraining(Center(Range),instances=-1, inputVariable=Age)+DownsampleTraining(SVM(RBF,EPS_SVR,inputVariable=Age),instances=100, inputVariable=Age)"); |
| 102 | 102 | Globals->abbreviations.insert("GenderClassifier", "DownsampleTraining(Center(Range),instances=-1, inputVariable=Gender)+DownsampleTraining(SVM(RBF,C_SVC,inputVariable=Gender),instances=4000, inputVariable=Gender)"); | ... | ... |
openbr/plugins/cascade.cpp
| ... | ... | @@ -252,6 +252,8 @@ class CascadeTransform : public MetaTransform |
| 252 | 252 | void init() |
| 253 | 253 | { |
| 254 | 254 | cascadeResource.setResourceMaker(new CascadeResourceMaker(model)); |
| 255 | + if (model == "Ear" || model == "Eye" || model == "FrontalFace" || model == "ProfileFace") | |
| 256 | + this->trainable = false; | |
| 255 | 257 | } |
| 256 | 258 | |
| 257 | 259 | // Train transform | ... | ... |
openbr/plugins/meta.cpp
| ... | ... | @@ -17,6 +17,8 @@ |
| 17 | 17 | #include <QFutureSynchronizer> |
| 18 | 18 | #include <QRegularExpression> |
| 19 | 19 | #include <QtConcurrentRun> |
| 20 | +#include <qbuffer.h> | |
| 21 | + | |
| 20 | 22 | #include "openbr_internal.h" |
| 21 | 23 | #include "openbr/core/common.h" |
| 22 | 24 | #include "openbr/core/opencvutils.h" |
| ... | ... | @@ -94,17 +96,15 @@ class PipeTransform : public CompositeTransform |
| 94 | 96 | |
| 95 | 97 | int i = 0; |
| 96 | 98 | while (i < transforms.size()) { |
| 97 | - fprintf(stderr, "\n%s", qPrintable(transforms[i]->objectName())); | |
| 98 | - | |
| 99 | 99 | // Conditional statement covers likely case that first transform is untrainable |
| 100 | 100 | if (transforms[i]->trainable) { |
| 101 | - fprintf(stderr, " training..."); | |
| 101 | + qDebug() << "Training " << transforms[i]->description() << "\n..."; | |
| 102 | 102 | transforms[i]->train(dataLines); |
| 103 | 103 | } |
| 104 | 104 | |
| 105 | 105 | // if the transform is time varying, we can't project it in parallel |
| 106 | 106 | if (transforms[i]->timeVarying()) { |
| 107 | - fprintf(stderr, "\n%s projecting...", qPrintable(transforms[i]->objectName())); | |
| 107 | + qDebug() << "Projecting " << transforms[i]->description() << "\n..."; | |
| 108 | 108 | for (int j=0; j < dataLines.size();j++) { |
| 109 | 109 | TemplateList junk; |
| 110 | 110 | splitFTEs(dataLines[j], junk); |
| ... | ... | @@ -130,7 +130,16 @@ class PipeTransform : public CompositeTransform |
| 130 | 130 | !transforms[nextTrainableTransform]->timeVarying()) |
| 131 | 131 | nextTrainableTransform++; |
| 132 | 132 | |
| 133 | - fprintf(stderr, " projecting..."); | |
| 133 | + // No more trainable transforms? Don't need any more projects then | |
| 134 | + if (nextTrainableTransform == transforms.size()) | |
| 135 | + break; | |
| 136 | + | |
| 137 | + fprintf(stderr, "Projecting %s", qPrintable(transforms[i]->description())); | |
| 138 | + for (int j=i+1; j < nextTrainableTransform; j++) | |
| 139 | + fprintf(stderr,"+%s", qPrintable(transforms[j]->description())); | |
| 140 | + fprintf(stderr, "\n...\n"); | |
| 141 | + fflush(stderr); | |
| 142 | + | |
| 134 | 143 | QFutureSynchronizer<void> futures; |
| 135 | 144 | for (int j=0; j < dataLines.size(); j++) |
| 136 | 145 | futures.addFuture(QtConcurrent::run(this, &PipeTransform::_projectPartial, &dataLines[j], i, nextTrainableTransform)); |
| ... | ... | @@ -510,7 +519,6 @@ class LoadStoreTransform : public MetaTransform |
| 510 | 519 | |
| 511 | 520 | public: |
| 512 | 521 | Transform *transform; |
| 513 | - QString baseName; | |
| 514 | 522 | |
| 515 | 523 | LoadStoreTransform() : transform(NULL) {} |
| 516 | 524 | |
| ... | ... | @@ -540,8 +548,8 @@ private: |
| 540 | 548 | void init() |
| 541 | 549 | { |
| 542 | 550 | if (transform != NULL) return; |
| 543 | - if (fileName.isEmpty()) baseName = QRegExp("^[_a-zA-Z0-9]+$").exactMatch(transformString) ? transformString : QtUtils::shortTextHash(transformString); | |
| 544 | - else baseName = fileName; | |
| 551 | + if (fileName.isEmpty()) fileName = QRegExp("^[_a-zA-Z0-9]+$").exactMatch(transformString) ? transformString : QtUtils::shortTextHash(transformString); | |
| 552 | + | |
| 545 | 553 | if (!tryLoad()) |
| 546 | 554 | transform = make(transformString); |
| 547 | 555 | else |
| ... | ... | @@ -553,19 +561,28 @@ private: |
| 553 | 561 | return transform->timeVarying(); |
| 554 | 562 | } |
| 555 | 563 | |
| 556 | - void train(const TemplateList &data) | |
| 564 | + void train(const QList<TemplateList> &data) | |
| 557 | 565 | { |
| 558 | 566 | if (QFileInfo(getFileName()).exists()) |
| 559 | 567 | return; |
| 560 | 568 | |
| 561 | 569 | transform->train(data); |
| 562 | 570 | |
| 563 | - qDebug("Storing %s", qPrintable(baseName)); | |
| 564 | - QByteArray byteArray; | |
| 565 | - QDataStream stream(&byteArray, QFile::WriteOnly); | |
| 566 | - stream << transform->description(); | |
| 571 | + qDebug("Storing %s", qPrintable(fileName)); | |
| 572 | + QtUtils::BlockCompression compressedOut; | |
| 573 | + QFile fout(fileName); | |
| 574 | + QtUtils::touchDir(fout); | |
| 575 | + compressedOut.setBasis(&fout); | |
| 576 | + | |
| 577 | + QDataStream stream(&compressedOut); | |
| 578 | + QString desc = transform->description(); | |
| 579 | + | |
| 580 | + if (!compressedOut.open(QFile::WriteOnly)) | |
| 581 | + qFatal("Failed to open %s for writing.", qPrintable(file)); | |
| 582 | + | |
| 583 | + stream << desc; | |
| 567 | 584 | transform->store(stream); |
| 568 | - QtUtils::writeFile(baseName, byteArray, -1); | |
| 585 | + compressedOut.close(); | |
| 569 | 586 | } |
| 570 | 587 | |
| 571 | 588 | void project(const Template &src, Template &dst) const |
| ... | ... | @@ -595,8 +612,8 @@ private: |
| 595 | 612 | |
| 596 | 613 | QString getFileName() const |
| 597 | 614 | { |
| 598 | - if (QFileInfo(baseName).exists()) return baseName; | |
| 599 | - const QString file = Globals->sdkPath + "/share/openbr/models/transforms/" + baseName; | |
| 615 | + if (QFileInfo(fileName).exists()) return fileName; | |
| 616 | + const QString file = Globals->sdkPath + "/share/openbr/models/transforms/" + fileName; | |
| 600 | 617 | return QFileInfo(file).exists() ? file : QString(); |
| 601 | 618 | } |
| 602 | 619 | |
| ... | ... | @@ -606,12 +623,19 @@ private: |
| 606 | 623 | if (file.isEmpty()) return false; |
| 607 | 624 | |
| 608 | 625 | qDebug("Loading %s", qPrintable(file)); |
| 609 | - QByteArray data; | |
| 610 | - QtUtils::readFile(file, data, true); | |
| 611 | - QDataStream stream(&data, QFile::ReadOnly); | |
| 626 | + QFile fin(file); | |
| 627 | + QtUtils::BlockCompression reader(&fin); | |
| 628 | + if (!reader.open(QIODevice::ReadOnly)) { | |
| 629 | + if (QFileInfo(file).exists()) qFatal("Unable to open %s for reading. Check file permissions.", qPrintable(file)); | |
| 630 | + else qFatal("Unable to open %s for reading. File does not exist.", qPrintable(file)); | |
| 631 | + } | |
| 632 | + | |
| 633 | + QDataStream stream(&reader); | |
| 612 | 634 | stream >> transformString; |
| 635 | + | |
| 613 | 636 | transform = Transform::make(transformString); |
| 614 | 637 | transform->load(stream); |
| 638 | + | |
| 615 | 639 | return true; |
| 616 | 640 | } |
| 617 | 641 | }; | ... | ... |
openbr/plugins/quality.cpp
| ... | ... | @@ -77,6 +77,12 @@ class ImpostorUniquenessMeasureTransform : public Transform |
| 77 | 77 | |
| 78 | 78 | BR_REGISTER(Transform, ImpostorUniquenessMeasureTransform) |
| 79 | 79 | |
| 80 | + | |
| 81 | +float KDEPointer(const QList<float> *scores, double x, double h) | |
| 82 | +{ | |
| 83 | + return Common::KernelDensityEstimation(*scores, x, h); | |
| 84 | +} | |
| 85 | + | |
| 80 | 86 | /* Kernel Density Estimator */ |
| 81 | 87 | struct KDE |
| 82 | 88 | { |
| ... | ... | @@ -85,20 +91,35 @@ struct KDE |
| 85 | 91 | QList<float> bins; |
| 86 | 92 | |
| 87 | 93 | KDE() : min(0), max(1), mean(0), stddev(1) {} |
| 88 | - KDE(const QList<float> &scores) | |
| 94 | + | |
| 95 | + KDE(const QList<float> &scores, bool trainKDE) | |
| 89 | 96 | { |
| 90 | 97 | Common::MinMax(scores, &min, &max); |
| 91 | 98 | Common::MeanStdDev(scores, &mean, &stddev); |
| 99 | + | |
| 100 | + if (!trainKDE) | |
| 101 | + return; | |
| 102 | + | |
| 92 | 103 | double h = Common::KernelDensityBandwidth(scores); |
| 93 | 104 | const int size = 255; |
| 94 | 105 | bins.reserve(size); |
| 95 | - for (int i=0; i<size; i++) | |
| 96 | - bins.append(Common::KernelDensityEstimation(scores, min + (max-min)*i/(size-1), h)); | |
| 106 | + | |
| 107 | + QFutureSynchronizer<float> futures; | |
| 108 | + | |
| 109 | + for (int i=0; i < size; i++) | |
| 110 | + futures.addFuture(QtConcurrent::run(KDEPointer, &scores, min + (max-min)*i/(size-1), h)); | |
| 111 | + futures.waitForFinished(); | |
| 112 | + | |
| 113 | + foreach(const QFuture<float> & future, futures.futures()) | |
| 114 | + bins.append(future.result()); | |
| 97 | 115 | } |
| 98 | 116 | |
| 99 | 117 | float operator()(float score, bool gaussian = true) const |
| 100 | 118 | { |
| 101 | 119 | if (gaussian) return 1/(stddev*sqrt(2*CV_PI))*exp(-0.5*pow((score-mean)/stddev, 2)); |
| 120 | + if (bins.empty()) | |
| 121 | + return -std::numeric_limits<float>::max(); | |
| 122 | + | |
| 102 | 123 | if (score <= min) return bins.first(); |
| 103 | 124 | if (score >= max) return bins.last(); |
| 104 | 125 | const float x = (score-min)/(max-min)*bins.size(); |
| ... | ... | @@ -123,8 +144,8 @@ struct MP |
| 123 | 144 | { |
| 124 | 145 | KDE genuine, impostor; |
| 125 | 146 | MP() {} |
| 126 | - MP(const QList<float> &genuineScores, const QList<float> &impostorScores) | |
| 127 | - : genuine(genuineScores), impostor(impostorScores) {} | |
| 147 | + MP(const QList<float> &genuineScores, const QList<float> &impostorScores, bool trainKDE) | |
| 148 | + : genuine(genuineScores, trainKDE), impostor(impostorScores, trainKDE) {} | |
| 128 | 149 | float operator()(float score, bool gaussian = true) const |
| 129 | 150 | { |
| 130 | 151 | const float g = genuine(score, gaussian); |
| ... | ... | @@ -165,7 +186,7 @@ class MatchProbabilityDistance : public Distance |
| 165 | 186 | const QList<int> labels = src.indexProperty(inputVariable); |
| 166 | 187 | QScopedPointer<MatrixOutput> matrixOutput(MatrixOutput::make(FileList(src.size()), FileList(src.size()))); |
| 167 | 188 | distance->compare(src, src, matrixOutput.data()); |
| 168 | - | |
| 189 | + | |
| 169 | 190 | QList<float> genuineScores, impostorScores; |
| 170 | 191 | genuineScores.reserve(labels.size()); |
| 171 | 192 | impostorScores.reserve(labels.size()*labels.size()); |
| ... | ... | @@ -178,8 +199,8 @@ class MatchProbabilityDistance : public Distance |
| 178 | 199 | else impostorScores.append(score); |
| 179 | 200 | } |
| 180 | 201 | } |
| 181 | - | |
| 182 | - mp = MP(genuineScores, impostorScores); | |
| 202 | + | |
| 203 | + mp = MP(genuineScores, impostorScores, !gaussian); | |
| 183 | 204 | } |
| 184 | 205 | |
| 185 | 206 | float compare(const Template &target, const Template &query) const | ... | ... |