From 51153424db735974af0ae64db50c7fc2c9559af0 Mon Sep 17 00:00:00 2001 From: Johannes Doerfert Date: Thu, 17 Jun 2021 11:23:20 -0500 Subject: [PATCH] [OpenMP] Unified entry point for SPMD & generic kernels in the device RTL In the spirit of TRegions [0], this patch provides a simpler and uniform interface for a kernel to set up the device runtime. The OMPIRBuilder is used for reuse in Flang. A custom state machine will be generated in the follow up patch. The "surplus" threads of the "master warp" will not exit early anymore so we need to use non-aligned barriers. The new runtime will not have an extra warp but also require these non-aligned barriers. [0] https://link.springer.com/chapter/10.1007/978-3-030-28596-8_11 This was in parts extracted from D59319. Reviewed By: ABataev, JonChesterfield Differential Revision: https://reviews.llvm.org/D101976 --- include/llvm/Frontend/OpenMP/OMPIRBuilder.h | 23 +++++++ include/llvm/Frontend/OpenMP/OMPKinds.def | 6 +- lib/Frontend/OpenMP/OMPIRBuilder.cpp | 65 +++++++++++++++++++ lib/Transforms/IPO/OpenMPOpt.cpp | 49 +++++--------- .../OpenMP/replace_globalization.ll | 44 +++++++++---- .../OpenMP/single_threaded_execution.ll | 21 ++++-- 6 files changed, 153 insertions(+), 55 deletions(-) diff --git a/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index 0a249b3e257..a92c3ba381c 100644 --- a/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -779,6 +779,29 @@ public: llvm::ConstantInt *Size, const llvm::Twine &Name = Twine("")); + /// The `omp target` interface + /// + /// For more information about the usage of this interface, + /// \see openmp/libomptarget/deviceRTLs/common/include/target.h + /// + ///{ + + /// Create a runtime call for kmpc_target_init + /// + /// \param Loc The insert and source location description. + /// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not. + /// \param RequiresFullRuntime Indicate if a full device runtime is necessary. + InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime); + + /// Create a runtime call for kmpc_target_deinit + /// + /// \param Loc The insert and source location description. + /// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not. + /// \param RequiresFullRuntime Indicate if a full device runtime is necessary. + void createTargetDeinit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime); + + ///} + /// Declarations for LLVM-IR types (simple, array, function and structure) are /// generated below. Their names are defined and used in OpenMPKinds.def. Here /// we provide the declarations, the initializeTypes function will provide the diff --git a/include/llvm/Frontend/OpenMP/OMPKinds.def b/include/llvm/Frontend/OpenMP/OMPKinds.def index 1804cfeef7b..2003f44e34e 100644 --- a/include/llvm/Frontend/OpenMP/OMPKinds.def +++ b/include/llvm/Frontend/OpenMP/OMPKinds.def @@ -409,10 +409,8 @@ __OMP_RTL(__kmpc_task_allow_completion_event, false, VoidPtr, IdentPtr, /* Int */ Int32, /* kmp_task_t */ VoidPtr) /// OpenMP Device runtime functions -__OMP_RTL(__kmpc_kernel_init, false, Void, Int32, Int16) -__OMP_RTL(__kmpc_kernel_deinit, false, Void, Int16) -__OMP_RTL(__kmpc_spmd_kernel_init, false, Void, Int32, Int16) -__OMP_RTL(__kmpc_spmd_kernel_deinit_v2, false, Void, Int16) +__OMP_RTL(__kmpc_target_init, false, Int32, IdentPtr, Int1, Int1, Int1) +__OMP_RTL(__kmpc_target_deinit, false, Void, IdentPtr, Int1, Int1) __OMP_RTL(__kmpc_kernel_prepare_parallel, false, Void, VoidPtr) __OMP_RTL(__kmpc_parallel_51, false, Void, IdentPtr, Int32, Int32, Int32, Int32, VoidPtr, VoidPtr, VoidPtrPtr, SizeTy) diff --git a/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/lib/Frontend/OpenMP/OMPIRBuilder.cpp index 1020de5f30e..60d71805c75 100644 --- a/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -20,6 +20,7 @@ #include "llvm/IR/DebugInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Value.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Error.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" @@ -2191,6 +2192,70 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate( return Builder.CreateCall(Fn, Args); } +OpenMPIRBuilder::InsertPointTy +OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime) { + if (!updateToLocation(Loc)) + return Loc.IP; + + Constant *SrcLocStr = getOrCreateSrcLocStr(Loc); + Value *Ident = getOrCreateIdent(SrcLocStr); + ConstantInt *IsSPMDVal = ConstantInt::getBool(Int32->getContext(), IsSPMD); + ConstantInt *UseGenericStateMachine = + ConstantInt::getBool(Int32->getContext(), !IsSPMD); + ConstantInt *RequiresFullRuntimeVal = ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime); + + Function *Fn = getOrCreateRuntimeFunctionPtr( + omp::RuntimeFunction::OMPRTL___kmpc_target_init); + + CallInst *ThreadKind = + Builder.CreateCall(Fn, {Ident, IsSPMDVal, UseGenericStateMachine, RequiresFullRuntimeVal}); + + Value *ExecUserCode = Builder.CreateICmpEQ( + ThreadKind, ConstantInt::get(ThreadKind->getType(), -1), "exec_user_code"); + + // ThreadKind = __kmpc_target_init(...) + // if (ThreadKind == -1) + // user_code + // else + // return; + + auto *UI = Builder.CreateUnreachable(); + BasicBlock *CheckBB = UI->getParent(); + BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry"); + + BasicBlock *WorkerExitBB = BasicBlock::Create( + CheckBB->getContext(), "worker.exit", CheckBB->getParent()); + Builder.SetInsertPoint(WorkerExitBB); + Builder.CreateRetVoid(); + + auto *CheckBBTI = CheckBB->getTerminator(); + Builder.SetInsertPoint(CheckBBTI); + Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB); + + CheckBBTI->eraseFromParent(); + UI->eraseFromParent(); + + // Continue in the "user_code" block, see diagram above and in + // openmp/libomptarget/deviceRTLs/common/include/target.h . + return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt()); +} + +void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc, + bool IsSPMD, bool RequiresFullRuntime) { + if (!updateToLocation(Loc)) + return; + + Constant *SrcLocStr = getOrCreateSrcLocStr(Loc); + Value *Ident = getOrCreateIdent(SrcLocStr); + ConstantInt *IsSPMDVal = ConstantInt::getBool(Int32->getContext(), IsSPMD); + ConstantInt *RequiresFullRuntimeVal = ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime); + + Function *Fn = getOrCreateRuntimeFunctionPtr( + omp::RuntimeFunction::OMPRTL___kmpc_target_deinit); + + Builder.CreateCall(Fn, {Ident, IsSPMDVal, RequiresFullRuntimeVal}); +} + std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef Parts, StringRef FirstSeparator, StringRef Separator) { diff --git a/lib/Transforms/IPO/OpenMPOpt.cpp b/lib/Transforms/IPO/OpenMPOpt.cpp index 0127f9da430..b1230b96dd6 100644 --- a/lib/Transforms/IPO/OpenMPOpt.cpp +++ b/lib/Transforms/IPO/OpenMPOpt.cpp @@ -26,9 +26,6 @@ #include "llvm/Frontend/OpenMP/OMPConstants.h" #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/IntrinsicsAMDGPU.h" -#include "llvm/IR/IntrinsicsNVPTX.h" -#include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/IPO.h" @@ -37,7 +34,6 @@ #include "llvm/Transforms/Utils/CallGraphUpdater.h" #include "llvm/Transforms/Utils/CodeExtractor.h" -using namespace llvm::PatternMatch; using namespace llvm; using namespace omp; @@ -2341,10 +2337,12 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { AllCallSitesKnown)) SingleThreadedBBs.erase(&F->getEntryBlock()); - // Check if the edge into the successor block compares a thread-id function to - // a constant zero. - // TODO: Use AAValueSimplify to simplify and propogate constants. - // TODO: Check more than a single use for thread ID's. + auto &OMPInfoCache = static_cast(A.getInfoCache()); + auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init]; + + // Check if the edge into the successor block compares the __kmpc_target_init + // result with -1. If we are in non-SPMD-mode that signals only the main + // thread will execute the edge. auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) { if (!Edge || !Edge->isConditional()) return false; @@ -2355,31 +2353,20 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality()) return false; - // Temporarily match the pattern generated by clang for teams regions. - // TODO: Remove this once the new runtime is in place. - ConstantInt *One, *NegOne; - CmpInst::Predicate Pred; - auto &&m_ThreadID = m_Intrinsic(); - auto &&m_WarpSize = m_Intrinsic(); - auto &&m_BlockSize = m_Intrinsic(); - if (match(Cmp, m_Cmp(Pred, m_ThreadID, - m_And(m_Sub(m_BlockSize, m_ConstantInt(One)), - m_Xor(m_Sub(m_WarpSize, m_ConstantInt(One)), - m_ConstantInt(NegOne)))))) - if (One->isOne() && NegOne->isMinusOne() && - Pred == CmpInst::Predicate::ICMP_EQ) - return true; - ConstantInt *C = dyn_cast(Cmp->getOperand(1)); - if (!C || !C->isZero()) + if (!C) return false; - if (auto *II = dyn_cast(Cmp->getOperand(0))) - if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x) - return true; - if (auto *II = dyn_cast(Cmp->getOperand(0))) - if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x) - return true; + // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!) + if (C->isAllOnesValue()) { + auto *CB = dyn_cast(Cmp->getOperand(0)); + if (!CB || CB->getCalledFunction() != RFI.Declaration) + return false; + const int InitIsSPMDArgNo = 1; + auto *IsSPMDModeCI = + dyn_cast(CB->getOperand(InitIsSPMDArgNo)); + return IsSPMDModeCI && IsSPMDModeCI->isZero(); + } return false; }; @@ -2394,7 +2381,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) { for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB); PredBB != PredEndBB; ++PredBB) { if (!IsInitialThreadOnly(dyn_cast((*PredBB)->getTerminator()), - BB)) + BB)) IsInitialThread &= SingleThreadedBBs.contains(*PredBB); } diff --git a/test/Transforms/OpenMP/replace_globalization.ll b/test/Transforms/OpenMP/replace_globalization.ll index cb96fc3832a..06224e6d406 100644 --- a/test/Transforms/OpenMP/replace_globalization.ll +++ b/test/Transforms/OpenMP/replace_globalization.ll @@ -3,10 +3,16 @@ target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64" target triple = "nvptx64" +%struct.ident_t = type { i32, i32, i32, i32, i8* } + @S = external local_unnamed_addr global i8* +@0 = private unnamed_addr constant [113 x i8] c";llvm/test/Transforms/OpenMP/custom_state_machines_remarks.c;__omp_offloading_2a_d80d3d_test_fallback_l11;11;1;;\00", align 1 +@1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([113 x i8], [113 x i8]* @0, i32 0, i32 0) }, align 8 ; CHECK-REMARKS: remark: replace_globalization.c:5:7: Replaced globalized variable with 16 bytes of shared memory ; CHECK-REMARKS: remark: replace_globalization.c:5:14: Replaced globalized variable with 4 bytes of shared memory +; CHECK-REMARKS-NOT: 6 bytes + ; CHECK: [[SHARED_X:@.+]] = internal addrspace(3) global [16 x i8] undef ; CHECK: [[SHARED_Y:@.+]] = internal addrspace(3) global [4 x i8] undef @@ -25,14 +31,15 @@ entry: define void @bar() { call void @baz() call void @qux() + call void @negative_qux_spmd() ret void } ; CHECK: call void @use.internalized(i8* nofree writeonly addrspacecast (i8 addrspace(3)* getelementptr inbounds ([16 x i8], [16 x i8] addrspace(3)* [[SHARED_X]], i32 0, i32 0) to i8*)) define internal void @baz() { entry: - %tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() - %cmp = icmp eq i32 %tid, 0 + %call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 false, i1 true) + %cmp = icmp eq i32 %call, -1 br i1 %cmp, label %master, label %exit master: %x = call i8* @__kmpc_alloc_shared(i64 16), !dbg !11 @@ -48,20 +55,30 @@ exit: ; CHECK: call void @use.internalized(i8* nofree writeonly addrspacecast (i8 addrspace(3)* getelementptr inbounds ([4 x i8], [4 x i8] addrspace(3)* [[SHARED_Y]], i32 0, i32 0) to i8*)) define internal void @qux() { entry: - %tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() - %ntid = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x() - %warpsize = call i32 @llvm.nvvm.read.ptx.sreg.warpsize() - %0 = sub nuw i32 %warpsize, 1 - %1 = sub nuw i32 %ntid, 1 - %2 = xor i32 %0, -1 - %master_tid = and i32 %1, %2 - %3 = icmp eq i32 %tid, %master_tid - br i1 %3, label %master, label %exit + %call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 true, i1 true) + %0 = icmp eq i32 %call, -1 + br i1 %0, label %master, label %exit master: %y = call i8* @__kmpc_alloc_shared(i64 4), !dbg !12 %y_on_stack = bitcast i8* %y to [4 x i32]* - %4 = bitcast [4 x i32]* %y_on_stack to i8* - call void @use(i8* %4) + %1 = bitcast [4 x i32]* %y_on_stack to i8* + call void @use(i8* %1) + call void @__kmpc_free_shared(i8* %y) + br label %exit +exit: + ret void +} + +define internal void @negative_qux_spmd() { +entry: + %call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 true, i1 true, i1 true) + %0 = icmp eq i32 %call, -1 + br i1 %0, label %master, label %exit +master: + %y = call i8* @__kmpc_alloc_shared(i64 6), !dbg !12 + %y_on_stack = bitcast i8* %y to [6 x i32]* + %1 = bitcast [6 x i32]* %y_on_stack to i8* + call void @use(i8* %1) call void @__kmpc_free_shared(i8* %y) br label %exit exit: @@ -85,6 +102,7 @@ declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x() declare i32 @llvm.nvvm.read.ptx.sreg.warpsize() +declare i32 @__kmpc_target_init(%struct.ident_t*, i1, i1, i1) !llvm.dbg.cu = !{!0} !llvm.module.flags = !{!3, !4, !5, !6} diff --git a/test/Transforms/OpenMP/single_threaded_execution.ll b/test/Transforms/OpenMP/single_threaded_execution.ll index 5fff563d364..ae56477902b 100644 --- a/test/Transforms/OpenMP/single_threaded_execution.ll +++ b/test/Transforms/OpenMP/single_threaded_execution.ll @@ -3,8 +3,13 @@ ; REQUIRES: asserts ; ModuleID = 'single_threaded_exeuction.c' -define weak void @kernel() { - call void @__kmpc_kernel_init(i32 512, i16 1) +%struct.ident_t = type { i32, i32, i32, i32, i8* } + +@0 = private unnamed_addr constant [1 x i8] c"\00", align 1 +@1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([1 x i8], [1 x i8]* @0, i32 0, i32 0) }, align 8 + +define void @kernel() { + call void @__kmpc_kernel_prepare_parallel(i8* null) call void @nvptx() call void @amdgcn() ret void @@ -19,8 +24,8 @@ define weak void @kernel() { ; Function Attrs: noinline define internal void @nvptx() { entry: - %call = call i32 @llvm.nvvm.read.ptx.sreg.tid.x() - %cmp = icmp eq i32 %call, 0 + %call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 false, i1 false) + %cmp = icmp eq i32 %call, -1 br i1 %cmp, label %if.then, label %if.end if.then: @@ -40,8 +45,8 @@ if.end: ; Function Attrs: noinline define internal void @amdgcn() { entry: - %call = call i32 @llvm.amdgcn.workitem.id.x() - %cmp = icmp eq i32 %call, 0 + %call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 true, i1 true) + %cmp = icmp eq i32 %call, -1 br i1 %cmp, label %if.then, label %if.end if.then: @@ -87,7 +92,9 @@ declare i32 @llvm.nvvm.read.ptx.sreg.tid.x() declare i32 @llvm.amdgcn.workitem.id.x() -declare void @__kmpc_kernel_init(i32, i16) +declare void @__kmpc_kernel_prepare_parallel(i8*) + +declare i32 @__kmpc_target_init(%struct.ident_t*, i1, i1, i1) attributes #0 = { cold noinline }