Commit 8bde0803a299150989d7ebe2bc682ff1c0583423

Authored by Josh Klontz
1 parent e9030443

substantial llvm progress

sdk/plugins/llvm.cpp
@@ -99,19 +99,52 @@ struct MatrixBuilder : public jit_matrix @@ -99,19 +99,52 @@ struct MatrixBuilder : public jit_matrix
99 MatrixBuilder(const jit_matrix &matrix, Value *value, IRBuilder<> *builder, Function *function, const Twine &name_) 99 MatrixBuilder(const jit_matrix &matrix, Value *value, IRBuilder<> *builder, Function *function, const Twine &name_)
100 : jit_matrix(matrix), m(value), b(builder), f(function), name(name_) {} 100 : jit_matrix(matrix), m(value), b(builder), f(function), name(name_) {}
101 101
102 - static Value *zero() { return constant(0); }  
103 - static Value *one() { return constant(1); }  
104 - static Value *constant(int value, int bits = 32) { return Constant::getIntegerValue(Type::getInt32Ty(getGlobalContext()), APInt(bits, value)); }  
105 - static Value *constant(float value) { return ConstantFP::get(Type::getFloatTy(getGlobalContext()), value == 0 ? -0.0f : value); }  
106 - static Value *constant(double value) { return ConstantFP::get(Type::getDoubleTy(getGlobalContext()), value == 0 ? -0.0 : value); }  
107 - Value *autoConstant(double value) const { return isFloating() ? ((bits() == 64) ? constant(value) : constant(float(value))) : constant(int(value), bits()); } 102 + static Constant *zero() { return constant(0); }
  103 + static Constant *one() { return constant(1); }
  104 + static Constant *constant(int value, int bits = 32) { return Constant::getIntegerValue(Type::getInt32Ty(getGlobalContext()), APInt(bits, value)); }
  105 + static Constant *constant(float value) { return ConstantFP::get(Type::getFloatTy(getGlobalContext()), value == 0 ? -0.0f : value); }
  106 + static Constant *constant(double value) { return ConstantFP::get(Type::getDoubleTy(getGlobalContext()), value == 0 ? -0.0 : value); }
  107 + Constant *autoConstant(double value) const { return isFloating() ? ((bits() == 64) ? constant(value) : constant(float(value))) : constant(int(value), bits()); }
108 AllocaInst *autoAlloca(double value, const Twine &name = "") const { AllocaInst *alloca = b->CreateAlloca(ty(), 0, name); b->CreateStore(autoConstant(value), alloca); return alloca; } 108 AllocaInst *autoAlloca(double value, const Twine &name = "") const { AllocaInst *alloca = b->CreateAlloca(ty(), 0, name); b->CreateStore(autoConstant(value), alloca); return alloca; }
109 109
110 - Value *getData() const { return b->CreatePointerCast(b->CreateLoad(b->CreateStructGEP(m, 0)), ptrTy(), name+"_data"); }  
111 - Value *getChannels() const { return singleChannel() ? one() : b->CreateLoad(b->CreateStructGEP(m, 1), name+"_channels"); }  
112 - Value *getColumns() const { return singleColumn() ? one() : b->CreateLoad(b->CreateStructGEP(m, 2), name+"_columns"); }  
113 - Value *getRows() const { return singleRow() ? one() : b->CreateLoad(b->CreateStructGEP(m, 3), name+"_rows"); }  
114 - Value *getFrames() const { return singleFrame() ? one() : b->CreateLoad(b->CreateStructGEP(m, 4), name+"_frames"); } 110 + Value *getData(bool cast = true) const { LoadInst *data = b->CreateLoad(b->CreateStructGEP(m, 0), name+"_data"); return cast ? b->CreatePointerCast(data, ptrTy()) : data; }
  111 + Value *getChannels() const { return singleChannel() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(m, 1), name+"_channels")); }
  112 + Value *getColumns() const { return singleColumn() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(m, 2), name+"_columns")); }
  113 + Value *getRows() const { return singleRow() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(m, 3), name+"_rows")); }
  114 + Value *getFrames() const { return singleFrame() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(m, 4), name+"_frames")); }
  115 + Value *getHash() const { return b->CreateLoad(b->CreateStructGEP(m, 5), name+"_hash"); }
  116 +
  117 + void setData(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 0)); }
  118 + void setChannels(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 1)); }
  119 + void setColumns(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 2)); }
  120 + void setRows(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 3)); }
  121 + void setFrames(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 4)); }
  122 + void setHash(Value *value) const { b->CreateStore(value, b->CreateStructGEP(m, 5)); }
  123 +
  124 + void copyHeaderCode(const MatrixBuilder &other) const {
  125 + setChannels(other.getChannels());
  126 + setRows(other.getRows());
  127 + setFrames(other.getFrames());
  128 + setHash(other.getHash());
  129 + }
  130 +
  131 + void allocate() const {
  132 + Function *malloc = TheModule->getFunction("malloc");
  133 + if (!malloc) {
  134 + PointerType *mallocReturn = Type::getInt8PtrTy(getGlobalContext());
  135 + std::vector<Type*> mallocParams;
  136 + mallocParams.push_back(Type::getInt32Ty(getGlobalContext()));
  137 + FunctionType* mallocType = FunctionType::get(mallocReturn, mallocParams, false);
  138 + malloc = Function::Create(mallocType, GlobalValue::ExternalLinkage, "malloc");
  139 + malloc->setCallingConv(CallingConv::C);
  140 + }
  141 +
  142 + std::vector<Value*> mallocArgs;
  143 + mallocArgs.push_back(elementsCode()); // TODO: FIX
  144 + setData(b->CreateCall(malloc, mallocArgs));
  145 + }
  146 +
  147 + Value *elementsCode() const { return b->CreateMul(b->CreateMul(b->CreateMul(getChannels(), getColumns()), getRows()), getFrames()); }
115 148
116 Value *columnStep() const { Value *columnStep = getChannels(); columnStep->setName(name+"_cStep"); return columnStep; } 149 Value *columnStep() const { Value *columnStep = getChannels(); columnStep->setName(name+"_cStep"); return columnStep; }
117 Value *rowStep() const { return b->CreateMul(getColumns(), columnStep(), name+"_rStep"); } 150 Value *rowStep() const { return b->CreateMul(getColumns(), columnStep(), name+"_rStep"); }
@@ -296,6 +329,8 @@ private: @@ -296,6 +329,8 @@ private:
296 329
297 Function *compile(const jit_matrix &m) const 330 Function *compile(const jit_matrix &m) const
298 { 331 {
  332 + Function *kernel = compileKernel(m);
  333 +
299 Constant *c = TheModule->getOrInsertFunction(qPrintable(mangledName()), 334 Constant *c = TheModule->getOrInsertFunction(qPrintable(mangledName()),
300 Type::getVoidTy(getGlobalContext()), 335 Type::getVoidTy(getGlobalContext()),
301 PointerType::getUnqual(TheMatrixStruct), 336 PointerType::getUnqual(TheMatrixStruct),
@@ -313,10 +348,38 @@ private: @@ -313,10 +348,38 @@ private:
313 348
314 BasicBlock *entry = BasicBlock::Create(getGlobalContext(), "entry", function); 349 BasicBlock *entry = BasicBlock::Create(getGlobalContext(), "entry", function);
315 IRBuilder<> builder(entry); 350 IRBuilder<> builder(entry);
316 -  
317 - Function *kernel = compileKernel(m);  
318 - builder.CreateCall3(kernel, src, dst, buildPreallocate(MatrixBuilder(m, src, &builder, function, "src"), MatrixBuilder(m, dst, &builder, function, "dst")));  
319 - 351 + MatrixBuilder mb(m, src, &builder, function, "src");
  352 + MatrixBuilder nb(m, dst, &builder, function, "dst");
  353 +
  354 + std::vector<Type*> kernelArgs;
  355 + kernelArgs.push_back(PointerType::getUnqual(TheMatrixStruct));
  356 + kernelArgs.push_back(PointerType::getUnqual(TheMatrixStruct));
  357 + kernelArgs.push_back(Type::getInt32Ty(getGlobalContext()));
  358 + PointerType *kernelType = PointerType::getUnqual(FunctionType::get(Type::getVoidTy(getGlobalContext()), kernelArgs, false));
  359 + QString kernelFunctionName = mangledName()+"_kernel";
  360 + TheModule->getOrInsertGlobal(qPrintable(kernelFunctionName), kernelType);
  361 + GlobalVariable *kernelFunction = TheModule->getGlobalVariable(qPrintable(kernelFunctionName));
  362 + kernelFunction->setInitializer(ConstantPointerNull::get(kernelType));
  363 +
  364 + QString kernelHashName = mangledName()+"_hash";
  365 + TheModule->getOrInsertGlobal(qPrintable(kernelHashName), Type::getInt16Ty(getGlobalContext()));
  366 + GlobalVariable *kernelHash = TheModule->getGlobalVariable(qPrintable(kernelHashName));
  367 + kernelHash->setInitializer(MatrixBuilder::constant(0, 16));
  368 +
  369 + BasicBlock *getKernel = BasicBlock::Create(getGlobalContext(), "get_kernel", function);
  370 + BasicBlock *preallocate = BasicBlock::Create(getGlobalContext(), "preallocate", function);
  371 + Value *hashTest = builder.CreateICmpNE(mb.getHash(), kernelHash, "hash_fail_test");
  372 + builder.CreateCondBr(hashTest, getKernel, preallocate);
  373 +
  374 + builder.SetInsertPoint(getKernel);
  375 + builder.CreateStore(kernel, kernelFunction);
  376 + builder.CreateStore(mb.getHash(), kernelHash);
  377 + builder.CreateBr(preallocate);
  378 + builder.SetInsertPoint(preallocate);
  379 + Value *kernelSize = buildPreallocate(mb, nb);
  380 + nb.allocate();
  381 +
  382 + builder.CreateCall3(builder.CreateLoad(kernelFunction), src, dst, kernelSize);
320 builder.CreateRetVoid(); 383 builder.CreateRetVoid();
321 384
322 return kernel; 385 return kernel;
@@ -495,6 +558,12 @@ public: @@ -495,6 +558,12 @@ public:
495 return dst.elements(); 558 return dst.elements();
496 } 559 }
497 560
  561 + virtual Value *buildPreallocate(const MatrixBuilder &src, const MatrixBuilder &dst) const
  562 + {
  563 + dst.copyHeaderCode(src);
  564 + return dst.elementsCode();
  565 + }
  566 +
498 private: 567 private:
499 void build(const MatrixBuilder &src, const MatrixBuilder &dst, PHINode *i) const 568 void build(const MatrixBuilder &src, const MatrixBuilder &dst, PHINode *i) const
500 { 569 {
@@ -909,7 +978,7 @@ class LLVMInitializer : public Initializer @@ -909,7 +978,7 @@ class LLVMInitializer : public Initializer
909 Type::getInt16Ty(getGlobalContext()), // hash 978 Type::getInt16Ty(getGlobalContext()), // hash
910 NULL); 979 NULL);
911 980
912 - QSharedPointer<Transform> kernel(Transform::make("sum", NULL)); 981 + QSharedPointer<Transform> kernel(Transform::make("add(1)", NULL));
913 982
914 Template src, dst; 983 Template src, dst;
915 src.m() = (Mat_<qint8>(2,2) << -1, -2, 3, 4); 984 src.m() = (Mat_<qint8>(2,2) << -1, -2, 3, 4);
1 -Subproject commit f6402dc42e513a1850ec73e675f62cbd15b26e31 1 +Subproject commit 8f74266f98653627b04ae9a8e7f83d4dd63b4f5d