From 5dd5096dfac2fa5adceae6bfee6d27671fec2eef Mon Sep 17 00:00:00 2001 From: Alexey Sotkin Date: Thu, 13 Jul 2017 19:34:07 +0300 Subject: Insert bitcast during translation of llvm.memset (#219) This commit supplements 94ed068c81cac48eed41e8851115520949664e2c "Changing translation of llvm.memset intrinsic (#217)" With arguments bitcasted to i8* OpCopyMemorySized can be correctly translated to llvm.memcpy, i.e. llvm.memcpy.p0i8.p0i8.i32(i8*, i8*, ...) instead of llvm.memcpy.p0i8.p0i8.i32(i8*, [n x i8]*, ...) --- lib/SPIRV/SPIRVReader.cpp | 24 ++++++++++++++---------- lib/SPIRV/SPIRVWriter.cpp | 9 ++++++--- lib/SPIRV/libSPIRV/SPIRVEntry.h | 2 +- test/transcoding/llvm.memset.ll | 4 +++- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/lib/SPIRV/SPIRVReader.cpp b/lib/SPIRV/SPIRVReader.cpp index a8d9f08..8bd80a7 100644 --- a/lib/SPIRV/SPIRVReader.cpp +++ b/lib/SPIRV/SPIRVReader.cpp @@ -1668,16 +1668,20 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F, << "i8"; Value *Src = nullptr; // If we copy from zero-initialized array, we can optimize it to llvm.memset - if (BC->getSource()->isVariable()) { - auto *Init = static_cast(BC->getSource())->getInitializer(); - if (isa(Init)) { - SPIRVType *Ty = static_cast(Init)->getType(); - if (isa(Ty)) { - SPIRVTypeArray *AT = static_cast(Ty); - SrcTy = transType(AT->getArrayElementType()); - assert(SrcTy->isIntegerTy(8)); - Src = ConstantInt::get(SrcTy, 0); - FuncName = "llvm.memset"; + if (BC->getSource()->getOpCode() == OpBitcast) { + SPIRVValue *Source = + static_cast(BC->getSource())->getOperand(0); + if (Source->isVariable()) { + auto *Init = static_cast(Source)->getInitializer(); + if (isa(Init)) { + SPIRVType *Ty = static_cast(Init)->getType(); + if (isa(Ty)) { + SPIRVTypeArray *AT = static_cast(Ty); + SrcTy = transType(AT->getArrayElementType()); + assert(SrcTy->isIntegerTy(8)); + Src = ConstantInt::get(SrcTy, 0); + FuncName = "llvm.memset"; + } } } } diff --git a/lib/SPIRV/SPIRVWriter.cpp b/lib/SPIRV/SPIRVWriter.cpp index 9bc2937..604ab71 100644 --- a/lib/SPIRV/SPIRVWriter.cpp +++ b/lib/SPIRV/SPIRVWriter.cpp @@ -1302,9 +1302,12 @@ LLVMToSPIRV::transIntrinsicInst(IntrinsicInst *II, SPIRVBasicBlock *BB) { SPIRVTypeArray *CompositeTy = static_cast(transType(AT)); SPIRVValue *Init = BM->addNullConstant(CompositeTy); SPIRVType *VarTy = transType(PointerType::get(AT, SPIRV::SPIRAS_Constant)); - SPIRVValue *Source = BM->addVariable(VarTy,/*isConstant*/true, - spv::LinkageTypeInternal, Init, "", - StorageClassUniformConstant, nullptr); + SPIRVValue *Var = BM->addVariable(VarTy,/*isConstant*/true, + spv::LinkageTypeInternal, Init, "", + StorageClassUniformConstant, nullptr); + SPIRVType *SourceTy = transType(PointerType::get(Val->getType(), + SPIRV::SPIRAS_Constant)); + SPIRVValue *Source = BM->addUnaryInst(OpBitcast, SourceTy, Var, BB); SPIRVValue *Target = transValue(MSI->getRawDest(), BB); return BM->addCopyMemorySizedInst(Target, Source, CompositeTy->getLength(), getMemoryAccess(MSI), BB); diff --git a/lib/SPIRV/libSPIRV/SPIRVEntry.h b/lib/SPIRV/libSPIRV/SPIRVEntry.h index 004e3de..8cdb3dd 100644 --- a/lib/SPIRV/libSPIRV/SPIRVEntry.h +++ b/lib/SPIRV/libSPIRV/SPIRVEntry.h @@ -698,7 +698,7 @@ T* bcast(SPIRVEntry *E) { template bool isa(SPIRVEntry *E) { - return E->getOpCode() == OC; + return E ? E->getOpCode() == OC : false; } // ToDo: The following typedef's are place holders for SPIRV entity classes diff --git a/test/transcoding/llvm.memset.ll b/test/transcoding/llvm.memset.ll index 22eb041..b41fa96 100644 --- a/test/transcoding/llvm.memset.ll +++ b/test/transcoding/llvm.memset.ll @@ -20,10 +20,12 @@ ; CHECK-SPIRV: Constant {{[0-9]+}} [[Len:[0-9]+]] 12 ; CHECK-SPIRV: TypePointer [[Int8Ptr:[0-9]+]] 8 [[Int8]] ; CHECK-SPIRV: TypeArray [[Int8x12:[0-9]+]] [[Int8]] [[Len]] +; CHECK-SPIRV: TypePointer [[Int8PtrConst:[0-9]+]] 0 [[Int8]] ; CHECK-SPIRV: ConstantNull [[Int8x12]] [[Init:[0-9]+]] -; CHECK-SPIRV: Variable {{[0-9]+}} [[Source:[0-9]+]] 0 [[Init]] +; CHECK-SPIRV: Variable {{[0-9]+}} [[Val:[0-9]+]] 0 [[Init]] ; CHECK-SPIRV: Bitcast [[Int8Ptr]] [[Target:[0-9]+]] {{[0-9]+}} +; CHECK-SPIRV: Bitcast [[Int8PtrConst]] [[Source:[0-9]+]] [[Val]] ; CHECK-SPIRV: CopyMemorySized [[Target]] [[Source]] [[Len]] 2 4 -- cgit v1.2.3