Commit fe840eb4701968fede4cd0d0dbbd276f6f3b2328

Authored by Josh Klontz
1 parent 032988c3

work towards jit-ing preallocate function

Showing 1 changed file with 45 additions and 4 deletions
sdk/plugins/llvm.cpp
... ... @@ -257,6 +257,7 @@ class UnaryKernel : public UntrainableMetaTransform
257 257 public:
258 258 UnaryKernel() : kernel(NULL), hash(0) {}
259 259 virtual int preallocate(const jit_matrix &src, jit_matrix &dst) const = 0; /*!< Preallocate destintation matrix based on source matrix. */
  260 + virtual Value *buildPreallocate(const MatrixBuilder &src, const MatrixBuilder &dst) const { (void) src; (void) dst; return MatrixBuilder::constant(0); }
260 261 virtual void build(const MatrixBuilder &src, const MatrixBuilder &dst, PHINode *i) const = 0; /*!< Build the kernel. */
261 262  
262 263 void apply(const jit_matrix &src, jit_matrix &dst) const
... ... @@ -267,17 +268,50 @@ public:
267 268 }
268 269  
269 270 private:
270   - QString mangledName(const jit_matrix &src) const
  271 + QString mangledName() const
271 272 {
272 273 static QHash<QString, int> argsLUT;
273 274 const QString args = arguments().join(",");
274 275 if (!argsLUT.contains(args)) argsLUT.insert(args, argsLUT.size());
275 276 int uid = argsLUT.value(args);
276   - return "jitcv_" + name().remove("Transform") + (args.isEmpty() ? QString() : QString::number(uid)) + "_" + MatrixToString(src);
  277 + return "jitcv_" + name().remove("Transform") + (args.isEmpty() ? QString() : QString::number(uid));
  278 + }
  279 +
  280 + QString mangledName(const jit_matrix &src) const
  281 + {
  282 + return mangledName() + "_" + MatrixToString(src);
277 283 }
278 284  
279 285 Function *compile(const jit_matrix &m) const
280 286 {
  287 + Constant *c = TheModule->getOrInsertFunction(qPrintable(mangledName()),
  288 + Type::getVoidTy(getGlobalContext()),
  289 + PointerType::getUnqual(TheMatrixStruct),
  290 + PointerType::getUnqual(TheMatrixStruct),
  291 + NULL);
  292 +
  293 + Function *function = cast<Function>(c);
  294 + function->setCallingConv(CallingConv::C);
  295 +
  296 + Function::arg_iterator args = function->arg_begin();
  297 + Value *src = args++;
  298 + src->setName("src");
  299 + Value *dst = args++;
  300 + dst->setName("dst");
  301 +
  302 + BasicBlock *entry = BasicBlock::Create(getGlobalContext(), "entry", function);
  303 + IRBuilder<> builder(entry);
  304 +
  305 + Function *kernel = compileKernel(m);
  306 + builder.CreateCall3(kernel, src, dst, buildPreallocate(MatrixBuilder(m, src, &builder, function, "src"), MatrixBuilder(m, dst, &builder, function, "dst")));
  307 +
  308 + builder.CreateRetVoid();
  309 +
  310 + return kernel;
  311 + }
  312 +
  313 + Function *compileKernel(const jit_matrix &m) const
  314 + {
281 315 Constant *c = TheModule->getOrInsertFunction(qPrintable(mangledName(m)),
282 316 Type::getVoidTy(getGlobalContext()),
283 317 PointerType::getUnqual(TheMatrixStruct),
... ... @@ -595,6 +629,13 @@ class sumTransform : public UnaryKernel
595 629 return dst.elements();
596 630 }
597 631  
  632 + Value *buildPreallocate(const MatrixBuilder &src, const MatrixBuilder &dst) const
  633 + {
  634 + (void) src;
  635 + (void) dst;
  636 + return MatrixBuilder::constant(0);
  637 + }
  638 +
598 639 void build(const MatrixBuilder &src, const MatrixBuilder &dst, PHINode *i) const
599 640 {
600 641 Value *c, *x, *y, *t;
... ... @@ -855,7 +896,7 @@ class LLVMInitializer : public Initializer
855 896 TheFunctionPassManager->add(createDeadInstEliminationPass());
856 897  
857 898 TheExtraFunctionPassManager = new FunctionPassManager(TheModule);
858   - TheExtraFunctionPassManager->add(createPrintFunctionPass("--------------------------------------------------------------------------------", &errs()));
  899 +// TheExtraFunctionPassManager->add(createPrintFunctionPass("--------------------------------------------------------------------------------", &errs()));
859 900 // TheExtraFunctionPassManager->add(createLoopUnrollPass(INT_MAX,8));
860 901  
861 902 TheMatrixStruct = StructType::create("Matrix",
... ... @@ -867,7 +908,7 @@ class LLVMInitializer : public Initializer
867 908 Type::getInt16Ty(getGlobalContext()), // hash
868 909 NULL);
869 910  
870   - QSharedPointer<Transform> kernel(Transform::make("abs", NULL));
  911 + QSharedPointer<Transform> kernel(Transform::make("sum", NULL));
871 912  
872 913 Template src, dst;
873 914 src.m() = (Mat_<qint8>(2,2) << -1, -2, 3, 4);
... ...