1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-25 20:23:11 +01:00

[OpenMP] Unified entry point for SPMD & generic kernels in the device RTL

In the spirit of TRegions [0], this patch provides a simpler and uniform
interface for a kernel to set up the device runtime. The OMPIRBuilder is
used for reuse in Flang. A custom state machine will be generated in the
follow up patch.

The "surplus" threads of the "master warp" will not exit early anymore
so we need to use non-aligned barriers. The new runtime will not have an
extra warp but also require these non-aligned barriers.

[0] https://link.springer.com/chapter/10.1007/978-3-030-28596-8_11

This was in parts extracted from D59319.

Reviewed By: ABataev, JonChesterfield

Differential Revision: https://reviews.llvm.org/D101976
This commit is contained in:
Johannes Doerfert 2021-06-17 11:23:20 -05:00
parent 522a3bbbbc
commit 51153424db
6 changed files with 153 additions and 55 deletions

View File

@ -779,6 +779,29 @@ public:
llvm::ConstantInt *Size,
const llvm::Twine &Name = Twine(""));
/// The `omp target` interface
///
/// For more information about the usage of this interface,
/// \see openmp/libomptarget/deviceRTLs/common/include/target.h
///
///{
/// Create a runtime call for kmpc_target_init
///
/// \param Loc The insert and source location description.
/// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
/// \param RequiresFullRuntime Indicate if a full device runtime is necessary.
InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime);
/// Create a runtime call for kmpc_target_deinit
///
/// \param Loc The insert and source location description.
/// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
/// \param RequiresFullRuntime Indicate if a full device runtime is necessary.
void createTargetDeinit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime);
///}
/// Declarations for LLVM-IR types (simple, array, function and structure) are
/// generated below. Their names are defined and used in OpenMPKinds.def. Here
/// we provide the declarations, the initializeTypes function will provide the

View File

@ -409,10 +409,8 @@ __OMP_RTL(__kmpc_task_allow_completion_event, false, VoidPtr, IdentPtr,
/* Int */ Int32, /* kmp_task_t */ VoidPtr)
/// OpenMP Device runtime functions
__OMP_RTL(__kmpc_kernel_init, false, Void, Int32, Int16)
__OMP_RTL(__kmpc_kernel_deinit, false, Void, Int16)
__OMP_RTL(__kmpc_spmd_kernel_init, false, Void, Int32, Int16)
__OMP_RTL(__kmpc_spmd_kernel_deinit_v2, false, Void, Int16)
__OMP_RTL(__kmpc_target_init, false, Int32, IdentPtr, Int1, Int1, Int1)
__OMP_RTL(__kmpc_target_deinit, false, Void, IdentPtr, Int1, Int1)
__OMP_RTL(__kmpc_kernel_prepare_parallel, false, Void, VoidPtr)
__OMP_RTL(__kmpc_parallel_51, false, Void, IdentPtr, Int32, Int32, Int32, Int32,
VoidPtr, VoidPtr, VoidPtrPtr, SizeTy)

View File

@ -20,6 +20,7 @@
#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Error.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
@ -2191,6 +2192,70 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
return Builder.CreateCall(Fn, Args);
}
OpenMPIRBuilder::InsertPointTy
OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime) {
if (!updateToLocation(Loc))
return Loc.IP;
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
Value *Ident = getOrCreateIdent(SrcLocStr);
ConstantInt *IsSPMDVal = ConstantInt::getBool(Int32->getContext(), IsSPMD);
ConstantInt *UseGenericStateMachine =
ConstantInt::getBool(Int32->getContext(), !IsSPMD);
ConstantInt *RequiresFullRuntimeVal = ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime);
Function *Fn = getOrCreateRuntimeFunctionPtr(
omp::RuntimeFunction::OMPRTL___kmpc_target_init);
CallInst *ThreadKind =
Builder.CreateCall(Fn, {Ident, IsSPMDVal, UseGenericStateMachine, RequiresFullRuntimeVal});
Value *ExecUserCode = Builder.CreateICmpEQ(
ThreadKind, ConstantInt::get(ThreadKind->getType(), -1), "exec_user_code");
// ThreadKind = __kmpc_target_init(...)
// if (ThreadKind == -1)
// user_code
// else
// return;
auto *UI = Builder.CreateUnreachable();
BasicBlock *CheckBB = UI->getParent();
BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry");
BasicBlock *WorkerExitBB = BasicBlock::Create(
CheckBB->getContext(), "worker.exit", CheckBB->getParent());
Builder.SetInsertPoint(WorkerExitBB);
Builder.CreateRetVoid();
auto *CheckBBTI = CheckBB->getTerminator();
Builder.SetInsertPoint(CheckBBTI);
Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB);
CheckBBTI->eraseFromParent();
UI->eraseFromParent();
// Continue in the "user_code" block, see diagram above and in
// openmp/libomptarget/deviceRTLs/common/include/target.h .
return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
}
void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
bool IsSPMD, bool RequiresFullRuntime) {
if (!updateToLocation(Loc))
return;
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
Value *Ident = getOrCreateIdent(SrcLocStr);
ConstantInt *IsSPMDVal = ConstantInt::getBool(Int32->getContext(), IsSPMD);
ConstantInt *RequiresFullRuntimeVal = ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime);
Function *Fn = getOrCreateRuntimeFunctionPtr(
omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
Builder.CreateCall(Fn, {Ident, IsSPMDVal, RequiresFullRuntimeVal});
}
std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
StringRef FirstSeparator,
StringRef Separator) {

View File

@ -26,9 +26,6 @@
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsAMDGPU.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/IPO.h"
@ -37,7 +34,6 @@
#include "llvm/Transforms/Utils/CallGraphUpdater.h"
#include "llvm/Transforms/Utils/CodeExtractor.h"
using namespace llvm::PatternMatch;
using namespace llvm;
using namespace omp;
@ -2341,10 +2337,12 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
AllCallSitesKnown))
SingleThreadedBBs.erase(&F->getEntryBlock());
// Check if the edge into the successor block compares a thread-id function to
// a constant zero.
// TODO: Use AAValueSimplify to simplify and propogate constants.
// TODO: Check more than a single use for thread ID's.
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
// Check if the edge into the successor block compares the __kmpc_target_init
// result with -1. If we are in non-SPMD-mode that signals only the main
// thread will execute the edge.
auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
if (!Edge || !Edge->isConditional())
return false;
@ -2355,31 +2353,20 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
return false;
// Temporarily match the pattern generated by clang for teams regions.
// TODO: Remove this once the new runtime is in place.
ConstantInt *One, *NegOne;
CmpInst::Predicate Pred;
auto &&m_ThreadID = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_tid_x>();
auto &&m_WarpSize = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_warpsize>();
auto &&m_BlockSize = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_ntid_x>();
if (match(Cmp, m_Cmp(Pred, m_ThreadID,
m_And(m_Sub(m_BlockSize, m_ConstantInt(One)),
m_Xor(m_Sub(m_WarpSize, m_ConstantInt(One)),
m_ConstantInt(NegOne))))))
if (One->isOne() && NegOne->isMinusOne() &&
Pred == CmpInst::Predicate::ICMP_EQ)
return true;
ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
if (!C || !C->isZero())
if (!C)
return false;
if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
return true;
if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
return true;
// Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
if (C->isAllOnesValue()) {
auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
if (!CB || CB->getCalledFunction() != RFI.Declaration)
return false;
const int InitIsSPMDArgNo = 1;
auto *IsSPMDModeCI =
dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo));
return IsSPMDModeCI && IsSPMDModeCI->isZero();
}
return false;
};

View File

@ -3,10 +3,16 @@
target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
target triple = "nvptx64"
%struct.ident_t = type { i32, i32, i32, i32, i8* }
@S = external local_unnamed_addr global i8*
@0 = private unnamed_addr constant [113 x i8] c";llvm/test/Transforms/OpenMP/custom_state_machines_remarks.c;__omp_offloading_2a_d80d3d_test_fallback_l11;11;1;;\00", align 1
@1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([113 x i8], [113 x i8]* @0, i32 0, i32 0) }, align 8
; CHECK-REMARKS: remark: replace_globalization.c:5:7: Replaced globalized variable with 16 bytes of shared memory
; CHECK-REMARKS: remark: replace_globalization.c:5:14: Replaced globalized variable with 4 bytes of shared memory
; CHECK-REMARKS-NOT: 6 bytes
; CHECK: [[SHARED_X:@.+]] = internal addrspace(3) global [16 x i8] undef
; CHECK: [[SHARED_Y:@.+]] = internal addrspace(3) global [4 x i8] undef
@ -25,14 +31,15 @@ entry:
define void @bar() {
call void @baz()
call void @qux()
call void @negative_qux_spmd()
ret void
}
; CHECK: call void @use.internalized(i8* nofree writeonly addrspacecast (i8 addrspace(3)* getelementptr inbounds ([16 x i8], [16 x i8] addrspace(3)* [[SHARED_X]], i32 0, i32 0) to i8*))
define internal void @baz() {
entry:
%tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%cmp = icmp eq i32 %tid, 0
%call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 false, i1 true)
%cmp = icmp eq i32 %call, -1
br i1 %cmp, label %master, label %exit
master:
%x = call i8* @__kmpc_alloc_shared(i64 16), !dbg !11
@ -48,20 +55,30 @@ exit:
; CHECK: call void @use.internalized(i8* nofree writeonly addrspacecast (i8 addrspace(3)* getelementptr inbounds ([4 x i8], [4 x i8] addrspace(3)* [[SHARED_Y]], i32 0, i32 0) to i8*))
define internal void @qux() {
entry:
%tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%ntid = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
%warpsize = call i32 @llvm.nvvm.read.ptx.sreg.warpsize()
%0 = sub nuw i32 %warpsize, 1
%1 = sub nuw i32 %ntid, 1
%2 = xor i32 %0, -1
%master_tid = and i32 %1, %2
%3 = icmp eq i32 %tid, %master_tid
br i1 %3, label %master, label %exit
%call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 true, i1 true)
%0 = icmp eq i32 %call, -1
br i1 %0, label %master, label %exit
master:
%y = call i8* @__kmpc_alloc_shared(i64 4), !dbg !12
%y_on_stack = bitcast i8* %y to [4 x i32]*
%4 = bitcast [4 x i32]* %y_on_stack to i8*
call void @use(i8* %4)
%1 = bitcast [4 x i32]* %y_on_stack to i8*
call void @use(i8* %1)
call void @__kmpc_free_shared(i8* %y)
br label %exit
exit:
ret void
}
define internal void @negative_qux_spmd() {
entry:
%call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 true, i1 true, i1 true)
%0 = icmp eq i32 %call, -1
br i1 %0, label %master, label %exit
master:
%y = call i8* @__kmpc_alloc_shared(i64 6), !dbg !12
%y_on_stack = bitcast i8* %y to [6 x i32]*
%1 = bitcast [6 x i32]* %y_on_stack to i8*
call void @use(i8* %1)
call void @__kmpc_free_shared(i8* %y)
br label %exit
exit:
@ -85,6 +102,7 @@ declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
declare i32 @llvm.nvvm.read.ptx.sreg.warpsize()
declare i32 @__kmpc_target_init(%struct.ident_t*, i1, i1, i1)
!llvm.dbg.cu = !{!0}
!llvm.module.flags = !{!3, !4, !5, !6}

View File

@ -3,8 +3,13 @@
; REQUIRES: asserts
; ModuleID = 'single_threaded_exeuction.c'
define weak void @kernel() {
call void @__kmpc_kernel_init(i32 512, i16 1)
%struct.ident_t = type { i32, i32, i32, i32, i8* }
@0 = private unnamed_addr constant [1 x i8] c"\00", align 1
@1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([1 x i8], [1 x i8]* @0, i32 0, i32 0) }, align 8
define void @kernel() {
call void @__kmpc_kernel_prepare_parallel(i8* null)
call void @nvptx()
call void @amdgcn()
ret void
@ -19,8 +24,8 @@ define weak void @kernel() {
; Function Attrs: noinline
define internal void @nvptx() {
entry:
%call = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%cmp = icmp eq i32 %call, 0
%call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 false, i1 false)
%cmp = icmp eq i32 %call, -1
br i1 %cmp, label %if.then, label %if.end
if.then:
@ -40,8 +45,8 @@ if.end:
; Function Attrs: noinline
define internal void @amdgcn() {
entry:
%call = call i32 @llvm.amdgcn.workitem.id.x()
%cmp = icmp eq i32 %call, 0
%call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 true, i1 true)
%cmp = icmp eq i32 %call, -1
br i1 %cmp, label %if.then, label %if.end
if.then:
@ -87,7 +92,9 @@ declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
declare i32 @llvm.amdgcn.workitem.id.x()
declare void @__kmpc_kernel_init(i32, i16)
declare void @__kmpc_kernel_prepare_parallel(i8*)
declare i32 @__kmpc_target_init(%struct.ident_t*, i1, i1, i1)
attributes #0 = { cold noinline }