diff options
Diffstat (limited to 'backend/src/llvm/llvm_device_enqueue.cpp')
-rw-r--r-- | backend/src/llvm/llvm_device_enqueue.cpp | 42 |
1 files changed, 22 insertions, 20 deletions
diff --git a/backend/src/llvm/llvm_device_enqueue.cpp b/backend/src/llvm/llvm_device_enqueue.cpp index 9a0fb46f..58aa6817 100644 --- a/backend/src/llvm/llvm_device_enqueue.cpp +++ b/backend/src/llvm/llvm_device_enqueue.cpp @@ -29,6 +29,7 @@ namespace gbe { BitCastInst* bt = dyn_cast<BitCastInst>(I); if (bt == NULL) return NULL; +//bt->dump(); Type* type = bt->getOperand(0)->getType(); if(!type->isPointerTy()) @@ -112,7 +113,8 @@ namespace gbe { ValueToValueMapTy VMap; for (Function::arg_iterator I = Fn->arg_begin(), E = Fn->arg_end(); I != E; ++I) { PointerType *ty = dyn_cast<PointerType>(I->getType()); - if(ty && ty->getAddressSpace() == 0) //Foce set the address space to global + //Foce set the address space to global + if(ty && (ty->getAddressSpace() == 0 || ty->getAddressSpace() == 4)) ty = PointerType::get(ty->getPointerElementType(), 1); ParamTys.push_back(ty); } @@ -252,12 +254,13 @@ namespace gbe { if(gep == NULL) continue; - BitCastInst* fnPointer = dyn_cast<BitCastInst>(gep->getOperand(0)); - if(fnPointer == NULL) + Value *fnPointer = gep->getOperand(0)->stripPointerCasts(); + + if(fnPointer == gep->getOperand(0)) continue; - if(BitCastInst* bt = dyn_cast<BitCastInst>(fnPointer->getOperand(0))) { - std::string fnName = blocks[bt->getOperand(0)]; + if(blocks.find(fnPointer) != blocks.end()) { + std::string fnName = blocks[fnPointer]; Function* f = mod->getFunction(fnName); CallInst *newCI = builder.CreateCall(f, args); CI->replaceAllUsesWith(newCI); @@ -266,7 +269,7 @@ namespace gbe { } //the function is global variable - if(GlobalVariable* gv = dyn_cast<GlobalVariable>(fnPointer->getOperand(0))) { + if(GlobalVariable* gv = dyn_cast<GlobalVariable>(fnPointer)) { Constant *c = gv->getInitializer(); ConstantExpr *expr = dyn_cast<ConstantExpr>(c->getOperand(3)); BitCastInst *bt = dyn_cast<BitCastInst>(expr->getAsInstruction()); @@ -277,7 +280,7 @@ namespace gbe { continue; } - ld = dyn_cast<LoadInst>(fnPointer->getOperand(0)); + ld = dyn_cast<LoadInst>(fnPointer); if(ld == NULL) continue; @@ -304,9 +307,7 @@ namespace gbe { User *theUser = iter->getUser(); #endif if(StoreInst *st = dyn_cast<StoreInst>(theUser)) { - bt = dyn_cast<BitCastInst>(st->getValueOperand()); - if(bt) - v = bt->getOperand(0); + v = st->getValueOperand()->stripPointerCasts(); } } if(blocks.find(v) == blocks.end()) { @@ -339,9 +340,7 @@ namespace gbe { Type *type = CI->getArgOperand(block_index)->getType(); if(type->isIntegerTy()) block_index = 6; - Value *block = CI->getArgOperand(block_index); - while(isa<BitCastInst>(block)) - block = dyn_cast<BitCastInst>(block)->getOperand(0); + Value *block = CI->getArgOperand(block_index)->stripPointerCasts(); LoadInst *ld = dyn_cast<LoadInst>(block); Value *v = NULL; if(ld) { @@ -353,9 +352,7 @@ namespace gbe { User *theUser = iter->getUser(); #endif if(StoreInst *st = dyn_cast<StoreInst>(theUser)) { - BitCastInst *bt = dyn_cast<BitCastInst>(st->getValueOperand()); - if(bt) - v = bt->getOperand(0); + v = st->getValueOperand()->stripPointerCasts(); } } if(blocks.find(v) == blocks.end()) { @@ -378,15 +375,20 @@ namespace gbe { if( fn->isVarArg() ) { //enqueue function with slm, convert to __gen_enqueue_kernel_slm call //store the slm information to a alloca address. - int start = block_index + 1; + int start = block_index + 1 + 1; //the first is count, skip int count = CI->getNumArgOperands() - start; Type *intTy = IntegerType::get(mod->getContext(), 32); + Type *int64Ty = IntegerType::get(mod->getContext(), 64); AllocaInst *AI = builder.CreateAlloca(intTy, ConstantInt::get(intTy, count)); for(uint32_t i = start; i < CI->getNumArgOperands(); i++) { Value *ptr = builder.CreateGEP(AI, ConstantInt::get(intTy, i-start)); - builder.CreateStore(CI->getArgOperand(i), ptr); + Value *argSize = CI->getArgOperand(i); + if (argSize->getType() == int64Ty) { + argSize = builder.CreateTrunc(argSize, intTy); + } + builder.CreateStore(argSize, ptr); } SmallVector<Value*, 16> args(CI->op_begin(), CI->op_begin() + 3); args.push_back(CI->getArgOperand(block_index)); @@ -394,8 +396,8 @@ namespace gbe { args.push_back(AI); std::vector<Type *> ParamTys; - for (Value** I = args.begin(); I != args.end(); ++I) - ParamTys.push_back((*I)->getType()); + for (Value** iter = args.begin(); iter != args.end(); ++iter) + ParamTys.push_back((*iter)->getType()); CallInst* newCI = builder.CreateCall(cast<llvm::Function>(mod->getOrInsertFunction( "__gen_enqueue_kernel_slm", FunctionType::get(intTy, ParamTys, false))), args); CI->replaceAllUsesWith(newCI); |