diff --git a/sdk/plugins/llvm.cpp b/sdk/plugins/llvm.cpp index 5d80ffa..18cc585 100644 --- a/sdk/plugins/llvm.cpp +++ b/sdk/plugins/llvm.cpp @@ -99,19 +99,52 @@ struct MatrixBuilder : public jit_matrix MatrixBuilder(const jit_matrix &matrix, Value *value, IRBuilder<> *builder, Function *function, const Twine &name_) : jit_matrix(matrix), m(value), b(builder), f(function), name(name_) {} - static Value *zero() { return constant(0); } - static Value *one() { return constant(1); } - static Value *constant(int value, int bits = 32) { return Constant::getIntegerValue(Type::getInt32Ty(getGlobalContext()), APInt(bits, value)); } - static Value *constant(float value) { return ConstantFP::get(Type::getFloatTy(getGlobalContext()), value == 0 ? -0.0f : value); } - static Value *constant(double value) { return ConstantFP::get(Type::getDoubleTy(getGlobalContext()), value == 0 ? -0.0 : value); } - Value *autoConstant(double value) const { return isFloating() ? ((bits() == 64) ? constant(value) : constant(float(value))) : constant(int(value), bits()); } + static Constant *zero() { return constant(0); } + static Constant *one() { return constant(1); } + static Constant *constant(int value, int bits = 32) { return Constant::getIntegerValue(Type::getInt32Ty(getGlobalContext()), APInt(bits, value)); } + static Constant *constant(float value) { return ConstantFP::get(Type::getFloatTy(getGlobalContext()), value == 0 ? -0.0f : value); } + static Constant *constant(double value) { return ConstantFP::get(Type::getDoubleTy(getGlobalContext()), value == 0 ? -0.0 : value); } + Constant *autoConstant(double value) const { return isFloating() ? ((bits() == 64) ? constant(value) : constant(float(value))) : constant(int(value), bits()); } AllocaInst *autoAlloca(double value, const Twine &name = "") const { AllocaInst *alloca = b->CreateAlloca(ty(), 0, name); b->CreateStore(autoConstant(value), alloca); return alloca; } - Value *getData() const { return b->CreatePointerCast(b->CreateLoad(b->CreateStructGEP(m, 0)), ptrTy(), name+"_data"); } - Value *getChannels() const { return singleChannel() ? one() : b->CreateLoad(b->CreateStructGEP(m, 1), name+"_channels"); } - Value *getColumns() const { return singleColumn() ? one() : b->CreateLoad(b->CreateStructGEP(m, 2), name+"_columns"); } - Value *getRows() const { return singleRow() ? one() : b->CreateLoad(b->CreateStructGEP(m, 3), name+"_rows"); } - Value *getFrames() const { return singleFrame() ? one() : b->CreateLoad(b->CreateStructGEP(m, 4), name+"_frames"); } + Value *getData(bool cast = true) const { LoadInst *data = b->CreateLoad(b->CreateStructGEP(m, 0), name+"_data"); return cast ? b->CreatePointerCast(data, ptrTy()) : data; } + Value *getChannels() const { return singleChannel() ? static_cast(one()) : static_cast(b->CreateLoad(b->CreateStructGEP(m, 1), name+"_channels")); } + Value *getColumns() const { return singleColumn() ? static_cast(one()) : static_cast(b->CreateLoad(b->CreateStructGEP(m, 2), name+"_columns")); } + Value *getRows() const { return singleRow() ? static_cast(one()) : static_cast(b->CreateLoad(b->CreateStructGEP(m, 3), name+"_rows")); } + Value *getFrames() const { return singleFrame() ? static_cast(one()) : static_cast(b->CreateLoad(b->CreateStructGEP(m, 4), name+"_frames")); } + Value *getHash() const { return b->CreateLoad(b->CreateStructGEP(m, 5), name+"_hash"); } + + void setData(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 0)); } + void setChannels(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 1)); } + void setColumns(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 2)); } + void setRows(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 3)); } + void setFrames(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 4)); } + void setHash(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 5)); } + + void copyHeaderCode(const MatrixBuilder &other) const { + setChannels(other.getChannels()); + setRows(other.getRows()); + setFrames(other.getFrames()); + setHash(other.getHash()); + } + + void allocate() const { + Function *malloc = TheModule->getFunction("malloc"); + if (!malloc) { + PointerType *mallocReturn = Type::getInt8PtrTy(getGlobalContext()); + std::vector mallocParams; + mallocParams.push_back(Type::getInt32Ty(getGlobalContext())); + FunctionType* mallocType = FunctionType::get(mallocReturn, mallocParams, false); + malloc = Function::Create(mallocType, GlobalValue::ExternalLinkage, "malloc"); + malloc->setCallingConv(CallingConv::C); + } + + std::vector mallocArgs; + mallocArgs.push_back(elementsCode()); // TODO: FIX + setData(b->CreateCall(malloc, mallocArgs)); + } + + Value *elementsCode() const { return b->CreateMul(b->CreateMul(b->CreateMul(getChannels(), getColumns()), getRows()), getFrames()); } Value *columnStep() const { Value *columnStep = getChannels(); columnStep->setName(name+"_cStep"); return columnStep; } Value *rowStep() const { return b->CreateMul(getColumns(), columnStep(), name+"_rStep"); } @@ -296,6 +329,8 @@ private: Function *compile(const jit_matrix &m) const { + Function *kernel = compileKernel(m); + Constant *c = TheModule->getOrInsertFunction(qPrintable(mangledName()), Type::getVoidTy(getGlobalContext()), PointerType::getUnqual(TheMatrixStruct), @@ -313,10 +348,38 @@ private: BasicBlock *entry = BasicBlock::Create(getGlobalContext(), "entry", function); IRBuilder<> builder(entry); - - Function *kernel = compileKernel(m); - builder.CreateCall3(kernel, src, dst, buildPreallocate(MatrixBuilder(m, src, &builder, function, "src"), MatrixBuilder(m, dst, &builder, function, "dst"))); - + MatrixBuilder mb(m, src, &builder, function, "src"); + MatrixBuilder nb(m, dst, &builder, function, "dst"); + + std::vector kernelArgs; + kernelArgs.push_back(PointerType::getUnqual(TheMatrixStruct)); + kernelArgs.push_back(PointerType::getUnqual(TheMatrixStruct)); + kernelArgs.push_back(Type::getInt32Ty(getGlobalContext())); + PointerType *kernelType = PointerType::getUnqual(FunctionType::get(Type::getVoidTy(getGlobalContext()), kernelArgs, false)); + QString kernelFunctionName = mangledName()+"_kernel"; + TheModule->getOrInsertGlobal(qPrintable(kernelFunctionName), kernelType); + GlobalVariable *kernelFunction = TheModule->getGlobalVariable(qPrintable(kernelFunctionName)); + kernelFunction->setInitializer(ConstantPointerNull::get(kernelType)); + + QString kernelHashName = mangledName()+"_hash"; + TheModule->getOrInsertGlobal(qPrintable(kernelHashName), Type::getInt16Ty(getGlobalContext())); + GlobalVariable *kernelHash = TheModule->getGlobalVariable(qPrintable(kernelHashName)); + kernelHash->setInitializer(MatrixBuilder::constant(0, 16)); + + BasicBlock *getKernel = BasicBlock::Create(getGlobalContext(), "get_kernel", function); + BasicBlock *preallocate = BasicBlock::Create(getGlobalContext(), "preallocate", function); + Value *hashTest = builder.CreateICmpNE(mb.getHash(), kernelHash, "hash_fail_test"); + builder.CreateCondBr(hashTest, getKernel, preallocate); + + builder.SetInsertPoint(getKernel); + builder.CreateStore(kernel, kernelFunction); + builder.CreateStore(mb.getHash(), kernelHash); + builder.CreateBr(preallocate); + builder.SetInsertPoint(preallocate); + Value *kernelSize = buildPreallocate(mb, nb); + nb.allocate(); + + builder.CreateCall3(builder.CreateLoad(kernelFunction), src, dst, kernelSize); builder.CreateRetVoid(); return kernel; @@ -495,6 +558,12 @@ public: return dst.elements(); } + virtual Value *buildPreallocate(const MatrixBuilder &src, const MatrixBuilder &dst) const + { + dst.copyHeaderCode(src); + return dst.elementsCode(); + } + private: void build(const MatrixBuilder &src, const MatrixBuilder &dst, PHINode *i) const { @@ -909,7 +978,7 @@ class LLVMInitializer : public Initializer Type::getInt16Ty(getGlobalContext()), // hash NULL); - QSharedPointer kernel(Transform::make("sum", NULL)); + QSharedPointer kernel(Transform::make("add(1)", NULL)); Template src, dst; src.m() = (Mat_(2,2) << -1, -2, 3, 4); diff --git a/share/openbr/doc b/share/openbr/doc index f6402dc..8f74266 160000 --- a/share/openbr/doc +++ b/share/openbr/doc @@ -1 +1 @@ -Subproject commit f6402dc42e513a1850ec73e675f62cbd15b26e31 +Subproject commit 8f74266f98653627b04ae9a8e7f83d4dd63b4f5d