mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-22 02:33:06 +01:00
[OpenMP] Unified entry point for SPMD & generic kernels in the device RTL
In the spirit of TRegions [0], this patch provides a simpler and uniform interface for a kernel to set up the device runtime. The OMPIRBuilder is used for reuse in Flang. A custom state machine will be generated in the follow up patch. The "surplus" threads of the "master warp" will not exit early anymore so we need to use non-aligned barriers. The new runtime will not have an extra warp but also require these non-aligned barriers. [0] https://link.springer.com/chapter/10.1007/978-3-030-28596-8_11 This was in parts extracted from D59319. Reviewed By: ABataev, JonChesterfield Differential Revision: https://reviews.llvm.org/D101976
This commit is contained in:
parent
c2a2cf0480
commit
63e4735bba
@ -779,6 +779,29 @@ public:
|
||||
llvm::ConstantInt *Size,
|
||||
const llvm::Twine &Name = Twine(""));
|
||||
|
||||
/// The `omp target` interface
|
||||
///
|
||||
/// For more information about the usage of this interface,
|
||||
/// \see openmp/libomptarget/deviceRTLs/common/include/target.h
|
||||
///
|
||||
///{
|
||||
|
||||
/// Create a runtime call for kmpc_target_init
|
||||
///
|
||||
/// \param Loc The insert and source location description.
|
||||
/// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
|
||||
/// \param RequiresFullRuntime Indicate if a full device runtime is necessary.
|
||||
InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime);
|
||||
|
||||
/// Create a runtime call for kmpc_target_deinit
|
||||
///
|
||||
/// \param Loc The insert and source location description.
|
||||
/// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
|
||||
/// \param RequiresFullRuntime Indicate if a full device runtime is necessary.
|
||||
void createTargetDeinit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime);
|
||||
|
||||
///}
|
||||
|
||||
/// Declarations for LLVM-IR types (simple, array, function and structure) are
|
||||
/// generated below. Their names are defined and used in OpenMPKinds.def. Here
|
||||
/// we provide the declarations, the initializeTypes function will provide the
|
||||
|
@ -409,10 +409,8 @@ __OMP_RTL(__kmpc_task_allow_completion_event, false, VoidPtr, IdentPtr,
|
||||
/* Int */ Int32, /* kmp_task_t */ VoidPtr)
|
||||
|
||||
/// OpenMP Device runtime functions
|
||||
__OMP_RTL(__kmpc_kernel_init, false, Void, Int32, Int16)
|
||||
__OMP_RTL(__kmpc_kernel_deinit, false, Void, Int16)
|
||||
__OMP_RTL(__kmpc_spmd_kernel_init, false, Void, Int32, Int16)
|
||||
__OMP_RTL(__kmpc_spmd_kernel_deinit_v2, false, Void, Int16)
|
||||
__OMP_RTL(__kmpc_target_init, false, Int32, IdentPtr, Int1, Int1, Int1)
|
||||
__OMP_RTL(__kmpc_target_deinit, false, Void, IdentPtr, Int1, Int1)
|
||||
__OMP_RTL(__kmpc_kernel_prepare_parallel, false, Void, VoidPtr)
|
||||
__OMP_RTL(__kmpc_parallel_51, false, Void, IdentPtr, Int32, Int32, Int32, Int32,
|
||||
VoidPtr, VoidPtr, VoidPtrPtr, SizeTy)
|
||||
|
@ -20,6 +20,7 @@
|
||||
#include "llvm/IR/DebugInfo.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/MDBuilder.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/Error.h"
|
||||
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
|
||||
@ -2191,6 +2192,70 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
|
||||
return Builder.CreateCall(Fn, Args);
|
||||
}
|
||||
|
||||
OpenMPIRBuilder::InsertPointTy
|
||||
OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD, bool RequiresFullRuntime) {
|
||||
if (!updateToLocation(Loc))
|
||||
return Loc.IP;
|
||||
|
||||
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
|
||||
Value *Ident = getOrCreateIdent(SrcLocStr);
|
||||
ConstantInt *IsSPMDVal = ConstantInt::getBool(Int32->getContext(), IsSPMD);
|
||||
ConstantInt *UseGenericStateMachine =
|
||||
ConstantInt::getBool(Int32->getContext(), !IsSPMD);
|
||||
ConstantInt *RequiresFullRuntimeVal = ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime);
|
||||
|
||||
Function *Fn = getOrCreateRuntimeFunctionPtr(
|
||||
omp::RuntimeFunction::OMPRTL___kmpc_target_init);
|
||||
|
||||
CallInst *ThreadKind =
|
||||
Builder.CreateCall(Fn, {Ident, IsSPMDVal, UseGenericStateMachine, RequiresFullRuntimeVal});
|
||||
|
||||
Value *ExecUserCode = Builder.CreateICmpEQ(
|
||||
ThreadKind, ConstantInt::get(ThreadKind->getType(), -1), "exec_user_code");
|
||||
|
||||
// ThreadKind = __kmpc_target_init(...)
|
||||
// if (ThreadKind == -1)
|
||||
// user_code
|
||||
// else
|
||||
// return;
|
||||
|
||||
auto *UI = Builder.CreateUnreachable();
|
||||
BasicBlock *CheckBB = UI->getParent();
|
||||
BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry");
|
||||
|
||||
BasicBlock *WorkerExitBB = BasicBlock::Create(
|
||||
CheckBB->getContext(), "worker.exit", CheckBB->getParent());
|
||||
Builder.SetInsertPoint(WorkerExitBB);
|
||||
Builder.CreateRetVoid();
|
||||
|
||||
auto *CheckBBTI = CheckBB->getTerminator();
|
||||
Builder.SetInsertPoint(CheckBBTI);
|
||||
Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB);
|
||||
|
||||
CheckBBTI->eraseFromParent();
|
||||
UI->eraseFromParent();
|
||||
|
||||
// Continue in the "user_code" block, see diagram above and in
|
||||
// openmp/libomptarget/deviceRTLs/common/include/target.h .
|
||||
return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
|
||||
}
|
||||
|
||||
void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
|
||||
bool IsSPMD, bool RequiresFullRuntime) {
|
||||
if (!updateToLocation(Loc))
|
||||
return;
|
||||
|
||||
Constant *SrcLocStr = getOrCreateSrcLocStr(Loc);
|
||||
Value *Ident = getOrCreateIdent(SrcLocStr);
|
||||
ConstantInt *IsSPMDVal = ConstantInt::getBool(Int32->getContext(), IsSPMD);
|
||||
ConstantInt *RequiresFullRuntimeVal = ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime);
|
||||
|
||||
Function *Fn = getOrCreateRuntimeFunctionPtr(
|
||||
omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
|
||||
|
||||
Builder.CreateCall(Fn, {Ident, IsSPMDVal, RequiresFullRuntimeVal});
|
||||
}
|
||||
|
||||
std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
|
||||
StringRef FirstSeparator,
|
||||
StringRef Separator) {
|
||||
|
@ -26,9 +26,6 @@
|
||||
#include "llvm/Frontend/OpenMP/OMPConstants.h"
|
||||
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
|
||||
#include "llvm/IR/IntrinsicInst.h"
|
||||
#include "llvm/IR/IntrinsicsAMDGPU.h"
|
||||
#include "llvm/IR/IntrinsicsNVPTX.h"
|
||||
#include "llvm/IR/PatternMatch.h"
|
||||
#include "llvm/InitializePasses.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Transforms/IPO.h"
|
||||
@ -37,7 +34,6 @@
|
||||
#include "llvm/Transforms/Utils/CallGraphUpdater.h"
|
||||
#include "llvm/Transforms/Utils/CodeExtractor.h"
|
||||
|
||||
using namespace llvm::PatternMatch;
|
||||
using namespace llvm;
|
||||
using namespace omp;
|
||||
|
||||
@ -2341,10 +2337,12 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
|
||||
AllCallSitesKnown))
|
||||
SingleThreadedBBs.erase(&F->getEntryBlock());
|
||||
|
||||
// Check if the edge into the successor block compares a thread-id function to
|
||||
// a constant zero.
|
||||
// TODO: Use AAValueSimplify to simplify and propogate constants.
|
||||
// TODO: Check more than a single use for thread ID's.
|
||||
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
|
||||
auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
|
||||
|
||||
// Check if the edge into the successor block compares the __kmpc_target_init
|
||||
// result with -1. If we are in non-SPMD-mode that signals only the main
|
||||
// thread will execute the edge.
|
||||
auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
|
||||
if (!Edge || !Edge->isConditional())
|
||||
return false;
|
||||
@ -2355,31 +2353,20 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
|
||||
if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
|
||||
return false;
|
||||
|
||||
// Temporarily match the pattern generated by clang for teams regions.
|
||||
// TODO: Remove this once the new runtime is in place.
|
||||
ConstantInt *One, *NegOne;
|
||||
CmpInst::Predicate Pred;
|
||||
auto &&m_ThreadID = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_tid_x>();
|
||||
auto &&m_WarpSize = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_warpsize>();
|
||||
auto &&m_BlockSize = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_ntid_x>();
|
||||
if (match(Cmp, m_Cmp(Pred, m_ThreadID,
|
||||
m_And(m_Sub(m_BlockSize, m_ConstantInt(One)),
|
||||
m_Xor(m_Sub(m_WarpSize, m_ConstantInt(One)),
|
||||
m_ConstantInt(NegOne))))))
|
||||
if (One->isOne() && NegOne->isMinusOne() &&
|
||||
Pred == CmpInst::Predicate::ICMP_EQ)
|
||||
return true;
|
||||
|
||||
ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
|
||||
if (!C || !C->isZero())
|
||||
if (!C)
|
||||
return false;
|
||||
|
||||
if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
|
||||
if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
|
||||
return true;
|
||||
if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
|
||||
if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
|
||||
return true;
|
||||
// Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
|
||||
if (C->isAllOnesValue()) {
|
||||
auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
|
||||
if (!CB || CB->getCalledFunction() != RFI.Declaration)
|
||||
return false;
|
||||
const int InitIsSPMDArgNo = 1;
|
||||
auto *IsSPMDModeCI =
|
||||
dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo));
|
||||
return IsSPMDModeCI && IsSPMDModeCI->isZero();
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
@ -2394,7 +2381,7 @@ ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
|
||||
for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
|
||||
PredBB != PredEndBB; ++PredBB) {
|
||||
if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
|
||||
BB))
|
||||
BB))
|
||||
IsInitialThread &= SingleThreadedBBs.contains(*PredBB);
|
||||
}
|
||||
|
||||
|
@ -3,10 +3,16 @@
|
||||
target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
|
||||
target triple = "nvptx64"
|
||||
|
||||
%struct.ident_t = type { i32, i32, i32, i32, i8* }
|
||||
|
||||
@S = external local_unnamed_addr global i8*
|
||||
@0 = private unnamed_addr constant [113 x i8] c";llvm/test/Transforms/OpenMP/custom_state_machines_remarks.c;__omp_offloading_2a_d80d3d_test_fallback_l11;11;1;;\00", align 1
|
||||
@1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([113 x i8], [113 x i8]* @0, i32 0, i32 0) }, align 8
|
||||
|
||||
; CHECK-REMARKS: remark: replace_globalization.c:5:7: Replaced globalized variable with 16 bytes of shared memory
|
||||
; CHECK-REMARKS: remark: replace_globalization.c:5:14: Replaced globalized variable with 4 bytes of shared memory
|
||||
; CHECK-REMARKS-NOT: 6 bytes
|
||||
|
||||
; CHECK: [[SHARED_X:@.+]] = internal addrspace(3) global [16 x i8] undef
|
||||
; CHECK: [[SHARED_Y:@.+]] = internal addrspace(3) global [4 x i8] undef
|
||||
|
||||
@ -25,14 +31,15 @@ entry:
|
||||
define void @bar() {
|
||||
call void @baz()
|
||||
call void @qux()
|
||||
call void @negative_qux_spmd()
|
||||
ret void
|
||||
}
|
||||
|
||||
; CHECK: call void @use.internalized(i8* nofree writeonly addrspacecast (i8 addrspace(3)* getelementptr inbounds ([16 x i8], [16 x i8] addrspace(3)* [[SHARED_X]], i32 0, i32 0) to i8*))
|
||||
define internal void @baz() {
|
||||
entry:
|
||||
%tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
|
||||
%cmp = icmp eq i32 %tid, 0
|
||||
%call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 false, i1 true)
|
||||
%cmp = icmp eq i32 %call, -1
|
||||
br i1 %cmp, label %master, label %exit
|
||||
master:
|
||||
%x = call i8* @__kmpc_alloc_shared(i64 16), !dbg !11
|
||||
@ -48,20 +55,30 @@ exit:
|
||||
; CHECK: call void @use.internalized(i8* nofree writeonly addrspacecast (i8 addrspace(3)* getelementptr inbounds ([4 x i8], [4 x i8] addrspace(3)* [[SHARED_Y]], i32 0, i32 0) to i8*))
|
||||
define internal void @qux() {
|
||||
entry:
|
||||
%tid = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
|
||||
%ntid = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
|
||||
%warpsize = call i32 @llvm.nvvm.read.ptx.sreg.warpsize()
|
||||
%0 = sub nuw i32 %warpsize, 1
|
||||
%1 = sub nuw i32 %ntid, 1
|
||||
%2 = xor i32 %0, -1
|
||||
%master_tid = and i32 %1, %2
|
||||
%3 = icmp eq i32 %tid, %master_tid
|
||||
br i1 %3, label %master, label %exit
|
||||
%call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 true, i1 true)
|
||||
%0 = icmp eq i32 %call, -1
|
||||
br i1 %0, label %master, label %exit
|
||||
master:
|
||||
%y = call i8* @__kmpc_alloc_shared(i64 4), !dbg !12
|
||||
%y_on_stack = bitcast i8* %y to [4 x i32]*
|
||||
%4 = bitcast [4 x i32]* %y_on_stack to i8*
|
||||
call void @use(i8* %4)
|
||||
%1 = bitcast [4 x i32]* %y_on_stack to i8*
|
||||
call void @use(i8* %1)
|
||||
call void @__kmpc_free_shared(i8* %y)
|
||||
br label %exit
|
||||
exit:
|
||||
ret void
|
||||
}
|
||||
|
||||
define internal void @negative_qux_spmd() {
|
||||
entry:
|
||||
%call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 true, i1 true, i1 true)
|
||||
%0 = icmp eq i32 %call, -1
|
||||
br i1 %0, label %master, label %exit
|
||||
master:
|
||||
%y = call i8* @__kmpc_alloc_shared(i64 6), !dbg !12
|
||||
%y_on_stack = bitcast i8* %y to [6 x i32]*
|
||||
%1 = bitcast [6 x i32]* %y_on_stack to i8*
|
||||
call void @use(i8* %1)
|
||||
call void @__kmpc_free_shared(i8* %y)
|
||||
br label %exit
|
||||
exit:
|
||||
@ -85,6 +102,7 @@ declare i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
|
||||
|
||||
declare i32 @llvm.nvvm.read.ptx.sreg.warpsize()
|
||||
|
||||
declare i32 @__kmpc_target_init(%struct.ident_t*, i1, i1, i1)
|
||||
|
||||
!llvm.dbg.cu = !{!0}
|
||||
!llvm.module.flags = !{!3, !4, !5, !6}
|
||||
|
@ -3,8 +3,13 @@
|
||||
; REQUIRES: asserts
|
||||
; ModuleID = 'single_threaded_exeuction.c'
|
||||
|
||||
define weak void @kernel() {
|
||||
call void @__kmpc_kernel_init(i32 512, i16 1)
|
||||
%struct.ident_t = type { i32, i32, i32, i32, i8* }
|
||||
|
||||
@0 = private unnamed_addr constant [1 x i8] c"\00", align 1
|
||||
@1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 0, i8* getelementptr inbounds ([1 x i8], [1 x i8]* @0, i32 0, i32 0) }, align 8
|
||||
|
||||
define void @kernel() {
|
||||
call void @__kmpc_kernel_prepare_parallel(i8* null)
|
||||
call void @nvptx()
|
||||
call void @amdgcn()
|
||||
ret void
|
||||
@ -19,8 +24,8 @@ define weak void @kernel() {
|
||||
; Function Attrs: noinline
|
||||
define internal void @nvptx() {
|
||||
entry:
|
||||
%call = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
|
||||
%cmp = icmp eq i32 %call, 0
|
||||
%call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 false, i1 false)
|
||||
%cmp = icmp eq i32 %call, -1
|
||||
br i1 %cmp, label %if.then, label %if.end
|
||||
|
||||
if.then:
|
||||
@ -40,8 +45,8 @@ if.end:
|
||||
; Function Attrs: noinline
|
||||
define internal void @amdgcn() {
|
||||
entry:
|
||||
%call = call i32 @llvm.amdgcn.workitem.id.x()
|
||||
%cmp = icmp eq i32 %call, 0
|
||||
%call = call i32 @__kmpc_target_init(%struct.ident_t* nonnull @1, i1 false, i1 true, i1 true)
|
||||
%cmp = icmp eq i32 %call, -1
|
||||
br i1 %cmp, label %if.then, label %if.end
|
||||
|
||||
if.then:
|
||||
@ -87,7 +92,9 @@ declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
|
||||
|
||||
declare i32 @llvm.amdgcn.workitem.id.x()
|
||||
|
||||
declare void @__kmpc_kernel_init(i32, i16)
|
||||
declare void @__kmpc_kernel_prepare_parallel(i8*)
|
||||
|
||||
declare i32 @__kmpc_target_init(%struct.ident_t*, i1, i1, i1)
|
||||
|
||||
attributes #0 = { cold noinline }
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user