Commit ae79a4ed0ee837e8c1c14c568cbbed736bf56c81

Authored by Josh Klontz
1 parent 7a6cef73

cleaned up MatrixBuilder class

Showing 1 changed file with 74 additions and 55 deletions
sdk/plugins/llvm.cpp
... ... @@ -103,16 +103,15 @@ struct MatrixBuilder
103 103 static Constant *constant(double value) { return ConstantFP::get(Type::getDoubleTy(getGlobalContext()), value == 0 ? -0.0 : value); }
104 104 static Constant *zero() { return constant(0); }
105 105 static Constant *one() { return constant(1); }
106   -
107 106 Constant *autoConstant(double value) const { return m->isFloating() ? ((m->bits() == 64) ? constant(value) : constant(float(value))) : constant(int(value), m->bits()); }
108 107 AllocaInst *autoAlloca(double value, const Twine &name = "") const { AllocaInst *alloca = b->CreateAlloca(ty(), 0, name); b->CreateStore(autoConstant(value), alloca); return alloca; }
109 108  
110   - Value *getData(bool cast = true) const { LoadInst *data = b->CreateLoad(b->CreateStructGEP(v, 0), name+"_data"); return cast ? b->CreatePointerCast(data, ptrTy()) : data; }
111   - Value *getChannels() const { return m->singleChannel() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 1), name+"_channels")); }
112   - Value *getColumns() const { return m->singleColumn() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 2), name+"_columns")); }
113   - Value *getRows() const { return m->singleRow() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 3), name+"_rows")); }
114   - Value *getFrames() const { return m->singleFrame() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 4), name+"_frames")); }
115   - Value *getHash() const { return b->CreateLoad(b->CreateStructGEP(v, 5), name+"_hash"); }
  109 + Value *data(bool cast = true) const { LoadInst *data = b->CreateLoad(b->CreateStructGEP(v, 0), name+"_data"); return cast ? b->CreatePointerCast(data, ptrTy()) : data; }
  110 + Value *channels() const { return m->singleChannel() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 1), name+"_channels")); }
  111 + Value *columns() const { return m->singleColumn() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 2), name+"_columns")); }
  112 + Value *rows() const { return m->singleRow() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 3), name+"_rows")); }
  113 + Value *frames() const { return m->singleFrame() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 4), name+"_frames")); }
  114 + Value *hash() const { return b->CreateLoad(b->CreateStructGEP(v, 5), name+"_hash"); }
116 115  
117 116 void setData(Value *value) const { b->CreateStore(value, b->CreateStructGEP(v, 0)); }
118 117 void setChannels(Value *value) const { b->CreateStore(value, b->CreateStructGEP(v, 1)); }
... ... @@ -121,18 +120,18 @@ struct MatrixBuilder
121 120 void setFrames(Value *value) const { b->CreateStore(value, b->CreateStructGEP(v, 4)); }
122 121 void setHash(Value *value) const { b->CreateStore(value, b->CreateStructGEP(v, 5)); }
123 122  
124   - void copyHeaderCode(const MatrixBuilder &other) const {
125   - setChannels(other.getChannels());
126   - setColumns(other.getColumns());
127   - setRows(other.getRows());
128   - setFrames(other.getFrames());
129   - setHash(other.getHash());
  123 + void copyHeader(const MatrixBuilder &other) const {
  124 + setChannels(other.channels());
  125 + setColumns(other.columns());
  126 + setRows(other.rows());
  127 + setFrames(other.frames());
  128 + setHash(other.hash());
130 129 }
131 130  
132   - void allocateCode() const {
133   - Function *malloc = TheModule->getFunction("malloc");
  131 + void allocate() const {
  132 + static Function *malloc = TheModule->getFunction("malloc");
134 133 if (!malloc) {
135   - PointerType *mallocReturn = Type::getInt8PtrTy(getGlobalContext());
  134 + Type *mallocReturn = Type::getInt8PtrTy(getGlobalContext());
136 135 std::vector<Type*> mallocParams;
137 136 mallocParams.push_back(Type::getInt32Ty(getGlobalContext()));
138 137 FunctionType* mallocType = FunctionType::get(mallocReturn, mallocParams, false);
... ... @@ -141,36 +140,53 @@ struct MatrixBuilder
141 140 }
142 141  
143 142 std::vector<Value*> mallocArgs;
144   - mallocArgs.push_back(bytesCode());
  143 + mallocArgs.push_back(bytes());
145 144 setData(b->CreateCall(malloc, mallocArgs));
146 145 }
147 146  
148   - Value *get(int mask) const { return b->CreateAnd(getHash(), constant(mask, 16)); }
149   - void set(int value, int mask) const { setHash(b->CreateOr(b->CreateAnd(getHash(), constant(~mask, 16)), b->CreateAnd(constant(value, 16), constant(mask, 16)))); }
150   - void setBit(bool on, int mask) const { on ? setHash(b->CreateOr(getHash(), constant(mask, 16))) : setHash(b->CreateAnd(getHash(), constant(~mask, 16))); }
151   -
152   - Value *bitsCode() const { return get(Matrix::Bits); }
153   - void setBitsCode(int bits) const { set(bits, Matrix::Bits); }
154   - Value *isFloatingCode() const { return get(Matrix::Floating); }
155   - void setFloatingCode(bool isFloating) const { if (isFloating) setSignedCode(true); setBit(isFloating, Matrix::Floating); }
156   - Value *isSignedCode() const { return get(Matrix::Signed); }
157   - void setSignedCode(bool isSigned) const { setBit(isSigned, Matrix::Signed); }
158   - Value *typeCode() const { return get(Matrix::Bits + Matrix::Floating + Matrix::Signed); }
159   - void setTypeCode(int type) const { set(type, Matrix::Bits + Matrix::Floating + Matrix::Signed); }
160   - Value *singleChannelCode() const { return get(Matrix::SingleChannel); }
161   - void setSingleChannelCode(bool singleChannel) const { setBit(singleChannel, Matrix::SingleChannel); }
162   - Value *singleColumnCode() const { return get(Matrix::SingleColumn); }
163   - void setSingleColumnCode(bool singleColumn) { setBit(singleColumn, Matrix::SingleColumn); }
164   - Value *singleRowCode() const { return get(Matrix::SingleRow); }
165   - void setSingleRowCode(bool singleRow) const { setBit(singleRow, Matrix::SingleRow); }
166   - Value *singleFrameCode() const { return get(Matrix::SingleFrame); }
167   - void setSingleFrameCode(bool singleFrame) const { setBit(singleFrame, Matrix::SingleFrame); }
168   - Value *elementsCode() const { return b->CreateMul(b->CreateMul(b->CreateMul(getChannels(), getColumns()), getRows()), getFrames()); }
169   - Value *bytesCode() const { return b->CreateMul(b->CreateUDiv(b->CreateCast(Instruction::ZExt, bitsCode(), Type::getInt32Ty(getGlobalContext())), constant(8, 32)), elementsCode()); }
170   -
171   - Value *columnStep() const { Value *columnStep = getChannels(); columnStep->setName(name+"_cStep"); return columnStep; }
172   - Value *rowStep() const { return b->CreateMul(getColumns(), columnStep(), name+"_rStep"); }
173   - Value *frameStep() const { return b->CreateMul(getRows(), rowStep(), name+"_tStep"); }
  147 + void deallocate() const {
  148 + static Function *free = TheModule->getFunction("free");
  149 + if (!free) {
  150 + Type *freeReturn = Type::getVoidTy(getGlobalContext());
  151 + std::vector<Type*> freeParams;
  152 + freeParams.push_back(Type::getInt8PtrTy(getGlobalContext()));
  153 + FunctionType* freeType = FunctionType::get(freeReturn, freeParams, false);
  154 + free = Function::Create(freeType, GlobalValue::ExternalLinkage, "free", TheModule);
  155 + free->setCallingConv(CallingConv::C);
  156 + }
  157 +
  158 + std::vector<Value*> freeArgs;
  159 + freeArgs.push_back(b->CreateStructGEP(v, 0));
  160 + b->CreateCall(free, freeArgs);
  161 + setData(ConstantPointerNull::get(Type::getInt8PtrTy(getGlobalContext())));
  162 + }
  163 +
  164 + Value *get(int mask) const { return b->CreateAnd(hash(), constant(mask, 16)); }
  165 + void set(int value, int mask) const { setHash(b->CreateOr(b->CreateAnd(hash(), constant(~mask, 16)), b->CreateAnd(constant(value, 16), constant(mask, 16)))); }
  166 + void setBit(bool on, int mask) const { on ? setHash(b->CreateOr(hash(), constant(mask, 16))) : setHash(b->CreateAnd(hash(), constant(~mask, 16))); }
  167 +
  168 + Value *bits() const { return get(Matrix::Bits); }
  169 + void setBits(int bits) const { set(bits, Matrix::Bits); }
  170 + Value *isFloating() const { return get(Matrix::Floating); }
  171 + void setFloating(bool isFloating) const { if (isFloating) setSigned(true); setBit(isFloating, Matrix::Floating); }
  172 + Value *isSigned() const { return get(Matrix::Signed); }
  173 + void setSigned(bool isSigned) const { setBit(isSigned, Matrix::Signed); }
  174 + Value *type() const { return get(Matrix::Bits + Matrix::Floating + Matrix::Signed); }
  175 + void setType(int type) const { set(type, Matrix::Bits + Matrix::Floating + Matrix::Signed); }
  176 + Value *singleChannel() const { return get(Matrix::SingleChannel); }
  177 + void setSingleChannel(bool singleChannel) const { setBit(singleChannel, Matrix::SingleChannel); }
  178 + Value *singleColumn() const { return get(Matrix::SingleColumn); }
  179 + void setSingleColumn(bool singleColumn) { setBit(singleColumn, Matrix::SingleColumn); }
  180 + Value *singleRow() const { return get(Matrix::SingleRow); }
  181 + void setSingleRow(bool singleRow) const { setBit(singleRow, Matrix::SingleRow); }
  182 + Value *singleFrame() const { return get(Matrix::SingleFrame); }
  183 + void setSingleFrame(bool singleFrame) const { setBit(singleFrame, Matrix::SingleFrame); }
  184 + Value *elements() const { return b->CreateMul(b->CreateMul(b->CreateMul(channels(), columns()), rows()), frames()); }
  185 + Value *bytes() const { return b->CreateMul(b->CreateUDiv(b->CreateCast(Instruction::ZExt, bits(), Type::getInt32Ty(getGlobalContext())), constant(8, 32)), elements()); }
  186 +
  187 + Value *columnStep() const { Value *columnStep = channels(); columnStep->setName(name+"_cStep"); return columnStep; }
  188 + Value *rowStep() const { return b->CreateMul(columns(), columnStep(), name+"_rStep"); }
  189 + Value *frameStep() const { return b->CreateMul(rows(), rowStep(), name+"_tStep"); }
174 190 Value *aliasColumnStep(const MatrixBuilder &other) const { return (m->channels == other.m->channels) ? other.columnStep() : columnStep(); }
175 191 Value *aliasRowStep(const MatrixBuilder &other) const { return (m->columns == other.m->columns) ? other.rowStep() : rowStep(); }
176 192 Value *aliasFrameStep(const MatrixBuilder &other) const { return (m->rows == other.m->rows) ? other.frameStep() : frameStep(); }
... ... @@ -183,7 +199,9 @@ struct MatrixBuilder
183 199 Value *aliasIndex(const MatrixBuilder &other, Value *c, Value *x, Value *y) const { return m->singleRow() ? aliasIndex(other, c, x) : b->CreateAdd(b->CreateMul(y, aliasRowStep(other)), aliasIndex(other, c, x)); }
184 200 Value *aliasIndex(const MatrixBuilder &other, Value *c, Value *x, Value *y, Value *f) const { return m->singleFrame() ? aliasIndex(other, c, x, y) : b->CreateAdd(b->CreateMul(f, aliasFrameStep(other)), aliasIndex(other, c, x, y)); }
185 201  
186   - void deindex(Value *i, Value **c) const { *c = m->singleChannel() ? constant(0) : i; }
  202 + void deindex(Value *i, Value **c) const {
  203 + *c = m->singleChannel() ? constant(0) : i;
  204 + }
187 205 void deindex(Value *i, Value **c, Value **x) const {
188 206 Value *rem;
189 207 if (m->singleColumn()) {
... ... @@ -221,8 +239,9 @@ struct MatrixBuilder
221 239 deindex(rem, c, x, y);
222 240 }
223 241  
224   - LoadInst *load(Value *i) const { return b->CreateLoad(b->CreateGEP(getData(), i)); }
225   - StoreInst *store(Value *i, Value *value) const { return b->CreateStore(value, b->CreateGEP(getData(), i)); }
  242 + LoadInst *load(Value *i) const { return b->CreateLoad(b->CreateGEP(data(), i)); }
  243 + StoreInst *store(Value *i, Value *value) const { return b->CreateStore(value, b->CreateGEP(data(), i)); }
  244 +
226 245 Value *cast(Value *i, const MatrixBuilder &dst) const { return (m->type() == dst.m->type()) ? i : b->CreateCast(CastInst::getCastOpcode(i, m->isSigned(), dst.ty(), dst.m->isSigned()), i, dst.ty()); }
227 246 Value *add(Value *i, Value *j, const Twine &name = "") const { return m->isFloating() ? b->CreateFAdd(i, j, name) : b->CreateAdd(i, j, name); }
228 247 Value *multiply(Value *i, Value *j, const Twine &name = "") const { return m->isFloating() ? b->CreateFMul(i, j, name) : b->CreateMul(i, j, name); }
... ... @@ -396,12 +415,12 @@ private:
396 415  
397 416 BasicBlock *getKernel = BasicBlock::Create(getGlobalContext(), "get_kernel", function);
398 417 BasicBlock *preallocate = BasicBlock::Create(getGlobalContext(), "preallocate", function);
399   - Value *hashTest = builder.CreateICmpNE(mb.getHash(), builder.CreateLoad(kernelHash), "hash_fail_test");
  418 + Value *hashTest = builder.CreateICmpNE(mb.hash(), builder.CreateLoad(kernelHash), "hash_fail_test");
400 419 builder.CreateCondBr(hashTest, getKernel, preallocate);
401 420  
402 421 builder.SetInsertPoint(getKernel);
403 422 builder.CreateStore(kernel, kernelFunction);
404   - builder.CreateStore(mb.getHash(), kernelHash);
  423 + builder.CreateStore(mb.hash(), kernelHash);
405 424 builder.CreateBr(preallocate);
406 425 builder.SetInsertPoint(preallocate);
407 426 Value *kernelSize = buildPreallocate(mb, nb);
... ... @@ -409,7 +428,7 @@ private:
409 428 BasicBlock *allocate = BasicBlock::Create(getGlobalContext(), "allocate", function);
410 429 builder.CreateBr(allocate);
411 430 builder.SetInsertPoint(allocate);
412   - nb.allocateCode();
  431 + nb.allocate();
413 432  
414 433 BasicBlock *callKernel = BasicBlock::Create(getGlobalContext(), "call_kernel", function);
415 434 builder.CreateBr(callKernel);
... ... @@ -581,8 +600,8 @@ public:
581 600  
582 601 virtual Value *buildPreallocate(const MatrixBuilder &src, const MatrixBuilder &dst) const
583 602 {
584   - dst.copyHeaderCode(src);
585   - return dst.elementsCode();
  603 + dst.copyHeader(src);
  604 + return dst.elements();
586 605 }
587 606  
588 607 private:
... ... @@ -739,7 +758,7 @@ class sumTransform : public UnaryTransform
739 758  
740 759 if (frames && !src.m->singleFrame()) {
741 760 BasicBlock *loop, *exit;
742   - src_t = dst.beginLoop(loops.last(), loop, exit, src.getFrames(), "src_t");
  761 + src_t = dst.beginLoop(loops.last(), loop, exit, src.frames(), "src_t");
743 762 loops.append(loop);
744 763 exits.append(exit);
745 764 } else {
... ... @@ -748,7 +767,7 @@ class sumTransform : public UnaryTransform
748 767  
749 768 if (rows && !src.m->singleRow()) {
750 769 BasicBlock *loop, *exit;
751   - src_y = dst.beginLoop(loops.last(), loop, exit, src.getRows(), "src_y");
  770 + src_y = dst.beginLoop(loops.last(), loop, exit, src.rows(), "src_y");
752 771 loops.append(loop);
753 772 exits.append(exit);
754 773 } else {
... ... @@ -757,7 +776,7 @@ class sumTransform : public UnaryTransform
757 776  
758 777 if (columns && !src.m->singleColumn()) {
759 778 BasicBlock *loop, *exit;
760   - src_x = dst.beginLoop(loops.last(), loop, exit, src.getColumns(), "src_x");
  779 + src_x = dst.beginLoop(loops.last(), loop, exit, src.columns(), "src_x");
761 780 loops.append(loop);
762 781 exits.append(exit);
763 782 } else {
... ... @@ -766,7 +785,7 @@ class sumTransform : public UnaryTransform
766 785  
767 786 if (channels && !src.m->singleChannel()) {
768 787 BasicBlock *loop, *exit;
769   - src_c = dst.beginLoop(loops.last(), loop, exit, src.getChannels(), "src_c");
  788 + src_c = dst.beginLoop(loops.last(), loop, exit, src.channels(), "src_c");
770 789 loops.append(loop);
771 790 exits.append(exit);
772 791 } else {
... ...