Commit 8bde0803a299150989d7ebe2bc682ff1c0583423

Authored by Josh Klontz
1 parent e9030443

substantial llvm progress

sdk/plugins/llvm.cpp
... ... @@ -99,19 +99,52 @@ struct MatrixBuilder : public jit_matrix
99 99 MatrixBuilder(const jit_matrix &matrix, Value *value, IRBuilder<> *builder, Function *function, const Twine &name_)
100 100 : jit_matrix(matrix), m(value), b(builder), f(function), name(name_) {}
101 101  
102   - static Value *zero() { return constant(0); }
103   - static Value *one() { return constant(1); }
104   - static Value *constant(int value, int bits = 32) { return Constant::getIntegerValue(Type::getInt32Ty(getGlobalContext()), APInt(bits, value)); }
105   - static Value *constant(float value) { return ConstantFP::get(Type::getFloatTy(getGlobalContext()), value == 0 ? -0.0f : value); }
106   - static Value *constant(double value) { return ConstantFP::get(Type::getDoubleTy(getGlobalContext()), value == 0 ? -0.0 : value); }
107   - Value *autoConstant(double value) const { return isFloating() ? ((bits() == 64) ? constant(value) : constant(float(value))) : constant(int(value), bits()); }
  102 + static Constant *zero() { return constant(0); }
  103 + static Constant *one() { return constant(1); }
  104 + static Constant *constant(int value, int bits = 32) { return Constant::getIntegerValue(Type::getInt32Ty(getGlobalContext()), APInt(bits, value)); }
  105 + static Constant *constant(float value) { return ConstantFP::get(Type::getFloatTy(getGlobalContext()), value == 0 ? -0.0f : value); }
  106 + static Constant *constant(double value) { return ConstantFP::get(Type::getDoubleTy(getGlobalContext()), value == 0 ? -0.0 : value); }
  107 + Constant *autoConstant(double value) const { return isFloating() ? ((bits() == 64) ? constant(value) : constant(float(value))) : constant(int(value), bits()); }
108 108 AllocaInst *autoAlloca(double value, const Twine &name = "") const { AllocaInst *alloca = b->CreateAlloca(ty(), 0, name); b->CreateStore(autoConstant(value), alloca); return alloca; }
109 109  
110   - Value *getData() const { return b->CreatePointerCast(b->CreateLoad(b->CreateStructGEP(m, 0)), ptrTy(), name+"_data"); }
111   - Value *getChannels() const { return singleChannel() ? one() : b->CreateLoad(b->CreateStructGEP(m, 1), name+"_channels"); }
112   - Value *getColumns() const { return singleColumn() ? one() : b->CreateLoad(b->CreateStructGEP(m, 2), name+"_columns"); }
113   - Value *getRows() const { return singleRow() ? one() : b->CreateLoad(b->CreateStructGEP(m, 3), name+"_rows"); }
114   - Value *getFrames() const { return singleFrame() ? one() : b->CreateLoad(b->CreateStructGEP(m, 4), name+"_frames"); }
  110 + Value *getData(bool cast = true) const { LoadInst *data = b->CreateLoad(b->CreateStructGEP(m, 0), name+"_data"); return cast ? b->CreatePointerCast(data, ptrTy()) : data; }
  111 + Value *getChannels() const { return singleChannel() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(m, 1), name+"_channels")); }
  112 + Value *getColumns() const { return singleColumn() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(m, 2), name+"_columns")); }
  113 + Value *getRows() const { return singleRow() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(m, 3), name+"_rows")); }
  114 + Value *getFrames() const { return singleFrame() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(m, 4), name+"_frames")); }
  115 + Value *getHash() const { return b->CreateLoad(b->CreateStructGEP(m, 5), name+"_hash"); }
  116 +
  117 + void setData(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 0)); }
  118 + void setChannels(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 1)); }
  119 + void setColumns(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 2)); }
  120 + void setRows(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 3)); }
  121 + void setFrames(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 4)); }
  122 + void setHash(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 5)); }
  123 +
  124 + void copyHeaderCode(const MatrixBuilder &other) const {
  125 + setChannels(other.getChannels());
  126 + setRows(other.getRows());
  127 + setFrames(other.getFrames());
  128 + setHash(other.getHash());
  129 + }
  130 +
  131 + void allocate() const {
  132 + Function *malloc = TheModule->getFunction("malloc");
  133 + if (!malloc) {
  134 + PointerType *mallocReturn = Type::getInt8PtrTy(getGlobalContext());
  135 + std::vector<Type*> mallocParams;
  136 + mallocParams.push_back(Type::getInt32Ty(getGlobalContext()));
  137 + FunctionType* mallocType = FunctionType::get(mallocReturn, mallocParams, false);
  138 + malloc = Function::Create(mallocType, GlobalValue::ExternalLinkage, "malloc");
  139 + malloc->setCallingConv(CallingConv::C);
  140 + }
  141 +
  142 + std::vector<Value*> mallocArgs;
  143 + mallocArgs.push_back(elementsCode()); // TODO: FIX
  144 + setData(b->CreateCall(malloc, mallocArgs));
  145 + }
  146 +
  147 + Value *elementsCode() const { return b->CreateMul(b->CreateMul(b->CreateMul(getChannels(), getColumns()), getRows()), getFrames()); }
115 148  
116 149 Value *columnStep() const { Value *columnStep = getChannels(); columnStep->setName(name+"_cStep"); return columnStep; }
117 150 Value *rowStep() const { return b->CreateMul(getColumns(), columnStep(), name+"_rStep"); }
... ... @@ -296,6 +329,8 @@ private:
296 329  
297 330 Function *compile(const jit_matrix &m) const
298 331 {
  332 + Function *kernel = compileKernel(m);
  333 +
299 334 Constant *c = TheModule->getOrInsertFunction(qPrintable(mangledName()),
300 335 Type::getVoidTy(getGlobalContext()),
301 336 PointerType::getUnqual(TheMatrixStruct),
... ... @@ -313,10 +348,38 @@ private:
313 348  
314 349 BasicBlock *entry = BasicBlock::Create(getGlobalContext(), "entry", function);
315 350 IRBuilder<> builder(entry);
316   -
317   - Function *kernel = compileKernel(m);
318   - builder.CreateCall3(kernel, src, dst, buildPreallocate(MatrixBuilder(m, src, &builder, function, "src"), MatrixBuilder(m, dst, &builder, function, "dst")));
319   -
  351 + MatrixBuilder mb(m, src, &builder, function, "src");
  352 + MatrixBuilder nb(m, dst, &builder, function, "dst");
  353 +
  354 + std::vector<Type*> kernelArgs;
  355 + kernelArgs.push_back(PointerType::getUnqual(TheMatrixStruct));
  356 + kernelArgs.push_back(PointerType::getUnqual(TheMatrixStruct));
  357 + kernelArgs.push_back(Type::getInt32Ty(getGlobalContext()));
  358 + PointerType *kernelType = PointerType::getUnqual(FunctionType::get(Type::getVoidTy(getGlobalContext()), kernelArgs, false));
  359 + QString kernelFunctionName = mangledName()+"_kernel";
  360 + TheModule->getOrInsertGlobal(qPrintable(kernelFunctionName), kernelType);
  361 + GlobalVariable *kernelFunction = TheModule->getGlobalVariable(qPrintable(kernelFunctionName));
  362 + kernelFunction->setInitializer(ConstantPointerNull::get(kernelType));
  363 +
  364 + QString kernelHashName = mangledName()+"_hash";
  365 + TheModule->getOrInsertGlobal(qPrintable(kernelHashName), Type::getInt16Ty(getGlobalContext()));
  366 + GlobalVariable *kernelHash = TheModule->getGlobalVariable(qPrintable(kernelHashName));
  367 + kernelHash->setInitializer(MatrixBuilder::constant(0, 16));
  368 +
  369 + BasicBlock *getKernel = BasicBlock::Create(getGlobalContext(), "get_kernel", function);
  370 + BasicBlock *preallocate = BasicBlock::Create(getGlobalContext(), "preallocate", function);
  371 + Value *hashTest = builder.CreateICmpNE(mb.getHash(), kernelHash, "hash_fail_test");
  372 + builder.CreateCondBr(hashTest, getKernel, preallocate);
  373 +
  374 + builder.SetInsertPoint(getKernel);
  375 + builder.CreateStore(kernel, kernelFunction);
  376 + builder.CreateStore(mb.getHash(), kernelHash);
  377 + builder.CreateBr(preallocate);
  378 + builder.SetInsertPoint(preallocate);
  379 + Value *kernelSize = buildPreallocate(mb, nb);
  380 + nb.allocate();
  381 +
  382 + builder.CreateCall3(builder.CreateLoad(kernelFunction), src, dst, kernelSize);
320 383 builder.CreateRetVoid();
321 384  
322 385 return kernel;
... ... @@ -495,6 +558,12 @@ public:
495 558 return dst.elements();
496 559 }
497 560  
  561 + virtual Value *buildPreallocate(const MatrixBuilder &src, const MatrixBuilder &dst) const
  562 + {
  563 + dst.copyHeaderCode(src);
  564 + return dst.elementsCode();
  565 + }
  566 +
498 567 private:
499 568 void build(const MatrixBuilder &src, const MatrixBuilder &dst, PHINode *i) const
500 569 {
... ... @@ -909,7 +978,7 @@ class LLVMInitializer : public Initializer
909 978 Type::getInt16Ty(getGlobalContext()), // hash
910 979 NULL);
911 980  
912   - QSharedPointer<Transform> kernel(Transform::make("sum", NULL));
  981 + QSharedPointer<Transform> kernel(Transform::make("add(1)", NULL));
913 982  
914 983 Template src, dst;
915 984 src.m() = (Mat_<qint8>(2,2) << -1, -2, 3, 4);
... ...
1   -Subproject commit f6402dc42e513a1850ec73e675f62cbd15b26e31
  1 +Subproject commit 8f74266f98653627b04ae9a8e7f83d4dd63b4f5d
... ...