From 5521155be5c869b0b760e1dec86c41cdbb7a75c0 Mon Sep 17 00:00:00 2001 From: sguo35 Date: Wed, 25 May 2022 17:40:19 -0700 Subject: [PATCH] Fix register clobbering on aarch64 GHC when mixing tail/non-tail calls By default LLVM doesn't save any regs for GHC on arm64. This means we'll clobber LR on arm64 if we make non-tail calls (e.g. L2 syscall) So we should save LR on non-tail calls, and not assume we won't make non-tail calls. --- lib/Target/AArch64/AArch64CallingConvention.td | 3 +++ lib/Target/AArch64/AArch64FrameLowering.cpp | 14 -------------- lib/Target/AArch64/AArch64ISelLowering.cpp | 2 +- lib/Target/AArch64/AArch64RegisterInfo.cpp | 10 +++++----- lib/Target/AArch64/GISel/AArch64CallLowering.cpp | 1 + 5 files changed, 10 insertions(+), 20 deletions(-) diff --git a/lib/Target/AArch64/AArch64CallingConvention.td b/lib/Target/AArch64/AArch64CallingConvention.td index 4b7ce565eb1..607c5c31c89 100644 --- a/lib/Target/AArch64/AArch64CallingConvention.td +++ b/lib/Target/AArch64/AArch64CallingConvention.td @@ -508,6 +508,9 @@ def CSR_Darwin_AArch64_CXX_TLS_ViaCopy def CSR_Darwin_AArch64_RT_MostRegs : CalleeSavedRegs<(add CSR_Darwin_AArch64_AAPCS, (sequence "X%u", 9, 15))>; +def CSR_AArch64_NoRegs_LR + : CalleeSavedRegs<(add CSR_AArch64_NoRegs, LR)>; + // Variants of the standard calling conventions for shadow call stack. // These all preserve x18 in addition to any other registers. def CSR_AArch64_NoRegs_SCS diff --git a/lib/Target/AArch64/AArch64FrameLowering.cpp b/lib/Target/AArch64/AArch64FrameLowering.cpp index f6a528c0e6f..f8cf2fa63ba 100644 --- a/lib/Target/AArch64/AArch64FrameLowering.cpp +++ b/lib/Target/AArch64/AArch64FrameLowering.cpp @@ -1165,11 +1165,6 @@ void AArch64FrameLowering::emitPrologue(MachineFunction &MF, .setMIFlag(MachineInstr::FrameSetup); } - // All calls are tail calls in GHC calling conv, and functions have no - // prologue/epilogue. - if (MF.getFunction().getCallingConv() == CallingConv::GHC) - return; - // Set tagged base pointer to the requested stack slot. // Ideally it should match SP value after prologue. Optional TBPI = AFI->getTaggedBasePointerIndex(); @@ -1677,11 +1672,6 @@ void AArch64FrameLowering::emitEpilogue(MachineFunction &MF, : MFI.getStackSize(); AArch64FunctionInfo *AFI = MF.getInfo(); - // All calls are tail calls in GHC calling conv, and functions have no - // prologue/epilogue. - if (MF.getFunction().getCallingConv() == CallingConv::GHC) - return; - // How much of the stack used by incoming arguments this function is expected // to restore in this particular epilogue. int64_t ArgumentStackToRestore = getArgumentStackToRestore(MF, MBB); @@ -2733,10 +2723,6 @@ bool AArch64FrameLowering::restoreCalleeSavedRegisters( void AArch64FrameLowering::determineCalleeSaves(MachineFunction &MF, BitVector &SavedRegs, RegScavenger *RS) const { - // All calls are tail calls in GHC calling conv, and functions have no - // prologue/epilogue. - if (MF.getFunction().getCallingConv() == CallingConv::GHC) - return; TargetFrameLowering::determineCalleeSaves(MF, SavedRegs, RS); const AArch64RegisterInfo *RegInfo = static_cast( diff --git a/lib/Target/AArch64/AArch64ISelLowering.cpp b/lib/Target/AArch64/AArch64ISelLowering.cpp index f5de1a7db25..64e74f8e9b4 100644 --- a/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -5526,7 +5526,7 @@ SDValue AArch64TargetLowering::LowerCallResult( /// Return true if the calling convention is one that we can guarantee TCO for. static bool canGuaranteeTCO(CallingConv::ID CC, bool GuaranteeTailCalls) { return (CC == CallingConv::Fast && GuaranteeTailCalls) || - CC == CallingConv::Tail || CC == CallingConv::SwiftTail || CC == CallingConv::GHC; + CC == CallingConv::Tail || CC == CallingConv::SwiftTail; } /// Return true if we might ever do TCO for calls with this calling convention. diff --git a/lib/Target/AArch64/AArch64RegisterInfo.cpp b/lib/Target/AArch64/AArch64RegisterInfo.cpp index d1b901e58d2..8b3386b2ae5 100644 --- a/lib/Target/AArch64/AArch64RegisterInfo.cpp +++ b/lib/Target/AArch64/AArch64RegisterInfo.cpp @@ -76,9 +76,7 @@ AArch64RegisterInfo::getCalleeSavedRegs(const MachineFunction *MF) const { assert(MF && "Invalid MachineFunction pointer."); if (MF->getFunction().getCallingConv() == CallingConv::GHC) - // GHC set of callee saved regs is empty as all those regs are - // used for passing STG regs around - return CSR_AArch64_NoRegs_SaveList; + return CSR_AArch64_NoRegs_LR_SaveList; if (MF->getFunction().getCallingConv() == CallingConv::AnyReg) return CSR_AArch64_AllRegs_SaveList; @@ -215,8 +213,10 @@ AArch64RegisterInfo::getCallPreservedMask(const MachineFunction &MF, CallingConv::ID CC) const { bool SCS = MF.getFunction().hasFnAttribute(Attribute::ShadowCallStack); if (CC == CallingConv::GHC) - // This is academic because all GHC calls are (supposed to be) tail calls - return SCS ? CSR_AArch64_NoRegs_SCS_RegMask : CSR_AArch64_NoRegs_RegMask; + // By default LLVM doesn't save any regs for GHC. + // This means we'll clobber LR on arm64 if we make non-tail calls (e.g. L2 syscall) + // CSR_AArch64_NoRegs_LR saves LR to fix this + return CSR_AArch64_NoRegs_LR_RegMask; if (CC == CallingConv::AnyReg) return SCS ? CSR_AArch64_AllRegs_SCS_RegMask : CSR_AArch64_AllRegs_RegMask; diff --git a/lib/Target/AArch64/GISel/AArch64CallLowering.cpp b/lib/Target/AArch64/GISel/AArch64CallLowering.cpp index 28b234b180f..8431bce064b 100644 --- a/lib/Target/AArch64/GISel/AArch64CallLowering.cpp +++ b/lib/Target/AArch64/GISel/AArch64CallLowering.cpp @@ -615,6 +615,7 @@ static bool mayTailCallThisCC(CallingConv::ID CC) { case CallingConv::SwiftTail: case CallingConv::Tail: case CallingConv::Fast: + case CallingConv::GHC: return true; default: return false;