From 5eba6ee9692b48b9886a95ba82faa69cc5a7b630 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sun, 6 Dec 2009 01:57:02 +0000 Subject: [PATCH] Handle forwarding local memsets to loads. For example, we optimize this: short x(short *A) { memset(A, 1, sizeof(*A)*100); return A[42]; } to 'return 257' instead of doing the load. llvm-svn: 90695 --- lib/Transforms/Scalar/GVN.cpp | 179 ++++++++++++++++++++++++++-------- test/Transforms/GVN/rle.ll | 37 +++++++ 2 files changed, 173 insertions(+), 43 deletions(-) diff --git a/lib/Transforms/Scalar/GVN.cpp b/lib/Transforms/Scalar/GVN.cpp index cbe7add582a..d485e07e076 100644 --- a/lib/Transforms/Scalar/GVN.cpp +++ b/lib/Transforms/Scalar/GVN.cpp @@ -40,6 +40,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/GetElementPtrTypeIterator.h" +#include "llvm/Support/IRBuilder.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetData.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -987,25 +988,24 @@ static Value *GetBaseWithConstantOffset(Value *Ptr, int64_t &Offset, } -/// AnalyzeLoadFromClobberingStore - This function is called when we have a -/// memdep query of a load that ends up being a clobbering store. This means -/// that the store *may* provide bits used by the load but we can't be sure -/// because the pointers don't mustalias. Check this case to see if there is -/// anything more we can do before we give up. This returns -1 if we have to -/// give up, or a byte number in the stored value of the piece that feeds the -/// load. -static int AnalyzeLoadFromClobberingStore(LoadInst *L, StoreInst *DepSI, +/// AnalyzeLoadFromClobberingWrite - This function is called when we have a +/// memdep query of a load that ends up being a clobbering memory write (store, +/// memset, memcpy, memmove). This means that the write *may* provide bits used +/// by the load but we can't be sure because the pointers don't mustalias. +/// +/// Check this case to see if there is anything more we can do before we give +/// up. This returns -1 if we have to give up, or a byte number in the stored +/// value of the piece that feeds the load. +static int AnalyzeLoadFromClobberingWrite(LoadInst *L, Value *WritePtr, + uint64_t WriteSizeInBits, const TargetData &TD) { // If the loaded or stored value is an first class array or struct, don't try // to transform them. We need to be able to bitcast to integer. - if (isa(L->getType()) || isa(L->getType()) || - isa(DepSI->getOperand(0)->getType()) || - isa(DepSI->getOperand(0)->getType())) + if (isa(L->getType()) || isa(L->getType())) return -1; int64_t StoreOffset = 0, LoadOffset = 0; - Value *StoreBase = - GetBaseWithConstantOffset(DepSI->getPointerOperand(), StoreOffset, TD); + Value *StoreBase = GetBaseWithConstantOffset(WritePtr, StoreOffset, TD); Value *LoadBase = GetBaseWithConstantOffset(L->getPointerOperand(), LoadOffset, TD); if (StoreBase != LoadBase) @@ -1018,8 +1018,8 @@ static int AnalyzeLoadFromClobberingStore(LoadInst *L, StoreInst *DepSI, #if 0 errs() << "STORE/LOAD DEP WITH COMMON POINTER MISSED:\n" << "Base = " << *StoreBase << "\n" - << "Store Ptr = " << *DepSI->getPointerOperand() << "\n" - << "Store Offs = " << StoreOffset << " - " << *DepSI << "\n" + << "Store Ptr = " << *WritePtr << "\n" + << "Store Offs = " << StoreOffset << "\n" << "Load Ptr = " << *L->getPointerOperand() << "\n" << "Load Offs = " << LoadOffset << " - " << *L << "\n\n"; errs() << "'" << L->getParent()->getParent()->getName() << "'" @@ -1033,12 +1033,11 @@ static int AnalyzeLoadFromClobberingStore(LoadInst *L, StoreInst *DepSI, // must have gotten confused. // FIXME: Investigate cases where this bails out, e.g. rdar://7238614. Then // remove this check, as it is duplicated with what we have below. - uint64_t StoreSize = TD.getTypeSizeInBits(DepSI->getOperand(0)->getType()); uint64_t LoadSize = TD.getTypeSizeInBits(L->getType()); - if ((StoreSize & 7) | (LoadSize & 7)) + if ((WriteSizeInBits & 7) | (LoadSize & 7)) return -1; - StoreSize >>= 3; // Convert to bytes. + uint64_t StoreSize = WriteSizeInBits >> 3; // Convert to bytes. LoadSize >>= 3; @@ -1052,8 +1051,8 @@ static int AnalyzeLoadFromClobberingStore(LoadInst *L, StoreInst *DepSI, #if 0 errs() << "STORE LOAD DEP WITH COMMON BASE:\n" << "Base = " << *StoreBase << "\n" - << "Store Ptr = " << *DepSI->getPointerOperand() << "\n" - << "Store Offs = " << StoreOffset << " - " << *DepSI << "\n" + << "Store Ptr = " << *WritePtr << "\n" + << "Store Offs = " << StoreOffset << "\n" << "Load Ptr = " << *L->getPointerOperand() << "\n" << "Load Offs = " << LoadOffset << " - " << *L << "\n\n"; errs() << "'" << L->getParent()->getParent()->getName() << "'" @@ -1075,6 +1074,34 @@ static int AnalyzeLoadFromClobberingStore(LoadInst *L, StoreInst *DepSI, return LoadOffset-StoreOffset; } +/// AnalyzeLoadFromClobberingStore - This function is called when we have a +/// memdep query of a load that ends up being a clobbering store. +static int AnalyzeLoadFromClobberingStore(LoadInst *L, StoreInst *DepSI, + const TargetData &TD) { + // Cannot handle reading from store of first-class aggregate yet. + if (isa(DepSI->getOperand(0)->getType()) || + isa(DepSI->getOperand(0)->getType())) + return -1; + + Value *StorePtr = DepSI->getPointerOperand(); + uint64_t StoreSize = TD.getTypeSizeInBits(StorePtr->getType()); + return AnalyzeLoadFromClobberingWrite(L, StorePtr, StoreSize, TD); +} + +static int AnalyzeLoadFromClobberingMemInst(LoadInst *L, MemIntrinsic *MI, + const TargetData &TD) { + // If the mem operation is a non-constant size, we can't handle it. + ConstantInt *SizeCst = dyn_cast(MI->getLength()); + if (SizeCst == 0) return -1; + uint64_t MemSizeInBits = SizeCst->getZExtValue()*8; + + if (MI->getIntrinsicID() == Intrinsic::memset) + return AnalyzeLoadFromClobberingWrite(L, MI->getDest(), MemSizeInBits, TD); + + // Unhandled memcpy/memmove. + return -1; +} + /// GetStoreValueForLoad - This function is called when we have a /// memdep query of a load that ends up being a clobbering store. This means @@ -1100,11 +1127,10 @@ static Value *GetStoreValueForLoad(Value *SrcVal, unsigned Offset, // Shift the bits to the least significant depending on endianness. unsigned ShiftAmt; - if (TD.isLittleEndian()) { + if (TD.isLittleEndian()) ShiftAmt = Offset*8; - } else { + else ShiftAmt = (StoreSize-LoadSize-Offset)*8; - } if (ShiftAmt) SrcVal = BinaryOperator::CreateLShr(SrcVal, @@ -1117,6 +1143,52 @@ static Value *GetStoreValueForLoad(Value *SrcVal, unsigned Offset, return CoerceAvailableValueToLoadType(SrcVal, LoadTy, InsertPt, TD); } +/// GetMemInstValueForLoad - This function is called when we have a +/// memdep query of a load that ends up being a clobbering mem intrinsic. +static Value *GetMemInstValueForLoad(MemIntrinsic *SrcInst, unsigned Offset, + const Type *LoadTy, Instruction *InsertPt, + const TargetData &TD){ + LLVMContext &Ctx = LoadTy->getContext(); + uint64_t LoadSize = TD.getTypeSizeInBits(LoadTy)/8; + + IRBuilder<> Builder(InsertPt->getParent(), InsertPt); + + // We know that this method is only called when the mem transfer fully + // provides the bits for the load. + if (MemSetInst *MSI = dyn_cast(SrcInst)) { + // memset(P, 'x', 1234) -> splat('x'), even if x is a variable, and + // independently of what the offset is. + Value *Val = MSI->getValue(); + if (LoadSize != 1) + Val = Builder.CreateZExt(Val, IntegerType::get(Ctx, LoadSize*8)); + + Value *OneElt = Val; + + // Splat the value out to the right number of bits. + for (unsigned NumBytesSet = 1; NumBytesSet != LoadSize; ) { + // If we can double the number of bytes set, do it. + if (NumBytesSet*2 <= LoadSize) { + Value *ShVal = Builder.CreateShl(Val, NumBytesSet*8); + Val = Builder.CreateOr(Val, ShVal); + NumBytesSet <<= 1; + continue; + } + + // Otherwise insert one byte at a time. + Value *ShVal = Builder.CreateShl(Val, 1*8); + Val = Builder.CreateOr(OneElt, ShVal); + ++NumBytesSet; + } + + return CoerceAvailableValueToLoadType(Val, LoadTy, InsertPt, TD); + } + + // ABORT; + return 0; +} + + + struct AvailableValueInBlock { /// BB - The basic block in question. BasicBlock *BB; @@ -1251,8 +1323,21 @@ bool GVN::processNonLocalLoad(LoadInst *LI, } } } + +#if 0 + // If the clobbering value is a memset/memcpy/memmove, see if we can + // forward a value on from it. + if (MemIntrinsic *DepMI = dyn_cast(Dep.getInst())) { + if (TD == 0) + TD = getAnalysisIfAvailable(); + if (TD) { + int Offset = AnalyzeLoadFromClobberingMemInst(L, DepMI, *TD); + if (Offset != -1) + AvailVal = GetMemInstValueForLoad(DepMI, Offset, L->getType(), L,*TD); + } + } +#endif - // FIXME: Handle memset/memcpy. UnavailableBlocks.push_back(DepBB); continue; } @@ -1526,11 +1611,6 @@ bool GVN::processLoad(LoadInst *L, SmallVectorImpl &toErase) { // If the value isn't available, don't do anything! if (Dep.isClobber()) { - // FIXME: We should handle memset/memcpy/memmove as dependent instructions - // to forward the value if available. - //if (isa(Dep.getInst())) - //errs() << "LOAD DEPENDS ON MEM: " << *L << "\n" << *Dep.getInst()<<"\n\n"; - // Check to see if we have something like this: // store i32 123, i32* %P // %A = bitcast i32* %P to i8* @@ -1541,25 +1621,38 @@ bool GVN::processLoad(LoadInst *L, SmallVectorImpl &toErase) { // a common base + constant offset, and if the previous store (or memset) // completely covers this load. This sort of thing can happen in bitfield // access code. + Value *AvailVal = 0; if (StoreInst *DepSI = dyn_cast(Dep.getInst())) if (const TargetData *TD = getAnalysisIfAvailable()) { int Offset = AnalyzeLoadFromClobberingStore(L, DepSI, *TD); - if (Offset != -1) { - Value *AvailVal = GetStoreValueForLoad(DepSI->getOperand(0), Offset, - L->getType(), L, *TD); - DEBUG(errs() << "GVN COERCED STORE BITS:\n" << *DepSI << '\n' - << *AvailVal << '\n' << *L << "\n\n\n"); - - // Replace the load! - L->replaceAllUsesWith(AvailVal); - if (isa(AvailVal->getType())) - MD->invalidateCachedPointerInfo(AvailVal); - toErase.push_back(L); - NumGVNLoad++; - return true; - } + if (Offset != -1) + AvailVal = GetStoreValueForLoad(DepSI->getOperand(0), Offset, + L->getType(), L, *TD); } + // If the clobbering value is a memset/memcpy/memmove, see if we can forward + // a value on from it. + if (MemIntrinsic *DepMI = dyn_cast(Dep.getInst())) { + if (const TargetData *TD = getAnalysisIfAvailable()) { + int Offset = AnalyzeLoadFromClobberingMemInst(L, DepMI, *TD); + if (Offset != -1) + AvailVal = GetMemInstValueForLoad(DepMI, Offset, L->getType(), L,*TD); + } + } + + if (AvailVal) { + DEBUG(errs() << "GVN COERCED INST:\n" << *Dep.getInst() << '\n' + << *AvailVal << '\n' << *L << "\n\n\n"); + + // Replace the load! + L->replaceAllUsesWith(AvailVal); + if (isa(AvailVal->getType())) + MD->invalidateCachedPointerInfo(AvailVal); + toErase.push_back(L); + NumGVNLoad++; + return true; + } + DEBUG( // fast print dep, using operator<< on instruction would be too slow errs() << "GVN: load "; diff --git a/test/Transforms/GVN/rle.ll b/test/Transforms/GVN/rle.ll index 71eb194d3a8..01d1ebc1a98 100644 --- a/test/Transforms/GVN/rle.ll +++ b/test/Transforms/GVN/rle.ll @@ -131,6 +131,43 @@ define i8* @coerce_mustalias7(i64 %V, i64* %P) { ; CHECK: ret i8* } +; memset -> i16 forwarding. +define signext i16 @memset_to_i16_local(i16* %A) nounwind ssp { +entry: + %conv = bitcast i16* %A to i8* + tail call void @llvm.memset.i64(i8* %conv, i8 1, i64 200, i32 1) + %arrayidx = getelementptr inbounds i16* %A, i64 42 + %tmp2 = load i16* %arrayidx + ret i16 %tmp2 +; CHECK: @memset_to_i16_local +; CHECK-NOT: load +; CHECK: ret i16 257 +} + +; memset -> float forwarding. +define float @memset_to_float_local(float* %A, i8 %Val) nounwind ssp { +entry: + %conv = bitcast float* %A to i8* ; [#uses=1] + tail call void @llvm.memset.i64(i8* %conv, i8 %Val, i64 400, i32 1) + %arrayidx = getelementptr inbounds float* %A, i64 42 ; [#uses=1] + %tmp2 = load float* %arrayidx ; [#uses=1] + ret float %tmp2 +; CHECK: @memset_to_float_local +; CHECK-NOT: load +; CHECK: zext +; CHECK-NEXT: shl +; CHECK-NEXT: or +; CHECK-NEXT: shl +; CHECK-NEXT: or +; CHECK-NEXT: bitcast +; CHECK-NEXT: ret float +} + +declare void @llvm.memset.i64(i8* nocapture, i8, i64, i32) nounwind + + + + ;; non-local i32/float -> i8 load forwarding. define i8 @coerce_mustalias_nonlocal0(i32* %P, i1 %cond) { %P2 = bitcast i32* %P to float*