From 86cba7bf5e26004a583d117666980ca5df975056 Mon Sep 17 00:00:00 2001 From: Johannes Doerfert Date: Thu, 15 Jul 2021 17:40:24 -0500 Subject: [PATCH] [InstSimplify] Expose generic interface for replaced operand simplification Users, especially the Attributor, might replace multiple operands at once. The actual implementation of simplifyWithOpReplaced is able to handle that just fine, the interface was simply not allowing to replace more than one operand at a time. This is exposing a more generic interface without intended changes for existing code. Differential Revision: https://reviews.llvm.org/D106189 --- include/llvm/Analysis/InstructionSimplify.h | 15 +- lib/Analysis/InstructionSimplify.cpp | 160 ++++++++++---------- 2 files changed, 92 insertions(+), 83 deletions(-) diff --git a/include/llvm/Analysis/InstructionSimplify.h b/include/llvm/Analysis/InstructionSimplify.h index 3a3b02601a5..efaf1847276 100644 --- a/include/llvm/Analysis/InstructionSimplify.h +++ b/include/llvm/Analysis/InstructionSimplify.h @@ -145,8 +145,7 @@ struct SimplifyQuery { // Please use the SimplifyQuery versions in new code. /// Given operand for an FNeg, fold the result or return null. -Value *SimplifyFNegInst(Value *Op, FastMathFlags FMF, - const SimplifyQuery &Q); +Value *SimplifyFNegInst(Value *Op, FastMathFlags FMF, const SimplifyQuery &Q); /// Given operands for an Add, fold the result or return null. Value *SimplifyAddInst(Value *LHS, Value *RHS, bool isNSW, bool isNUW, @@ -297,8 +296,8 @@ Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, /// Given operands for a BinaryOperator, fold the result or return null. /// Try to use FastMathFlags when folding the result. -Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, - FastMathFlags FMF, const SimplifyQuery &Q); +Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, FastMathFlags FMF, + const SimplifyQuery &Q); /// Given a callsite, fold the result or return null. Value *SimplifyCall(CallBase *Call, const SimplifyQuery &Q); @@ -312,6 +311,13 @@ Value *SimplifyFreezeInst(Value *Op, const SimplifyQuery &Q); Value *SimplifyInstruction(Instruction *I, const SimplifyQuery &Q, OptimizationRemarkEmitter *ORE = nullptr); +/// Like \p SimplifyInstruction but the operands of \p I are replaced with +/// \p NewOps. Returns a simplified value, or null if none was found. +Value * +SimplifyInstructionWithOperands(Instruction *I, ArrayRef NewOps, + const SimplifyQuery &Q, + OptimizationRemarkEmitter *ORE = nullptr); + /// See if V simplifies when its operand Op is replaced with RepOp. If not, /// return null. /// AllowRefinement specifies whether the simplification can be a refinement @@ -345,4 +351,3 @@ const SimplifyQuery getBestSimplifyQuery(LoopStandardAnalysisResults &, } // end namespace llvm #endif - diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index 4a310a34917..23083bc8178 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -17,6 +17,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/InstructionSimplify.h" + +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" @@ -4567,7 +4569,8 @@ Value *llvm::SimplifyExtractElementInst(Value *Vec, Value *Idx, } /// See if we can fold the given phi. If not, returns null. -static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) { +static Value *SimplifyPHINode(PHINode *PN, ArrayRef IncomingValues, + const SimplifyQuery &Q) { // WARNING: no matter how worthwhile it may seem, we can not perform PHI CSE // here, because the PHI we may succeed simplifying to was not // def-reachable from the original PHI! @@ -4576,7 +4579,7 @@ static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) { // with the common value. Value *CommonValue = nullptr; bool HasUndefInput = false; - for (Value *Incoming : PN->incoming_values()) { + for (Value *Incoming : IncomingValues) { // If the incoming value is the phi node itself, it can safely be skipped. if (Incoming == PN) continue; if (Q.isUndefValue(Incoming)) { @@ -6040,21 +6043,19 @@ static Constant *ConstructLoadOperandConstant(Value *Op) { return NewOp; } -static Value *SimplifyLoadInst(LoadInst *LI, const SimplifyQuery &Q) { +static Value *SimplifyLoadInst(LoadInst *LI, Value *PtrOp, + const SimplifyQuery &Q) { if (LI->isVolatile()) return nullptr; - if (auto *C = ConstantFoldInstruction(LI, Q.DL)) - return C; + // Try to make the load operand a constant, specifically handle + // invariant.group intrinsics. + auto *PtrOpC = dyn_cast(PtrOp); + if (!PtrOpC) + PtrOpC = ConstructLoadOperandConstant(PtrOp); - // The following only catches more cases than ConstantFoldInstruction() if the - // load operand wasn't a constant. Specifically, invariant.group intrinsics. - if (isa(LI->getPointerOperand())) - return nullptr; - - if (auto *C = dyn_cast_or_null( - ConstructLoadOperandConstant(LI->getPointerOperand()))) - return ConstantFoldLoadFromConstPtr(C, LI->getType(), Q.DL); + if (PtrOpC) + return ConstantFoldLoadFromConstPtr(PtrOpC, LI->getType(), Q.DL); return nullptr; } @@ -6062,161 +6063,149 @@ static Value *SimplifyLoadInst(LoadInst *LI, const SimplifyQuery &Q) { /// See if we can compute a simplified version of this instruction. /// If not, this returns null. -Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, - OptimizationRemarkEmitter *ORE) { +static Value *simplifyInstructionWithOperands(Instruction *I, + ArrayRef NewOps, + const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); - Value *Result; + Value *Result = nullptr; switch (I->getOpcode()) { default: - Result = ConstantFoldInstruction(I, Q.DL, Q.TLI); + if (llvm::all_of(NewOps, [](Value *V) { return isa(V); })) { + SmallVector NewConstOps(NewOps.size()); + transform(NewOps, NewConstOps.begin(), + [](Value *V) { return cast(V); }); + Result = ConstantFoldInstOperands(I, NewConstOps, Q.DL, Q.TLI); + } break; case Instruction::FNeg: - Result = SimplifyFNegInst(I->getOperand(0), I->getFastMathFlags(), Q); + Result = SimplifyFNegInst(NewOps[0], I->getFastMathFlags(), Q); break; case Instruction::FAdd: - Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFAddInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Add: - Result = - SimplifyAddInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.hasNoSignedWrap(cast(I)), - Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); + Result = SimplifyAddInst( + NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast(I)), + Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); break; case Instruction::FSub: - Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFSubInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Sub: - Result = - SimplifySubInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.hasNoSignedWrap(cast(I)), - Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); + Result = SimplifySubInst( + NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast(I)), + Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); break; case Instruction::FMul: - Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFMulInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Mul: - Result = SimplifyMulInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyMulInst(NewOps[0], NewOps[1], Q); break; case Instruction::SDiv: - Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifySDivInst(NewOps[0], NewOps[1], Q); break; case Instruction::UDiv: - Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyUDivInst(NewOps[0], NewOps[1], Q); break; case Instruction::FDiv: - Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFDivInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::SRem: - Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifySRemInst(NewOps[0], NewOps[1], Q); break; case Instruction::URem: - Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyURemInst(NewOps[0], NewOps[1], Q); break; case Instruction::FRem: - Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), Q); + Result = SimplifyFRemInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Shl: - Result = - SimplifyShlInst(I->getOperand(0), I->getOperand(1), - Q.IIQ.hasNoSignedWrap(cast(I)), - Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); + Result = SimplifyShlInst( + NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast(I)), + Q.IIQ.hasNoUnsignedWrap(cast(I)), Q); break; case Instruction::LShr: - Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1), + Result = SimplifyLShrInst(NewOps[0], NewOps[1], Q.IIQ.isExact(cast(I)), Q); break; case Instruction::AShr: - Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1), + Result = SimplifyAShrInst(NewOps[0], NewOps[1], Q.IIQ.isExact(cast(I)), Q); break; case Instruction::And: - Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyAndInst(NewOps[0], NewOps[1], Q); break; case Instruction::Or: - Result = SimplifyOrInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyOrInst(NewOps[0], NewOps[1], Q); break; case Instruction::Xor: - Result = SimplifyXorInst(I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyXorInst(NewOps[0], NewOps[1], Q); break; case Instruction::ICmp: - Result = SimplifyICmpInst(cast(I)->getPredicate(), - I->getOperand(0), I->getOperand(1), Q); + Result = SimplifyICmpInst(cast(I)->getPredicate(), NewOps[0], + NewOps[1], Q); break; case Instruction::FCmp: - Result = - SimplifyFCmpInst(cast(I)->getPredicate(), I->getOperand(0), - I->getOperand(1), I->getFastMathFlags(), Q); + Result = SimplifyFCmpInst(cast(I)->getPredicate(), NewOps[0], + NewOps[1], I->getFastMathFlags(), Q); break; case Instruction::Select: - Result = SimplifySelectInst(I->getOperand(0), I->getOperand(1), - I->getOperand(2), Q); + Result = SimplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q); break; case Instruction::GetElementPtr: { - SmallVector Ops(I->operands()); Result = SimplifyGEPInst(cast(I)->getSourceElementType(), - Ops, Q); + NewOps, Q); break; } case Instruction::InsertValue: { InsertValueInst *IV = cast(I); - Result = SimplifyInsertValueInst(IV->getAggregateOperand(), - IV->getInsertedValueOperand(), - IV->getIndices(), Q); + Result = SimplifyInsertValueInst(NewOps[0], NewOps[1], IV->getIndices(), Q); break; } case Instruction::InsertElement: { - auto *IE = cast(I); - Result = SimplifyInsertElementInst(IE->getOperand(0), IE->getOperand(1), - IE->getOperand(2), Q); + Result = SimplifyInsertElementInst(NewOps[0], NewOps[1], NewOps[2], Q); break; } case Instruction::ExtractValue: { auto *EVI = cast(I); - Result = SimplifyExtractValueInst(EVI->getAggregateOperand(), - EVI->getIndices(), Q); + Result = SimplifyExtractValueInst(NewOps[0], EVI->getIndices(), Q); break; } case Instruction::ExtractElement: { - auto *EEI = cast(I); - Result = SimplifyExtractElementInst(EEI->getVectorOperand(), - EEI->getIndexOperand(), Q); + Result = SimplifyExtractElementInst(NewOps[0], NewOps[1], Q); break; } case Instruction::ShuffleVector: { auto *SVI = cast(I); - Result = - SimplifyShuffleVectorInst(SVI->getOperand(0), SVI->getOperand(1), - SVI->getShuffleMask(), SVI->getType(), Q); + Result = SimplifyShuffleVectorInst( + NewOps[0], NewOps[1], SVI->getShuffleMask(), SVI->getType(), Q); break; } case Instruction::PHI: - Result = SimplifyPHINode(cast(I), Q); + Result = SimplifyPHINode(cast(I), NewOps, Q); break; case Instruction::Call: { + // TODO: Use NewOps Result = SimplifyCall(cast(I), Q); break; } case Instruction::Freeze: - Result = SimplifyFreezeInst(I->getOperand(0), Q); + Result = llvm::SimplifyFreezeInst(NewOps[0], Q); break; #define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc: #include "llvm/IR/Instruction.def" #undef HANDLE_CAST_INST - Result = - SimplifyCastInst(I->getOpcode(), I->getOperand(0), I->getType(), Q); + Result = SimplifyCastInst(I->getOpcode(), NewOps[0], I->getType(), Q); break; case Instruction::Alloca: // No simplifications for Alloca and it can't be constant folded. Result = nullptr; break; case Instruction::Load: - Result = SimplifyLoadInst(cast(I), Q); + Result = SimplifyLoadInst(cast(I), NewOps[0], Q); break; } @@ -6226,6 +6215,21 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, return Result == I ? UndefValue::get(I->getType()) : Result; } +Value *llvm::SimplifyInstructionWithOperands(Instruction *I, + ArrayRef NewOps, + const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { + assert(NewOps.size() == I->getNumOperands() && + "Number of operands should match the instruction!"); + return ::simplifyInstructionWithOperands(I, NewOps, SQ, ORE); +} + +Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, + OptimizationRemarkEmitter *ORE) { + SmallVector Ops(I->operands()); + return ::simplifyInstructionWithOperands(I, Ops, SQ, ORE); +} + /// Implementation of recursive simplification through an instruction's /// uses. ///