diff --git a/lib/Target/NVPTX/CMakeLists.txt b/lib/Target/NVPTX/CMakeLists.txt index 4e35b181129..6b196ef6150 100644 --- a/lib/Target/NVPTX/CMakeLists.txt +++ b/lib/Target/NVPTX/CMakeLists.txt @@ -29,6 +29,7 @@ set(NVPTXCodeGen_sources NVPTXMCExpr.cpp NVPTXReplaceImageHandles.cpp NVPTXImageOptimizer.cpp + NVPTXLowerStructArgs.cpp ) add_llvm_target(NVPTXCodeGen ${NVPTXCodeGen_sources}) diff --git a/lib/Target/NVPTX/NVPTX.h b/lib/Target/NVPTX/NVPTX.h index eaeecdbca81..9f315f6dded 100644 --- a/lib/Target/NVPTX/NVPTX.h +++ b/lib/Target/NVPTX/NVPTX.h @@ -69,6 +69,7 @@ ModulePass *createNVVMReflectPass(const StringMap& Mapping); MachineFunctionPass *createNVPTXPrologEpilogPass(); MachineFunctionPass *createNVPTXReplaceImageHandlesPass(); FunctionPass *createNVPTXImageOptimizerPass(); +FunctionPass *createNVPTXLowerStructArgsPass(); bool isImageOrSamplerVal(const Value *, const Module *); diff --git a/lib/Target/NVPTX/NVPTXLowerStructArgs.cpp b/lib/Target/NVPTX/NVPTXLowerStructArgs.cpp new file mode 100644 index 00000000000..6c1cc80a9fd --- /dev/null +++ b/lib/Target/NVPTX/NVPTXLowerStructArgs.cpp @@ -0,0 +1,150 @@ +//===-- NVPTXLowerStructArgs.cpp - Copy struct args to local memory =====--===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Copy struct args to local memory. This is needed for kernel functions only. +// This is a preparation for handling cases like +// +// kernel void foo(struct A arg, ...) +// { +// struct A *p = &arg; +// ... +// ... = p->filed1 ... (this is no generic address for .param) +// p->filed2 = ... (this is no write access to .param) +// } +// +//===----------------------------------------------------------------------===// + +#include "NVPTX.h" +#include "NVPTXUtilities.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Pass.h" + +using namespace llvm; + +namespace llvm { +void initializeNVPTXLowerStructArgsPass(PassRegistry &); +} + +class LLVM_LIBRARY_VISIBILITY NVPTXLowerStructArgs : public FunctionPass { + bool runOnFunction(Function &F) override; + + void handleStructPtrArgs(Function &); + void handleParam(Argument *); + +public: + static char ID; // Pass identification, replacement for typeid + NVPTXLowerStructArgs() : FunctionPass(ID) {} + const char *getPassName() const override { + return "Copy structure (byval *) arguments to stack"; + } +}; + +char NVPTXLowerStructArgs::ID = 1; + +INITIALIZE_PASS(NVPTXLowerStructArgs, "nvptx-lower-struct-args", + "Lower structure arguments (NVPTX)", false, false) + +void NVPTXLowerStructArgs::handleParam(Argument *Arg) { + Function *Func = Arg->getParent(); + Instruction *FirstInst = &(Func->getEntryBlock().front()); + const PointerType *PType = dyn_cast(Arg->getType()); + + assert(PType && "Expecting pointer type in handleParam"); + + const Type *StructType = PType->getElementType(); + + AllocaInst *AllocA = + new AllocaInst((Type *)StructType, Arg->getName(), FirstInst); + + /* Set the alignment to alignment of the byval parameter. This is because, + * later load/stores assume that alignment, and we are going to replace + * the use of the byval parameter with this alloca instruction. + */ + AllocA->setAlignment(Func->getParamAlignment(Arg->getArgNo() + 1)); + + Arg->replaceAllUsesWith(AllocA); + + // Get the cvt.gen.to.param intrinsic + const Type *CvtTypes[2] = { + Type::getInt8PtrTy(Func->getParent()->getContext(), ADDRESS_SPACE_PARAM), + Type::getInt8PtrTy(Func->getParent()->getContext(), ADDRESS_SPACE_GENERIC) + }; + Function *CvtFunc = (Function *)Intrinsic::getDeclaration( + Func->getParent(), Intrinsic::nvvm_ptr_gen_to_param, + ArrayRef((Type **)CvtTypes, 2)); + std::vector BC1; + BC1.push_back( + new BitCastInst(Arg, Type::getInt8PtrTy(Func->getParent()->getContext(), + ADDRESS_SPACE_GENERIC), + Arg->getName(), FirstInst)); + CallInst *CallCVT = CallInst::Create(CvtFunc, ArrayRef(BC1), + "cvt_to_param", FirstInst); + + BitCastInst *BitCast = new BitCastInst( + CallCVT, PointerType::get((Type *)StructType, ADDRESS_SPACE_PARAM), + Arg->getName(), FirstInst); + LoadInst *LI = new LoadInst(BitCast, Arg->getName(), FirstInst); + new StoreInst(LI, AllocA, FirstInst); +} + +/// ============================================================================= +/// If the function had a struct ptr arg, say foo(%struct.x *byval %d), then +/// add the following instructions to the first basic block : +/// +/// %temp = alloca %struct.x, align 8 +/// %tt1 = bitcast %struct.x * %d to i8 * +/// %tt2 = llvm.nvvm.cvt.gen.to.param %tt2 +/// %tempd = bitcast i8 addrspace(101) * to %struct.x addrspace(101) * +/// %tv = load %struct.x addrspace(101) * %tempd +/// store %struct.x %tv, %struct.x * %temp, align 8 +/// +/// The above code allocates some space in the stack and copies the incoming +/// struct from param space to local space. +/// Then replace all occurences of %d by %temp. +/// ============================================================================= +void NVPTXLowerStructArgs::handleStructPtrArgs(Function &F) { + const AttributeSet &PAL = F.getAttributes(); + + unsigned Idx = 1; + + for (Argument &Arg : F.args()) { + const Type *Ty = Arg.getType(); + + const PointerType *PTy = dyn_cast(Ty); + + if (PTy) { + if (PAL.hasAttribute(Idx, Attribute::ByVal)) { + // cout << "Has struct ptr args" << std::endl; + handleParam(&Arg); + } + } + Idx++; + } +} + +/// ============================================================================= +/// Main function for this pass. +/// ============================================================================= +bool NVPTXLowerStructArgs::runOnFunction(Function &F) { + // Skip non-kernels. See the comments at the top of this file. + if (!isKernelFunction(F)) + return false; + + handleStructPtrArgs(F); + + return true; +} + +FunctionPass *llvm::createNVPTXLowerStructArgsPass() { + return new NVPTXLowerStructArgs(); +} diff --git a/lib/Target/NVPTX/NVPTXTargetMachine.cpp b/lib/Target/NVPTX/NVPTXTargetMachine.cpp index 069a1b9966f..75df2c5bcae 100644 --- a/lib/Target/NVPTX/NVPTXTargetMachine.cpp +++ b/lib/Target/NVPTX/NVPTXTargetMachine.cpp @@ -50,6 +50,7 @@ void initializeNVVMReflectPass(PassRegistry&); void initializeGenericToNVVMPass(PassRegistry&); void initializeNVPTXAssignValidGlobalNamesPass(PassRegistry&); void initializeNVPTXFavorNonGenericAddrSpacesPass(PassRegistry &); +void initializeNVPTXLowerStructArgsPass(PassRegistry &); } extern "C" void LLVMInitializeNVPTXTarget() { @@ -64,6 +65,7 @@ extern "C" void LLVMInitializeNVPTXTarget() { initializeNVPTXAssignValidGlobalNamesPass(*PassRegistry::getPassRegistry()); initializeNVPTXFavorNonGenericAddrSpacesPass( *PassRegistry::getPassRegistry()); + initializeNVPTXLowerStructArgsPass(*PassRegistry::getPassRegistry()); } NVPTXTargetMachine::NVPTXTargetMachine(const Target &T, StringRef TT, diff --git a/test/CodeGen/NVPTX/bug21465.ll b/test/CodeGen/NVPTX/bug21465.ll new file mode 100644 index 00000000000..157b28c67e5 --- /dev/null +++ b/test/CodeGen/NVPTX/bug21465.ll @@ -0,0 +1,24 @@ +; RUN: opt < %s -nvptx-lower-struct-args -S | FileCheck %s + +target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64" +target triple = "nvptx64-unknown-unknown" + +%struct.S = type { i32, i32 } + +; Function Attrs: nounwind +define void @_Z11TakesStruct1SPi(%struct.S* byval nocapture readonly %input, i32* nocapture %output) #0 { +entry: +; CHECK-LABEL @_Z22TakesStruct1SPi +; CHECK: bitcast %struct.S* %input to i8* +; CHECK: call i8 addrspace(101)* @llvm.nvvm.ptr.gen.to.param.p101i8.p0i8 + %b = getelementptr inbounds %struct.S* %input, i64 0, i32 1 + %0 = load i32* %b, align 4 + store i32 %0, i32* %output, align 4 + ret void +} + +attributes #0 = { nounwind "less-precise-fpmad"="false" "no-frame-pointer-elim"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "stack-protector-buffer-size"="8" "unsafe-fp-math"="false" "use-soft-float"="false" } + +!nvvm.annotations = !{!0} + +!0 = metadata !{void (%struct.S*, i32*)* @_Z11TakesStruct1SPi, metadata !"kernel", i32 1}