Commit fe840eb4701968fede4cd0d0dbbd276f6f3b2328
1 parent
032988c3
work towards jit-ing preallocate function
Showing
1 changed file
with
45 additions
and
4 deletions
sdk/plugins/llvm.cpp
| ... | ... | @@ -257,6 +257,7 @@ class UnaryKernel : public UntrainableMetaTransform |
| 257 | 257 | public: |
| 258 | 258 | UnaryKernel() : kernel(NULL), hash(0) {} |
| 259 | 259 | virtual int preallocate(const jit_matrix &src, jit_matrix &dst) const = 0; /*!< Preallocate destintation matrix based on source matrix. */ |
| 260 | + virtual Value *buildPreallocate(const MatrixBuilder &src, const MatrixBuilder &dst) const { (void) src; (void) dst; return MatrixBuilder::constant(0); } | |
| 260 | 261 | virtual void build(const MatrixBuilder &src, const MatrixBuilder &dst, PHINode *i) const = 0; /*!< Build the kernel. */ |
| 261 | 262 | |
| 262 | 263 | void apply(const jit_matrix &src, jit_matrix &dst) const |
| ... | ... | @@ -267,17 +268,50 @@ public: |
| 267 | 268 | } |
| 268 | 269 | |
| 269 | 270 | private: |
| 270 | - QString mangledName(const jit_matrix &src) const | |
| 271 | + QString mangledName() const | |
| 271 | 272 | { |
| 272 | 273 | static QHash<QString, int> argsLUT; |
| 273 | 274 | const QString args = arguments().join(","); |
| 274 | 275 | if (!argsLUT.contains(args)) argsLUT.insert(args, argsLUT.size()); |
| 275 | 276 | int uid = argsLUT.value(args); |
| 276 | - return "jitcv_" + name().remove("Transform") + (args.isEmpty() ? QString() : QString::number(uid)) + "_" + MatrixToString(src); | |
| 277 | + return "jitcv_" + name().remove("Transform") + (args.isEmpty() ? QString() : QString::number(uid)); | |
| 278 | + } | |
| 279 | + | |
| 280 | + QString mangledName(const jit_matrix &src) const | |
| 281 | + { | |
| 282 | + return mangledName() + "_" + MatrixToString(src); | |
| 277 | 283 | } |
| 278 | 284 | |
| 279 | 285 | Function *compile(const jit_matrix &m) const |
| 280 | 286 | { |
| 287 | + Constant *c = TheModule->getOrInsertFunction(qPrintable(mangledName()), | |
| 288 | + Type::getVoidTy(getGlobalContext()), | |
| 289 | + PointerType::getUnqual(TheMatrixStruct), | |
| 290 | + PointerType::getUnqual(TheMatrixStruct), | |
| 291 | + NULL); | |
| 292 | + | |
| 293 | + Function *function = cast<Function>(c); | |
| 294 | + function->setCallingConv(CallingConv::C); | |
| 295 | + | |
| 296 | + Function::arg_iterator args = function->arg_begin(); | |
| 297 | + Value *src = args++; | |
| 298 | + src->setName("src"); | |
| 299 | + Value *dst = args++; | |
| 300 | + dst->setName("dst"); | |
| 301 | + | |
| 302 | + BasicBlock *entry = BasicBlock::Create(getGlobalContext(), "entry", function); | |
| 303 | + IRBuilder<> builder(entry); | |
| 304 | + | |
| 305 | + Function *kernel = compileKernel(m); | |
| 306 | + builder.CreateCall3(kernel, src, dst, buildPreallocate(MatrixBuilder(m, src, &builder, function, "src"), MatrixBuilder(m, dst, &builder, function, "dst"))); | |
| 307 | + | |
| 308 | + builder.CreateRetVoid(); | |
| 309 | + | |
| 310 | + return kernel; | |
| 311 | + } | |
| 312 | + | |
| 313 | + Function *compileKernel(const jit_matrix &m) const | |
| 314 | + { | |
| 281 | 315 | Constant *c = TheModule->getOrInsertFunction(qPrintable(mangledName(m)), |
| 282 | 316 | Type::getVoidTy(getGlobalContext()), |
| 283 | 317 | PointerType::getUnqual(TheMatrixStruct), |
| ... | ... | @@ -595,6 +629,13 @@ class sumTransform : public UnaryKernel |
| 595 | 629 | return dst.elements(); |
| 596 | 630 | } |
| 597 | 631 | |
| 632 | + Value *buildPreallocate(const MatrixBuilder &src, const MatrixBuilder &dst) const | |
| 633 | + { | |
| 634 | + (void) src; | |
| 635 | + (void) dst; | |
| 636 | + return MatrixBuilder::constant(0); | |
| 637 | + } | |
| 638 | + | |
| 598 | 639 | void build(const MatrixBuilder &src, const MatrixBuilder &dst, PHINode *i) const |
| 599 | 640 | { |
| 600 | 641 | Value *c, *x, *y, *t; |
| ... | ... | @@ -855,7 +896,7 @@ class LLVMInitializer : public Initializer |
| 855 | 896 | TheFunctionPassManager->add(createDeadInstEliminationPass()); |
| 856 | 897 | |
| 857 | 898 | TheExtraFunctionPassManager = new FunctionPassManager(TheModule); |
| 858 | - TheExtraFunctionPassManager->add(createPrintFunctionPass("--------------------------------------------------------------------------------", &errs())); | |
| 899 | +// TheExtraFunctionPassManager->add(createPrintFunctionPass("--------------------------------------------------------------------------------", &errs())); | |
| 859 | 900 | // TheExtraFunctionPassManager->add(createLoopUnrollPass(INT_MAX,8)); |
| 860 | 901 | |
| 861 | 902 | TheMatrixStruct = StructType::create("Matrix", |
| ... | ... | @@ -867,7 +908,7 @@ class LLVMInitializer : public Initializer |
| 867 | 908 | Type::getInt16Ty(getGlobalContext()), // hash |
| 868 | 909 | NULL); |
| 869 | 910 | |
| 870 | - QSharedPointer<Transform> kernel(Transform::make("abs", NULL)); | |
| 911 | + QSharedPointer<Transform> kernel(Transform::make("sum", NULL)); | |
| 871 | 912 | |
| 872 | 913 | Template src, dst; |
| 873 | 914 | src.m() = (Mat_<qint8>(2,2) << -1, -2, 3, 4); | ... | ... |