Commit ae79a4ed0ee837e8c1c14c568cbbed736bf56c81

Authored by Josh Klontz
1 parent 7a6cef73

cleaned up MatrixBuilder class

Showing 1 changed file with 74 additions and 55 deletions
sdk/plugins/llvm.cpp
@@ -103,16 +103,15 @@ struct MatrixBuilder @@ -103,16 +103,15 @@ struct MatrixBuilder
103 static Constant *constant(double value) { return ConstantFP::get(Type::getDoubleTy(getGlobalContext()), value == 0 ? -0.0 : value); } 103 static Constant *constant(double value) { return ConstantFP::get(Type::getDoubleTy(getGlobalContext()), value == 0 ? -0.0 : value); }
104 static Constant *zero() { return constant(0); } 104 static Constant *zero() { return constant(0); }
105 static Constant *one() { return constant(1); } 105 static Constant *one() { return constant(1); }
106 -  
107 Constant *autoConstant(double value) const { return m->isFloating() ? ((m->bits() == 64) ? constant(value) : constant(float(value))) : constant(int(value), m->bits()); } 106 Constant *autoConstant(double value) const { return m->isFloating() ? ((m->bits() == 64) ? constant(value) : constant(float(value))) : constant(int(value), m->bits()); }
108 AllocaInst *autoAlloca(double value, const Twine &name = "") const { AllocaInst *alloca = b->CreateAlloca(ty(), 0, name); b->CreateStore(autoConstant(value), alloca); return alloca; } 107 AllocaInst *autoAlloca(double value, const Twine &name = "") const { AllocaInst *alloca = b->CreateAlloca(ty(), 0, name); b->CreateStore(autoConstant(value), alloca); return alloca; }
109 108
110 - Value *getData(bool cast = true) const { LoadInst *data = b->CreateLoad(b->CreateStructGEP(v, 0), name+"_data"); return cast ? b->CreatePointerCast(data, ptrTy()) : data; }  
111 - Value *getChannels() const { return m->singleChannel() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 1), name+"_channels")); }  
112 - Value *getColumns() const { return m->singleColumn() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 2), name+"_columns")); }  
113 - Value *getRows() const { return m->singleRow() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 3), name+"_rows")); }  
114 - Value *getFrames() const { return m->singleFrame() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 4), name+"_frames")); }  
115 - Value *getHash() const { return b->CreateLoad(b->CreateStructGEP(v, 5), name+"_hash"); } 109 + Value *data(bool cast = true) const { LoadInst *data = b->CreateLoad(b->CreateStructGEP(v, 0), name+"_data"); return cast ? b->CreatePointerCast(data, ptrTy()) : data; }
  110 + Value *channels() const { return m->singleChannel() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 1), name+"_channels")); }
  111 + Value *columns() const { return m->singleColumn() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 2), name+"_columns")); }
  112 + Value *rows() const { return m->singleRow() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 3), name+"_rows")); }
  113 + Value *frames() const { return m->singleFrame() ? static_cast<Value*>(one()) : static_cast<Value*>(b->CreateLoad(b->CreateStructGEP(v, 4), name+"_frames")); }
  114 + Value *hash() const { return b->CreateLoad(b->CreateStructGEP(v, 5), name+"_hash"); }
116 115
117 void setData(Value *value) const { b->CreateStore(value, b->CreateStructGEP(v, 0)); } 116 void setData(Value *value) const { b->CreateStore(value, b->CreateStructGEP(v, 0)); }
118 void setChannels(Value *value) const { b->CreateStore(value, b->CreateStructGEP(v, 1)); } 117 void setChannels(Value *value) const { b->CreateStore(value, b->CreateStructGEP(v, 1)); }
@@ -121,18 +120,18 @@ struct MatrixBuilder @@ -121,18 +120,18 @@ struct MatrixBuilder
121 void setFrames(Value *value) const { b->CreateStore(value, b->CreateStructGEP(v, 4)); } 120 void setFrames(Value *value) const { b->CreateStore(value, b->CreateStructGEP(v, 4)); }
122 void setHash(Value *value) const { b->CreateStore(value, b->CreateStructGEP(v, 5)); } 121 void setHash(Value *value) const { b->CreateStore(value, b->CreateStructGEP(v, 5)); }
123 122
124 - void copyHeaderCode(const MatrixBuilder &other) const {  
125 - setChannels(other.getChannels());  
126 - setColumns(other.getColumns());  
127 - setRows(other.getRows());  
128 - setFrames(other.getFrames());  
129 - setHash(other.getHash()); 123 + void copyHeader(const MatrixBuilder &other) const {
  124 + setChannels(other.channels());
  125 + setColumns(other.columns());
  126 + setRows(other.rows());
  127 + setFrames(other.frames());
  128 + setHash(other.hash());
130 } 129 }
131 130
132 - void allocateCode() const {  
133 - Function *malloc = TheModule->getFunction("malloc"); 131 + void allocate() const {
  132 + static Function *malloc = TheModule->getFunction("malloc");
134 if (!malloc) { 133 if (!malloc) {
135 - PointerType *mallocReturn = Type::getInt8PtrTy(getGlobalContext()); 134 + Type *mallocReturn = Type::getInt8PtrTy(getGlobalContext());
136 std::vector<Type*> mallocParams; 135 std::vector<Type*> mallocParams;
137 mallocParams.push_back(Type::getInt32Ty(getGlobalContext())); 136 mallocParams.push_back(Type::getInt32Ty(getGlobalContext()));
138 FunctionType* mallocType = FunctionType::get(mallocReturn, mallocParams, false); 137 FunctionType* mallocType = FunctionType::get(mallocReturn, mallocParams, false);
@@ -141,36 +140,53 @@ struct MatrixBuilder @@ -141,36 +140,53 @@ struct MatrixBuilder
141 } 140 }
142 141
143 std::vector<Value*> mallocArgs; 142 std::vector<Value*> mallocArgs;
144 - mallocArgs.push_back(bytesCode()); 143 + mallocArgs.push_back(bytes());
145 setData(b->CreateCall(malloc, mallocArgs)); 144 setData(b->CreateCall(malloc, mallocArgs));
146 } 145 }
147 146
148 - Value *get(int mask) const { return b->CreateAnd(getHash(), constant(mask, 16)); }  
149 - void set(int value, int mask) const { setHash(b->CreateOr(b->CreateAnd(getHash(), constant(~mask, 16)), b->CreateAnd(constant(value, 16), constant(mask, 16)))); }  
150 - void setBit(bool on, int mask) const { on ? setHash(b->CreateOr(getHash(), constant(mask, 16))) : setHash(b->CreateAnd(getHash(), constant(~mask, 16))); }  
151 -  
152 - Value *bitsCode() const { return get(Matrix::Bits); }  
153 - void setBitsCode(int bits) const { set(bits, Matrix::Bits); }  
154 - Value *isFloatingCode() const { return get(Matrix::Floating); }  
155 - void setFloatingCode(bool isFloating) const { if (isFloating) setSignedCode(true); setBit(isFloating, Matrix::Floating); }  
156 - Value *isSignedCode() const { return get(Matrix::Signed); }  
157 - void setSignedCode(bool isSigned) const { setBit(isSigned, Matrix::Signed); }  
158 - Value *typeCode() const { return get(Matrix::Bits + Matrix::Floating + Matrix::Signed); }  
159 - void setTypeCode(int type) const { set(type, Matrix::Bits + Matrix::Floating + Matrix::Signed); }  
160 - Value *singleChannelCode() const { return get(Matrix::SingleChannel); }  
161 - void setSingleChannelCode(bool singleChannel) const { setBit(singleChannel, Matrix::SingleChannel); }  
162 - Value *singleColumnCode() const { return get(Matrix::SingleColumn); }  
163 - void setSingleColumnCode(bool singleColumn) { setBit(singleColumn, Matrix::SingleColumn); }  
164 - Value *singleRowCode() const { return get(Matrix::SingleRow); }  
165 - void setSingleRowCode(bool singleRow) const { setBit(singleRow, Matrix::SingleRow); }  
166 - Value *singleFrameCode() const { return get(Matrix::SingleFrame); }  
167 - void setSingleFrameCode(bool singleFrame) const { setBit(singleFrame, Matrix::SingleFrame); }  
168 - Value *elementsCode() const { return b->CreateMul(b->CreateMul(b->CreateMul(getChannels(), getColumns()), getRows()), getFrames()); }  
169 - Value *bytesCode() const { return b->CreateMul(b->CreateUDiv(b->CreateCast(Instruction::ZExt, bitsCode(), Type::getInt32Ty(getGlobalContext())), constant(8, 32)), elementsCode()); }  
170 -  
171 - Value *columnStep() const { Value *columnStep = getChannels(); columnStep->setName(name+"_cStep"); return columnStep; }  
172 - Value *rowStep() const { return b->CreateMul(getColumns(), columnStep(), name+"_rStep"); }  
173 - Value *frameStep() const { return b->CreateMul(getRows(), rowStep(), name+"_tStep"); } 147 + void deallocate() const {
  148 + static Function *free = TheModule->getFunction("free");
  149 + if (!free) {
  150 + Type *freeReturn = Type::getVoidTy(getGlobalContext());
  151 + std::vector<Type*> freeParams;
  152 + freeParams.push_back(Type::getInt8PtrTy(getGlobalContext()));
  153 + FunctionType* freeType = FunctionType::get(freeReturn, freeParams, false);
  154 + free = Function::Create(freeType, GlobalValue::ExternalLinkage, "free", TheModule);
  155 + free->setCallingConv(CallingConv::C);
  156 + }
  157 +
  158 + std::vector<Value*> freeArgs;
  159 + freeArgs.push_back(b->CreateStructGEP(v, 0));
  160 + b->CreateCall(free, freeArgs);
  161 + setData(ConstantPointerNull::get(Type::getInt8PtrTy(getGlobalContext())));
  162 + }
  163 +
  164 + Value *get(int mask) const { return b->CreateAnd(hash(), constant(mask, 16)); }
  165 + void set(int value, int mask) const { setHash(b->CreateOr(b->CreateAnd(hash(), constant(~mask, 16)), b->CreateAnd(constant(value, 16), constant(mask, 16)))); }
  166 + void setBit(bool on, int mask) const { on ? setHash(b->CreateOr(hash(), constant(mask, 16))) : setHash(b->CreateAnd(hash(), constant(~mask, 16))); }
  167 +
  168 + Value *bits() const { return get(Matrix::Bits); }
  169 + void setBits(int bits) const { set(bits, Matrix::Bits); }
  170 + Value *isFloating() const { return get(Matrix::Floating); }
  171 + void setFloating(bool isFloating) const { if (isFloating) setSigned(true); setBit(isFloating, Matrix::Floating); }
  172 + Value *isSigned() const { return get(Matrix::Signed); }
  173 + void setSigned(bool isSigned) const { setBit(isSigned, Matrix::Signed); }
  174 + Value *type() const { return get(Matrix::Bits + Matrix::Floating + Matrix::Signed); }
  175 + void setType(int type) const { set(type, Matrix::Bits + Matrix::Floating + Matrix::Signed); }
  176 + Value *singleChannel() const { return get(Matrix::SingleChannel); }
  177 + void setSingleChannel(bool singleChannel) const { setBit(singleChannel, Matrix::SingleChannel); }
  178 + Value *singleColumn() const { return get(Matrix::SingleColumn); }
  179 + void setSingleColumn(bool singleColumn) { setBit(singleColumn, Matrix::SingleColumn); }
  180 + Value *singleRow() const { return get(Matrix::SingleRow); }
  181 + void setSingleRow(bool singleRow) const { setBit(singleRow, Matrix::SingleRow); }
  182 + Value *singleFrame() const { return get(Matrix::SingleFrame); }
  183 + void setSingleFrame(bool singleFrame) const { setBit(singleFrame, Matrix::SingleFrame); }
  184 + Value *elements() const { return b->CreateMul(b->CreateMul(b->CreateMul(channels(), columns()), rows()), frames()); }
  185 + Value *bytes() const { return b->CreateMul(b->CreateUDiv(b->CreateCast(Instruction::ZExt, bits(), Type::getInt32Ty(getGlobalContext())), constant(8, 32)), elements()); }
  186 +
  187 + Value *columnStep() const { Value *columnStep = channels(); columnStep->setName(name+"_cStep"); return columnStep; }
  188 + Value *rowStep() const { return b->CreateMul(columns(), columnStep(), name+"_rStep"); }
  189 + Value *frameStep() const { return b->CreateMul(rows(), rowStep(), name+"_tStep"); }
174 Value *aliasColumnStep(const MatrixBuilder &other) const { return (m->channels == other.m->channels) ? other.columnStep() : columnStep(); } 190 Value *aliasColumnStep(const MatrixBuilder &other) const { return (m->channels == other.m->channels) ? other.columnStep() : columnStep(); }
175 Value *aliasRowStep(const MatrixBuilder &other) const { return (m->columns == other.m->columns) ? other.rowStep() : rowStep(); } 191 Value *aliasRowStep(const MatrixBuilder &other) const { return (m->columns == other.m->columns) ? other.rowStep() : rowStep(); }
176 Value *aliasFrameStep(const MatrixBuilder &other) const { return (m->rows == other.m->rows) ? other.frameStep() : frameStep(); } 192 Value *aliasFrameStep(const MatrixBuilder &other) const { return (m->rows == other.m->rows) ? other.frameStep() : frameStep(); }
@@ -183,7 +199,9 @@ struct MatrixBuilder @@ -183,7 +199,9 @@ struct MatrixBuilder
183 Value *aliasIndex(const MatrixBuilder &other, Value *c, Value *x, Value *y) const { return m->singleRow() ? aliasIndex(other, c, x) : b->CreateAdd(b->CreateMul(y, aliasRowStep(other)), aliasIndex(other, c, x)); } 199 Value *aliasIndex(const MatrixBuilder &other, Value *c, Value *x, Value *y) const { return m->singleRow() ? aliasIndex(other, c, x) : b->CreateAdd(b->CreateMul(y, aliasRowStep(other)), aliasIndex(other, c, x)); }
184 Value *aliasIndex(const MatrixBuilder &other, Value *c, Value *x, Value *y, Value *f) const { return m->singleFrame() ? aliasIndex(other, c, x, y) : b->CreateAdd(b->CreateMul(f, aliasFrameStep(other)), aliasIndex(other, c, x, y)); } 200 Value *aliasIndex(const MatrixBuilder &other, Value *c, Value *x, Value *y, Value *f) const { return m->singleFrame() ? aliasIndex(other, c, x, y) : b->CreateAdd(b->CreateMul(f, aliasFrameStep(other)), aliasIndex(other, c, x, y)); }
185 201
186 - void deindex(Value *i, Value **c) const { *c = m->singleChannel() ? constant(0) : i; } 202 + void deindex(Value *i, Value **c) const {
  203 + *c = m->singleChannel() ? constant(0) : i;
  204 + }
187 void deindex(Value *i, Value **c, Value **x) const { 205 void deindex(Value *i, Value **c, Value **x) const {
188 Value *rem; 206 Value *rem;
189 if (m->singleColumn()) { 207 if (m->singleColumn()) {
@@ -221,8 +239,9 @@ struct MatrixBuilder @@ -221,8 +239,9 @@ struct MatrixBuilder
221 deindex(rem, c, x, y); 239 deindex(rem, c, x, y);
222 } 240 }
223 241
224 - LoadInst *load(Value *i) const { return b->CreateLoad(b->CreateGEP(getData(), i)); }  
225 - StoreInst *store(Value *i, Value *value) const { return b->CreateStore(value, b->CreateGEP(getData(), i)); } 242 + LoadInst *load(Value *i) const { return b->CreateLoad(b->CreateGEP(data(), i)); }
  243 + StoreInst *store(Value *i, Value *value) const { return b->CreateStore(value, b->CreateGEP(data(), i)); }
  244 +
226 Value *cast(Value *i, const MatrixBuilder &dst) const { return (m->type() == dst.m->type()) ? i : b->CreateCast(CastInst::getCastOpcode(i, m->isSigned(), dst.ty(), dst.m->isSigned()), i, dst.ty()); } 245 Value *cast(Value *i, const MatrixBuilder &dst) const { return (m->type() == dst.m->type()) ? i : b->CreateCast(CastInst::getCastOpcode(i, m->isSigned(), dst.ty(), dst.m->isSigned()), i, dst.ty()); }
227 Value *add(Value *i, Value *j, const Twine &name = "") const { return m->isFloating() ? b->CreateFAdd(i, j, name) : b->CreateAdd(i, j, name); } 246 Value *add(Value *i, Value *j, const Twine &name = "") const { return m->isFloating() ? b->CreateFAdd(i, j, name) : b->CreateAdd(i, j, name); }
228 Value *multiply(Value *i, Value *j, const Twine &name = "") const { return m->isFloating() ? b->CreateFMul(i, j, name) : b->CreateMul(i, j, name); } 247 Value *multiply(Value *i, Value *j, const Twine &name = "") const { return m->isFloating() ? b->CreateFMul(i, j, name) : b->CreateMul(i, j, name); }
@@ -396,12 +415,12 @@ private: @@ -396,12 +415,12 @@ private:
396 415
397 BasicBlock *getKernel = BasicBlock::Create(getGlobalContext(), "get_kernel", function); 416 BasicBlock *getKernel = BasicBlock::Create(getGlobalContext(), "get_kernel", function);
398 BasicBlock *preallocate = BasicBlock::Create(getGlobalContext(), "preallocate", function); 417 BasicBlock *preallocate = BasicBlock::Create(getGlobalContext(), "preallocate", function);
399 - Value *hashTest = builder.CreateICmpNE(mb.getHash(), builder.CreateLoad(kernelHash), "hash_fail_test"); 418 + Value *hashTest = builder.CreateICmpNE(mb.hash(), builder.CreateLoad(kernelHash), "hash_fail_test");
400 builder.CreateCondBr(hashTest, getKernel, preallocate); 419 builder.CreateCondBr(hashTest, getKernel, preallocate);
401 420
402 builder.SetInsertPoint(getKernel); 421 builder.SetInsertPoint(getKernel);
403 builder.CreateStore(kernel, kernelFunction); 422 builder.CreateStore(kernel, kernelFunction);
404 - builder.CreateStore(mb.getHash(), kernelHash); 423 + builder.CreateStore(mb.hash(), kernelHash);
405 builder.CreateBr(preallocate); 424 builder.CreateBr(preallocate);
406 builder.SetInsertPoint(preallocate); 425 builder.SetInsertPoint(preallocate);
407 Value *kernelSize = buildPreallocate(mb, nb); 426 Value *kernelSize = buildPreallocate(mb, nb);
@@ -409,7 +428,7 @@ private: @@ -409,7 +428,7 @@ private:
409 BasicBlock *allocate = BasicBlock::Create(getGlobalContext(), "allocate", function); 428 BasicBlock *allocate = BasicBlock::Create(getGlobalContext(), "allocate", function);
410 builder.CreateBr(allocate); 429 builder.CreateBr(allocate);
411 builder.SetInsertPoint(allocate); 430 builder.SetInsertPoint(allocate);
412 - nb.allocateCode(); 431 + nb.allocate();
413 432
414 BasicBlock *callKernel = BasicBlock::Create(getGlobalContext(), "call_kernel", function); 433 BasicBlock *callKernel = BasicBlock::Create(getGlobalContext(), "call_kernel", function);
415 builder.CreateBr(callKernel); 434 builder.CreateBr(callKernel);
@@ -581,8 +600,8 @@ public: @@ -581,8 +600,8 @@ public:
581 600
582 virtual Value *buildPreallocate(const MatrixBuilder &src, const MatrixBuilder &dst) const 601 virtual Value *buildPreallocate(const MatrixBuilder &src, const MatrixBuilder &dst) const
583 { 602 {
584 - dst.copyHeaderCode(src);  
585 - return dst.elementsCode(); 603 + dst.copyHeader(src);
  604 + return dst.elements();
586 } 605 }
587 606
588 private: 607 private:
@@ -739,7 +758,7 @@ class sumTransform : public UnaryTransform @@ -739,7 +758,7 @@ class sumTransform : public UnaryTransform
739 758
740 if (frames && !src.m->singleFrame()) { 759 if (frames && !src.m->singleFrame()) {
741 BasicBlock *loop, *exit; 760 BasicBlock *loop, *exit;
742 - src_t = dst.beginLoop(loops.last(), loop, exit, src.getFrames(), "src_t"); 761 + src_t = dst.beginLoop(loops.last(), loop, exit, src.frames(), "src_t");
743 loops.append(loop); 762 loops.append(loop);
744 exits.append(exit); 763 exits.append(exit);
745 } else { 764 } else {
@@ -748,7 +767,7 @@ class sumTransform : public UnaryTransform @@ -748,7 +767,7 @@ class sumTransform : public UnaryTransform
748 767
749 if (rows && !src.m->singleRow()) { 768 if (rows && !src.m->singleRow()) {
750 BasicBlock *loop, *exit; 769 BasicBlock *loop, *exit;
751 - src_y = dst.beginLoop(loops.last(), loop, exit, src.getRows(), "src_y"); 770 + src_y = dst.beginLoop(loops.last(), loop, exit, src.rows(), "src_y");
752 loops.append(loop); 771 loops.append(loop);
753 exits.append(exit); 772 exits.append(exit);
754 } else { 773 } else {
@@ -757,7 +776,7 @@ class sumTransform : public UnaryTransform @@ -757,7 +776,7 @@ class sumTransform : public UnaryTransform
757 776
758 if (columns && !src.m->singleColumn()) { 777 if (columns && !src.m->singleColumn()) {
759 BasicBlock *loop, *exit; 778 BasicBlock *loop, *exit;
760 - src_x = dst.beginLoop(loops.last(), loop, exit, src.getColumns(), "src_x"); 779 + src_x = dst.beginLoop(loops.last(), loop, exit, src.columns(), "src_x");
761 loops.append(loop); 780 loops.append(loop);
762 exits.append(exit); 781 exits.append(exit);
763 } else { 782 } else {
@@ -766,7 +785,7 @@ class sumTransform : public UnaryTransform @@ -766,7 +785,7 @@ class sumTransform : public UnaryTransform
766 785
767 if (channels && !src.m->singleChannel()) { 786 if (channels && !src.m->singleChannel()) {
768 BasicBlock *loop, *exit; 787 BasicBlock *loop, *exit;
769 - src_c = dst.beginLoop(loops.last(), loop, exit, src.getChannels(), "src_c"); 788 + src_c = dst.beginLoop(loops.last(), loop, exit, src.channels(), "src_c");
770 loops.append(loop); 789 loops.append(loop);
771 exits.append(exit); 790 exits.append(exit);
772 } else { 791 } else {