diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index e6baed1779c..c6dabf8870f 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -5819,6 +5819,74 @@ Value *llvm::SimplifyFreezeInst(Value *Op0, const SimplifyQuery &Q) { return ::SimplifyFreezeInst(Op0, Q); } +static Constant *ConstructLoadOperandConstant(Value *Op) { + SmallVector Worklist; + Worklist.push_back(Op); + while (true) { + Value *CurOp = Worklist.back(); + if (isa(CurOp)) + break; + if (auto *BC = dyn_cast(CurOp)) { + Worklist.push_back(BC->getOperand(0)); + } else if (auto *GEP = dyn_cast(CurOp)) { + for (unsigned I = 1; I != GEP->getNumOperands(); ++I) { + if (!isa(GEP->getOperand(I))) + return nullptr; + } + Worklist.push_back(GEP->getOperand(0)); + } else if (auto *II = dyn_cast(CurOp)) { + if (II->isLaunderOrStripInvariantGroup()) + Worklist.push_back(II->getOperand(0)); + else + return nullptr; + } else { + return nullptr; + } + } + + Constant *NewOp = cast(Worklist.pop_back_val()); + while (!Worklist.empty()) { + Value *CurOp = Worklist.pop_back_val(); + if (isa(CurOp)) { + NewOp = ConstantExpr::getBitCast(NewOp, CurOp->getType()); + } else if (auto *GEP = dyn_cast(CurOp)) { + SmallVector Idxs; + Idxs.reserve(GEP->getNumOperands() - 1); + for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I) { + Idxs.push_back(cast(GEP->getOperand(I))); + } + NewOp = ConstantExpr::getGetElementPtr(GEP->getSourceElementType(), NewOp, + Idxs, GEP->isInBounds(), + GEP->getInRangeIndex()); + } else { + assert(isa(CurOp) && + cast(CurOp)->isLaunderOrStripInvariantGroup() && + "expected invariant group intrinsic"); + NewOp = ConstantExpr::getBitCast(NewOp, CurOp->getType()); + } + } + return NewOp; +} + +static Value *SimplifyLoadInst(LoadInst *LI, const SimplifyQuery &Q) { + if (LI->isVolatile()) + return nullptr; + + if (auto *C = ConstantFoldInstruction(LI, Q.DL)) + return C; + + // The following only catches more cases than ConstantFoldInstruction() if the + // load operand wasn't a constant. Specifically, invariant.group intrinsics. + if (isa(LI->getPointerOperand())) + return nullptr; + + if (auto *C = dyn_cast_or_null( + ConstructLoadOperandConstant(LI->getPointerOperand()))) + return ConstantFoldLoadFromConstPtr(C, LI->getType(), Q.DL); + + return nullptr; +} + /// See if we can compute a simplified version of this instruction. /// If not, this returns null. @@ -5975,6 +6043,9 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, // No simplifications for Alloca and it can't be constant folded. Result = nullptr; break; + case Instruction::Load: + Result = SimplifyLoadInst(cast(I), Q); + break; } /// If called on unreachable code, the above logic may report that the diff --git a/test/Transforms/InstSimplify/invariant.group-load.ll b/test/Transforms/InstSimplify/invariant.group-load.ll index 72cb36d2eee..f1ee1528e8a 100644 --- a/test/Transforms/InstSimplify/invariant.group-load.ll +++ b/test/Transforms/InstSimplify/invariant.group-load.ll @@ -9,11 +9,7 @@ declare i8* @llvm.launder.invariant.group.p0i8(i8* %p) define i64 @f() { ; CHECK-LABEL: @f( -; CHECK-NEXT: [[A:%.*]] = call i8* @llvm.strip.invariant.group.p0i8(i8* bitcast ({ i64, i64 }* @A to i8*)) -; CHECK-NEXT: [[B:%.*]] = getelementptr i8, i8* [[A]], i32 8 -; CHECK-NEXT: [[C:%.*]] = bitcast i8* [[B]] to i64* -; CHECK-NEXT: [[D:%.*]] = load i64, i64* [[C]], align 4 -; CHECK-NEXT: ret i64 [[D]] +; CHECK-NEXT: ret i64 3 ; %p = bitcast { i64, i64 }* @A to i8* %a = call i8* @llvm.strip.invariant.group.p0i8(i8* %p) @@ -25,11 +21,7 @@ define i64 @f() { define i64 @g() { ; CHECK-LABEL: @g( -; CHECK-NEXT: [[A:%.*]] = call i8* @llvm.launder.invariant.group.p0i8(i8* bitcast ({ i64, i64 }* @A to i8*)) -; CHECK-NEXT: [[B:%.*]] = getelementptr i8, i8* [[A]], i32 8 -; CHECK-NEXT: [[C:%.*]] = bitcast i8* [[B]] to i64* -; CHECK-NEXT: [[D:%.*]] = load i64, i64* [[C]], align 4 -; CHECK-NEXT: ret i64 [[D]] +; CHECK-NEXT: ret i64 3 ; %p = bitcast { i64, i64 }* @A to i8* %a = call i8* @llvm.launder.invariant.group.p0i8(i8* %p)