1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-22 10:42:39 +01:00

[OpenMP] Folding threadLimit and numThreads when single value in kernels

The device runtime contains several calls to `__kmpc_get_hardware_num_threads_in_block`
and `__kmpc_get_hardware_num_blocks`. If the thread_limit and the num_teams are constant,
these calls can be folded to the constant value.

In this patch we use the already introduced `AAFoldRuntimeCall` and the `NumTeams` and
`NumThreads` kernel attributes (to be introduced in a different patch) to fold these functions.
The code checks all the kernels, and if their attributes match, the functions are folded.

In the future we will explore specializing for multiple values of NumThreads and NumTeams.

Depends on D106390

Reviewed By: jdoerfert, JonChesterfield

Differential Revision: https://reviews.llvm.org/D106033
This commit is contained in:
Jose M Monsalve Diaz 2021-07-27 21:46:39 -04:00 committed by Shilei Tian
parent 66a78bc217
commit 5b7208da36
3 changed files with 196 additions and 37 deletions

View File

@ -206,6 +206,9 @@ __OMP_RTL(__kmpc_omp_reg_task_with_affinity, false, Int32, IdentPtr, Int32,
/* kmp_task_t */ VoidPtr, Int32, /* kmp_task_t */ VoidPtr, Int32,
/* kmp_task_affinity_info_t */ VoidPtr) /* kmp_task_affinity_info_t */ VoidPtr)
__OMP_RTL(__kmpc_get_hardware_num_blocks, false, Int32, )
__OMP_RTL(__kmpc_get_hardware_num_threads_in_block, false, Int32, )
__OMP_RTL(omp_get_thread_num, false, Int32, ) __OMP_RTL(omp_get_thread_num, false, Int32, )
__OMP_RTL(omp_get_num_threads, false, Int32, ) __OMP_RTL(omp_get_num_threads, false, Int32, )
__OMP_RTL(omp_get_max_threads, false, Int32, ) __OMP_RTL(omp_get_max_threads, false, Int32, )
@ -601,6 +604,9 @@ __OMP_RTL_ATTRS(__kmpc_omp_reg_task_with_affinity, DefaultAttrs, AttributeSet(),
ParamAttrs(ReadOnlyPtrAttrs, AttributeSet(), ReadOnlyPtrAttrs, ParamAttrs(ReadOnlyPtrAttrs, AttributeSet(), ReadOnlyPtrAttrs,
AttributeSet(), ReadOnlyPtrAttrs)) AttributeSet(), ReadOnlyPtrAttrs))
__OMP_RTL_ATTRS(__kmpc_get_hardware_num_blocks, GetterAttrs, AttributeSet(), ParamAttrs())
__OMP_RTL_ATTRS(__kmpc_get_hardware_num_threads_in_block, GetterAttrs, AttributeSet(), ParamAttrs())
__OMP_RTL_ATTRS(omp_get_thread_num, GetterAttrs, AttributeSet(), ParamAttrs()) __OMP_RTL_ATTRS(omp_get_thread_num, GetterAttrs, AttributeSet(), ParamAttrs())
__OMP_RTL_ATTRS(omp_get_num_threads, GetterAttrs, AttributeSet(), ParamAttrs()) __OMP_RTL_ATTRS(omp_get_num_threads, GetterAttrs, AttributeSet(), ParamAttrs())
__OMP_RTL_ATTRS(omp_get_max_threads, GetterAttrs, AttributeSet(), ParamAttrs()) __OMP_RTL_ATTRS(omp_get_max_threads, GetterAttrs, AttributeSet(), ParamAttrs())

View File

@ -1833,6 +1833,8 @@ private:
return Changed == ChangeStatus::CHANGED; return Changed == ChangeStatus::CHANGED;
} }
void registerFoldRuntimeCall(RuntimeFunction RF);
/// Populate the Attributor with abstract attribute opportunities in the /// Populate the Attributor with abstract attribute opportunities in the
/// function. /// function.
void registerAAs(bool IsModulePass); void registerAAs(bool IsModulePass);
@ -3506,6 +3508,8 @@ struct AAKernelInfoCallSite : AAKernelInfo {
case OMPRTL___kmpc_is_spmd_exec_mode: case OMPRTL___kmpc_is_spmd_exec_mode:
case OMPRTL___kmpc_for_static_fini: case OMPRTL___kmpc_for_static_fini:
case OMPRTL___kmpc_global_thread_num: case OMPRTL___kmpc_global_thread_num:
case OMPRTL___kmpc_get_hardware_num_threads_in_block:
case OMPRTL___kmpc_get_hardware_num_blocks:
case OMPRTL___kmpc_single: case OMPRTL___kmpc_single:
case OMPRTL___kmpc_end_single: case OMPRTL___kmpc_end_single:
case OMPRTL___kmpc_master: case OMPRTL___kmpc_master:
@ -3710,7 +3714,6 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
ChangeStatus updateImpl(Attributor &A) override { ChangeStatus updateImpl(Attributor &A) override {
ChangeStatus Changed = ChangeStatus::UNCHANGED; ChangeStatus Changed = ChangeStatus::UNCHANGED;
switch (RFKind) { switch (RFKind) {
case OMPRTL___kmpc_is_spmd_exec_mode: case OMPRTL___kmpc_is_spmd_exec_mode:
Changed |= foldIsSPMDExecMode(A); Changed |= foldIsSPMDExecMode(A);
@ -3721,6 +3724,12 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
case OMPRTL___kmpc_parallel_level: case OMPRTL___kmpc_parallel_level:
Changed |= foldParallelLevel(A); Changed |= foldParallelLevel(A);
break; break;
case OMPRTL___kmpc_get_hardware_num_threads_in_block:
Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
break;
case OMPRTL___kmpc_get_hardware_num_blocks:
Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
break;
default: default:
llvm_unreachable("Unhandled OpenMP runtime function!"); llvm_unreachable("Unhandled OpenMP runtime function!");
} }
@ -3892,7 +3901,39 @@ private:
"Expected only non-SPMD kernels!"); "Expected only non-SPMD kernels!");
SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0); SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
} }
return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
: ChangeStatus::CHANGED;
}
ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
// Specialize only if all the calls agree with the attribute constant value
int32_t CurrentAttrValue = -1;
Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
return indicatePessimisticFixpoint();
// Iterate over the kernels that reach this function
for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
int32_t NextAttrVal = -1;
if (K->hasFnAttribute(Attr))
NextAttrVal =
std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
if (NextAttrVal == -1 ||
(CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
return indicatePessimisticFixpoint();
CurrentAttrValue = NextAttrVal;
}
if (CurrentAttrValue != -1) {
auto &Ctx = getAnchorValue().getContext();
SimplifiedValue =
ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
}
return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
: ChangeStatus::CHANGED; : ChangeStatus::CHANGED;
} }
@ -3908,6 +3949,21 @@ private:
} // namespace } // namespace
/// Register folding callsite
void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
auto &RFI = OMPInfoCache.RFIs[RF];
RFI.foreachUse(SCC, [&](Use &U, Function &F) {
CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
if (!CI)
return false;
A.getOrCreateAAFor<AAFoldRuntimeCall>(
IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
DepClassTy::NONE, /* ForceUpdate */ false,
/* UpdateAfterInit */ false);
return false;
});
}
void OpenMPOpt::registerAAs(bool IsModulePass) { void OpenMPOpt::registerAAs(bool IsModulePass) {
if (SCC.empty()) if (SCC.empty())
@ -3923,43 +3979,12 @@ void OpenMPOpt::registerAAs(bool IsModulePass) {
DepClassTy::NONE, /* ForceUpdate */ false, DepClassTy::NONE, /* ForceUpdate */ false,
/* UpdateAfterInit */ false); /* UpdateAfterInit */ false);
auto &IsMainRFI =
OMPInfoCache.RFIs[OMPRTL___kmpc_is_generic_main_thread_id];
IsMainRFI.foreachUse(SCC, [&](Use &U, Function &F) {
CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &IsMainRFI);
if (!CI)
return false;
A.getOrCreateAAFor<AAFoldRuntimeCall>(
IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
DepClassTy::NONE, /* ForceUpdate */ false,
/* UpdateAfterInit */ false);
return false;
});
auto &IsSPMDRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_is_spmd_exec_mode]; registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
IsSPMDRFI.foreachUse(SCC, [&](Use &U, Function &) { registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &IsSPMDRFI); registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
if (!CI) registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
return false; registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
A.getOrCreateAAFor<AAFoldRuntimeCall>(
IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
DepClassTy::NONE, /* ForceUpdate */ false,
/* UpdateAfterInit */ false);
return false;
});
auto &ParallelLevelRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_level];
ParallelLevelRFI.foreachUse(SCC, [&](Use &U, Function &) {
CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &ParallelLevelRFI);
if (!CI)
return false;
A.getOrCreateAAFor<AAFoldRuntimeCall>(
IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
DepClassTy::NONE, /* ForceUpdate */ false,
/* UpdateAfterInit */ false);
return false;
});
} }
// Create CallSite AA for all Getters. // Create CallSite AA for all Getters.

View File

@ -0,0 +1,128 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --check-globals
; RUN: opt -S -passes=openmp-opt < %s | FileCheck %s
target triple = "nvptx64"
%struct.ident_t = type { i32, i32, i32, i32, i8* }
@kernel0_exec_mode = weak constant i8 1
@G = external global i32
;.
; CHECK: @[[G:[a-zA-Z0-9_$"\\.-]+]] = external global i32
;.
define weak void @kernel0() #0 {
; CHECK-LABEL: define {{[^@]+}}@kernel0()
; CHECK: #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
; CHECK-NEXT: call void @helper0()
; CHECK-NEXT: call void @helper1()
; CHECK-NEXT: call void @helper2()
; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
; CHECK-NEXT: ret void
;
%i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
call void @helper0()
call void @helper1()
call void @helper2()
call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
ret void
}
@kernel1_exec_mode = weak constant i8 1
define weak void @kernel1() #0 {
; CHECK-LABEL: define {{[^@]+}}@kernel1()
; CHECK: #[[ATTR0]] {
; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
; CHECK-NEXT: call void @helper1()
; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false)
; CHECK-NEXT: ret void
;
%i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
call void @helper1()
call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false)
ret void
}
@kernel2_exec_mode = weak constant i8 1
define weak void @kernel2() #0 {
; CHECK-LABEL: define {{[^@]+}}@kernel2()
; CHECK: #[[ATTR0]] {
; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false)
; CHECK-NEXT: call void @helper0()
; CHECK-NEXT: call void @helper1()
; CHECK-NEXT: call void @helper2()
; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false)
; CHECK-NEXT: ret void
;
%i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false)
call void @helper0()
call void @helper1()
call void @helper2()
call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false)
ret void
}
define internal void @helper0() {
; CHECK-LABEL: define {{[^@]+}}@helper0() {{#[0-9]+}} {
; CHECK-NEXT: store i32 666, i32* @G, align 4
; CHECK-NEXT: ret void
;
%threadLimit = call i32 @__kmpc_get_hardware_num_threads_in_block()
store i32 %threadLimit, i32* @G
ret void
}
define internal void @helper1() {
; CHECK-LABEL: define {{[^@]+}}@helper1() {{#[0-9]+}} {
; CHECK-NEXT: br label [[F:%.*]]
; CHECK: t:
; CHECK-NEXT: unreachable
; CHECK: f:
; CHECK-NEXT: ret void
;
%threadLimit = call i32 @__kmpc_get_hardware_num_threads_in_block()
%c = icmp eq i32 %threadLimit, 666
br i1 %c, label %f, label %t
t:
call void @helper0()
ret void
f:
ret void
}
define internal void @helper2() {
; CHECK-LABEL: define {{[^@]+}}@helper2() {{#[0-9]+}} {
; CHECK-NEXT: store i32 666, i32* @G
; CHECK-NEXT: ret void
;
%threadLimit = call i32 @__kmpc_get_hardware_num_threads_in_block()
store i32 %threadLimit, i32* @G
ret void
}
declare i32 @__kmpc_get_hardware_num_threads_in_block()
declare i32 @__kmpc_target_init(%struct.ident_t*, i1 zeroext, i1 zeroext, i1 zeroext) #1
declare void @__kmpc_target_deinit(%struct.ident_t* nocapture readnone, i1 zeroext, i1 zeroext) #1
!llvm.module.flags = !{!0, !1}
!nvvm.annotations = !{!2, !3, !4}
attributes #0 = { "omp_target_thread_limit"="666" "omp_target_num_teams"="777"}
!0 = !{i32 7, !"openmp", i32 50}
!1 = !{i32 7, !"openmp-device", i32 50}
!2 = !{void ()* @kernel0, !"kernel", i32 1}
!3 = !{void ()* @kernel1, !"kernel", i32 1}
!4 = !{void ()* @kernel2, !"kernel", i32 1}
;.
; CHECK: attributes #[[ATTR0]] = { "omp_target_num_teams"="777" "omp_target_thread_limit"="666" }
;
; CHECK: [[META0:![0-9]+]] = !{i32 7, !"openmp", i32 50}
; CHECK: [[META1:![0-9]+]] = !{i32 7, !"openmp-device", i32 50}
; CHECK: [[META2:![0-9]+]] = !{void ()* @kernel0, !"kernel", i32 1}
; CHECK: [[META3:![0-9]+]] = !{void ()* @kernel1, !"kernel", i32 1}
; CHECK: [[META4:![0-9]+]] = !{void ()* @kernel2, !"kernel", i32 1}
;.