1
0
mirror of https://github.com/RPCS3/llvm-mirror.git synced 2024-11-22 02:33:06 +01:00

[AbstractAttributor] Fold __kmpc_parallel_level if possible

Similar to D105787, this patch tries to fold `__kmpc_parallel_level` if possible.
Note that `__kmpc_parallel_level` doesn't take activeness into consideration,
based on current `deviceRTLs`, its return value can be such as 0, 1, 2, instead
of 0, 129, 130, etc. that also indicate activeness.

Reviewed By: jdoerfert

Differential Revision: https://reviews.llvm.org/D106154
This commit is contained in:
Shilei Tian 2021-07-26 22:45:52 -04:00
parent 6a849a320d
commit 14491b35e0
2 changed files with 279 additions and 1 deletions

View File

@ -519,6 +519,11 @@ struct KernelInfoState : AbstractState {
/// State to track what kernel entries can reach the associated function.
BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
/// State to indicate if we can track parallel level of the associated
/// function. We will give up tracking if we encounter unknown caller or the
/// caller is __kmpc_parallel_51.
BooleanStateWithSetVector<uint8_t> ParallelLevels;
/// Abstract State interface
///{
@ -3329,8 +3334,10 @@ struct AAKernelInfoFunction : AAKernelInfo {
CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
SPMDCompatibilityTracker.indicatePessimisticFixpoint();
if (!IsKernelEntry)
if (!IsKernelEntry) {
updateReachingKernelEntries(A);
updateParallelLevels(A);
}
// Callback to check a call instruction.
bool AllSPMDStatesWereFixed = true;
@ -3386,6 +3393,49 @@ private:
AllCallSitesKnown))
ReachingKernelEntries.indicatePessimisticFixpoint();
}
/// Update info regarding parallel levels.
void updateParallelLevels(Attributor &A) {
auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
auto PredCallSite = [&](AbstractCallSite ACS) {
Function *Caller = ACS.getInstruction()->getFunction();
assert(Caller && "Caller is nullptr");
auto &CAA =
A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
if (CAA.ParallelLevels.isValidState()) {
// Any function that is called by `__kmpc_parallel_51` will not be
// folded as the parallel level in the function is updated. In order to
// get it right, all the analysis would depend on the implentation. That
// said, if in the future any change to the implementation, the analysis
// could be wrong. As a consequence, we are just conservative here.
if (Caller == Parallel51RFI.Declaration) {
ParallelLevels.indicatePessimisticFixpoint();
return true;
}
ParallelLevels ^= CAA.ParallelLevels;
return true;
}
// We lost track of the caller of the associated function, any kernel
// could reach now.
ParallelLevels.indicatePessimisticFixpoint();
return true;
};
bool AllCallSitesKnown = true;
if (!A.checkForAllCallSites(PredCallSite, *this,
true /* RequireAllCallSites */,
AllCallSitesKnown))
ParallelLevels.indicatePessimisticFixpoint();
}
};
/// The call site kernel info abstract attribute, basically, what can we say
@ -3668,6 +3718,9 @@ struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
case OMPRTL___kmpc_is_generic_main_thread_id:
Changed |= foldIsGenericMainThread(A);
break;
case OMPRTL___kmpc_parallel_level:
Changed |= foldParallelLevel(A);
break;
default:
llvm_unreachable("Unhandled OpenMP runtime function!");
}
@ -3782,6 +3835,68 @@ private:
: ChangeStatus::CHANGED;
}
/// Fold __kmpc_parallel_level into a constant if possible.
ChangeStatus foldParallelLevel(Attributor &A) {
Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
*this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
if (!CallerKernelInfoAA.ParallelLevels.isValidState())
return indicatePessimisticFixpoint();
if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
return indicatePessimisticFixpoint();
if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
assert(!SimplifiedValue.hasValue() &&
"SimplifiedValue should keep none at this point");
return ChangeStatus::UNCHANGED;
}
unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
DepClassTy::REQUIRED);
if (!AA.SPMDCompatibilityTracker.isValidState())
return indicatePessimisticFixpoint();
if (AA.SPMDCompatibilityTracker.isAssumed()) {
if (AA.SPMDCompatibilityTracker.isAtFixpoint())
++KnownSPMDCount;
else
++AssumedSPMDCount;
} else {
if (AA.SPMDCompatibilityTracker.isAtFixpoint())
++KnownNonSPMDCount;
else
++AssumedNonSPMDCount;
}
}
if ((AssumedSPMDCount + KnownSPMDCount) &&
(AssumedNonSPMDCount + KnownNonSPMDCount))
return indicatePessimisticFixpoint();
auto &Ctx = getAnchorValue().getContext();
// If the caller can only be reached by SPMD kernel entries, the parallel
// level is 1. Similarly, if the caller can only be reached by non-SPMD
// kernel entries, it is 0.
if (AssumedSPMDCount || KnownSPMDCount) {
assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
"Expected only SPMD kernels!");
SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
} else {
assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
"Expected only non-SPMD kernels!");
SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
}
return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
: ChangeStatus::CHANGED;
}
/// An optional value the associated value is assumed to fold to. That is, we
/// assume the associated value (which is a call) can be replaced by this
/// simplified value.
@ -3832,6 +3947,19 @@ void OpenMPOpt::registerAAs(bool IsModulePass) {
/* UpdateAfterInit */ false);
return false;
});
auto &ParallelLevelRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_level];
ParallelLevelRFI.foreachUse(SCC, [&](Use &U, Function &) {
CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &ParallelLevelRFI);
if (!CI)
return false;
A.getOrCreateAAFor<AAFoldRuntimeCall>(
IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
DepClassTy::NONE, /* ForceUpdate */ false,
/* UpdateAfterInit */ false);
return false;
});
}
// Create CallSite AA for all Getters.

View File

@ -0,0 +1,150 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --function-signature --check-globals
; RUN: opt -S -passes=openmp-opt < %s | FileCheck %s
target triple = "nvptx64"
%struct.ident_t = type { i32, i32, i32, i32, i8* }
@no_spmd_exec_mode = weak constant i8 1
@spmd_exec_mode = weak constant i8 0
@parallel_exec_mode = weak constant i8 0
@G = external global i8
@llvm.compiler.used = appending global [3 x i8*] [i8* @no_spmd_exec_mode, i8* @spmd_exec_mode, i8* @parallel_exec_mode], section "llvm.metadata"
;.
; CHECK: @[[NO_SPMD_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 1
; CHECK: @[[SPMD_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0
; CHECK: @[[PARALLEL_EXEC_MODE:[a-zA-Z0-9_$"\\.-]+]] = weak constant i8 0
; CHECK: @[[G:[a-zA-Z0-9_$"\\.-]+]] = external global i8
; CHECK: @[[LLVM_COMPILER_USED:[a-zA-Z0-9_$"\\.-]+]] = appending global [3 x i8*] [i8* @no_spmd_exec_mode, i8* @spmd_exec_mode, i8* @parallel_exec_mode], section "llvm.metadata"
;.
define weak void @none_spmd() {
; CHECK-LABEL: define {{[^@]+}}@none_spmd() {
; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false)
; CHECK-NEXT: call void @none_spmd_helper()
; CHECK-NEXT: call void @mixed_helper()
; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false)
; CHECK-NEXT: ret void
;
%i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 false, i1 false, i1 false)
call void @none_spmd_helper()
call void @mixed_helper()
call void @__kmpc_target_deinit(%struct.ident_t* null, i1 false, i1 false)
ret void
}
define weak void @spmd() {
; CHECK-LABEL: define {{[^@]+}}@spmd() {
; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
; CHECK-NEXT: call void @spmd_helper()
; CHECK-NEXT: call void @mixed_helper()
; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
; CHECK-NEXT: ret void
;
%i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
call void @spmd_helper()
call void @mixed_helper()
call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
ret void
}
define weak void @parallel() {
; CHECK-LABEL: define {{[^@]+}}@parallel() {
; CHECK-NEXT: [[I:%.*]] = call i32 @__kmpc_target_init(%struct.ident_t* align 536870912 null, i1 true, i1 false, i1 false)
; CHECK-NEXT: call void @spmd_helper()
; CHECK-NEXT: call void @__kmpc_parallel_51(%struct.ident_t* noalias noundef align 536870912 null, i32 noundef 0, i32 noundef 0, i32 noundef 0, i32 noundef 0, i8* noalias noundef align 536870912 null, i8* noalias noundef align 536870912 null, i8** noalias noundef align 536870912 null, i64 noundef 0)
; CHECK-NEXT: call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
; CHECK-NEXT: ret void
;
%i = call i32 @__kmpc_target_init(%struct.ident_t* null, i1 true, i1 false, i1 false)
call void @spmd_helper()
call void @__kmpc_parallel_51(%struct.ident_t* null, i32 0, i32 0, i32 0, i32 0, i8* null, i8* null, i8** null, i64 0)
call void @__kmpc_target_deinit(%struct.ident_t* null, i1 true, i1 false)
ret void
}
define internal void @mixed_helper() {
; CHECK-LABEL: define {{[^@]+}}@mixed_helper() {
; CHECK-NEXT: [[LEVEL:%.*]] = call i8 @__kmpc_parallel_level()
; CHECK-NEXT: store i8 [[LEVEL]], i8* @G, align 1
; CHECK-NEXT: ret void
;
%level = call i8 @__kmpc_parallel_level()
store i8 %level, i8* @G
ret void
}
define internal void @none_spmd_helper() {
; CHECK-LABEL: define {{[^@]+}}@none_spmd_helper() {
; CHECK-NEXT: [[LEVEL12:%.*]] = call i8 @__kmpc_parallel_level()
; CHECK-NEXT: [[C:%.*]] = icmp eq i8 [[LEVEL12]], 0
; CHECK-NEXT: br i1 [[C]], label [[T:%.*]], label [[F:%.*]]
; CHECK: t:
; CHECK-NEXT: call void @foo()
; CHECK-NEXT: ret void
; CHECK: f:
; CHECK-NEXT: call void @bar()
; CHECK-NEXT: ret void
;
%level12 = call i8 @__kmpc_parallel_level()
%c = icmp eq i8 %level12, 0
br i1 %c, label %t, label %f
t:
call void @foo()
ret void
f:
call void @bar()
ret void
}
define internal void @spmd_helper() {
; CHECK-LABEL: define {{[^@]+}}@spmd_helper() {
; CHECK-NEXT: store i8 1, i8* @G, align 1
; CHECK-NEXT: ret void
;
%level = call i8 @__kmpc_parallel_level()
store i8 %level, i8* @G
ret void
}
define internal void @__kmpc_parallel_51(%struct.ident_t*, i32, i32, i32, i32, i8*, i8*, i8**, i64) {
; CHECK-LABEL: define {{[^@]+}}@__kmpc_parallel_51
; CHECK-SAME: (%struct.ident_t* noalias nocapture nofree readnone align 536870912 [[TMP0:%.*]], i32 [[TMP1:%.*]], i32 [[TMP2:%.*]], i32 [[TMP3:%.*]], i32 [[TMP4:%.*]], i8* noalias nocapture nofree readnone align 536870912 [[TMP5:%.*]], i8* noalias nocapture nofree readnone align 536870912 [[TMP6:%.*]], i8** noalias nocapture nofree readnone align 536870912 [[TMP7:%.*]], i64 [[TMP8:%.*]]) {
; CHECK-NEXT: call void @parallel_helper()
; CHECK-NEXT: ret void
;
call void @parallel_helper()
ret void
}
define internal void @parallel_helper() {
; CHECK-LABEL: define {{[^@]+}}@parallel_helper() {
; CHECK-NEXT: [[LEVEL:%.*]] = call i8 @__kmpc_parallel_level()
; CHECK-NEXT: store i8 [[LEVEL]], i8* @G, align 1
; CHECK-NEXT: ret void
;
%level = call i8 @__kmpc_parallel_level()
store i8 %level, i8* @G
ret void
}
declare void @foo()
declare void @bar()
declare i8 @__kmpc_parallel_level()
declare i32 @__kmpc_target_init(%struct.ident_t*, i1 zeroext, i1 zeroext, i1 zeroext) #1
declare void @__kmpc_target_deinit(%struct.ident_t* nocapture readnone, i1 zeroext, i1 zeroext) #1
!llvm.module.flags = !{!0, !1}
!nvvm.annotations = !{!2, !3, !4}
!0 = !{i32 7, !"openmp", i32 50}
!1 = !{i32 7, !"openmp-device", i32 50}
!2 = !{void ()* @none_spmd, !"kernel", i32 1}
!3 = !{void ()* @spmd, !"kernel", i32 1}
!4 = !{void ()* @parallel, !"kernel", i32 1}
;.
; CHECK: [[META0:![0-9]+]] = !{i32 7, !"openmp", i32 50}
; CHECK: [[META1:![0-9]+]] = !{i32 7, !"openmp-device", i32 50}
; CHECK: [[META2:![0-9]+]] = !{void ()* @none_spmd, !"kernel", i32 1}
; CHECK: [[META3:![0-9]+]] = !{void ()* @spmd, !"kernel", i32 1}
; CHECK: [[META4:![0-9]+]] = !{void ()* @parallel, !"kernel", i32 1}
;.