mirror of
https://github.com/RPCS3/llvm-mirror.git
synced 2024-11-22 10:42:39 +01:00
[X86] Add x86_amx type for intel AMX.
The x86_amx is used for AMX intrisics. <256 x i32> is bitcast to x86_amx when it is used by AMX intrinsics, and x86_amx is bitcast to <256 x i32> when it is used by load/store instruction. So amx intrinsics only operate on type x86_amx. It can help to separate amx intrinsics from llvm IR instructions (+-*/). Thank Craig for the idea. This patch depend on https://reviews.llvm.org/D87981. Differential Revision: https://reviews.llvm.org/D91927
This commit is contained in:
parent
0aec9ce280
commit
4ef6280b52
@ -160,6 +160,7 @@ typedef enum {
|
||||
LLVMVectorTypeKind, /**< Fixed width SIMD vector type */
|
||||
LLVMMetadataTypeKind, /**< Metadata */
|
||||
LLVMX86_MMXTypeKind, /**< X86 MMX */
|
||||
LLVMX86_AMXTypeKind, /**< X86 AMX */
|
||||
LLVMTokenTypeKind, /**< Tokens */
|
||||
LLVMScalableVectorTypeKind, /**< Scalable SIMD vector type */
|
||||
LLVMBFloatTypeKind /**< 16 bit brain floating point type */
|
||||
@ -1493,6 +1494,11 @@ LLVMTypeRef LLVMLabelTypeInContext(LLVMContextRef C);
|
||||
*/
|
||||
LLVMTypeRef LLVMX86MMXTypeInContext(LLVMContextRef C);
|
||||
|
||||
/**
|
||||
* Create a X86 AMX type in a context.
|
||||
*/
|
||||
LLVMTypeRef LLVMX86AMXTypeInContext(LLVMContextRef C);
|
||||
|
||||
/**
|
||||
* Create a token type in a context.
|
||||
*/
|
||||
@ -1510,6 +1516,7 @@ LLVMTypeRef LLVMMetadataTypeInContext(LLVMContextRef C);
|
||||
LLVMTypeRef LLVMVoidType(void);
|
||||
LLVMTypeRef LLVMLabelType(void);
|
||||
LLVMTypeRef LLVMX86MMXType(void);
|
||||
LLVMTypeRef LLVMX86AMXType(void);
|
||||
|
||||
/**
|
||||
* @}
|
||||
|
@ -168,7 +168,8 @@ enum TypeCodes {
|
||||
|
||||
TYPE_CODE_TOKEN = 22, // TOKEN
|
||||
|
||||
TYPE_CODE_BFLOAT = 23 // BRAIN FLOATING POINT
|
||||
TYPE_CODE_BFLOAT = 23, // BRAIN FLOATING POINT
|
||||
TYPE_CODE_X86_AMX = 24 // X86 AMX
|
||||
};
|
||||
|
||||
enum OperandBundleTagCode {
|
||||
|
@ -196,6 +196,7 @@ def untyped: ValueType<8 , 160>; // Produces an untyped value
|
||||
def exnref : ValueType<0 , 161>; // WebAssembly's exnref type
|
||||
def funcref : ValueType<0 , 162>; // WebAssembly's funcref type
|
||||
def externref : ValueType<0 , 163>; // WebAssembly's externref type
|
||||
def x86amx : ValueType<8192, 164>; // X86 AMX value
|
||||
|
||||
|
||||
def token : ValueType<0 , 248>; // TokenTy
|
||||
|
@ -690,6 +690,8 @@ inline TypeSize DataLayout::getTypeSizeInBits(Type *Ty) const {
|
||||
case Type::PPC_FP128TyID:
|
||||
case Type::FP128TyID:
|
||||
return TypeSize::Fixed(128);
|
||||
case Type::X86_AMXTyID:
|
||||
return TypeSize::Fixed(8192);
|
||||
// In memory objects this is always aligned to a higher boundary, but
|
||||
// only 80 bits contain information.
|
||||
case Type::X86_FP80TyID:
|
||||
|
@ -125,7 +125,8 @@ namespace Intrinsic {
|
||||
VecElementArgument,
|
||||
Subdivide2Argument,
|
||||
Subdivide4Argument,
|
||||
VecOfBitcastsToInt
|
||||
VecOfBitcastsToInt,
|
||||
AMX
|
||||
} Kind;
|
||||
|
||||
union {
|
||||
|
@ -255,6 +255,8 @@ def llvm_token_ty : LLVMType<token>; // token
|
||||
def llvm_x86mmx_ty : LLVMType<x86mmx>;
|
||||
def llvm_ptrx86mmx_ty : LLVMPointerType<llvm_x86mmx_ty>; // <1 x i64>*
|
||||
|
||||
def llvm_x86amx_ty : LLVMType<x86amx>;
|
||||
|
||||
def llvm_v2i1_ty : LLVMType<v2i1>; // 2 x i1
|
||||
def llvm_v4i1_ty : LLVMType<v4i1>; // 4 x i1
|
||||
def llvm_v8i1_ty : LLVMType<v8i1>; // 8 x i1
|
||||
|
@ -5041,6 +5041,22 @@ let TargetPrefix = "x86" in {
|
||||
Intrinsic<[], [llvm_i8_ty, llvm_i8_ty, llvm_i8_ty],
|
||||
[ImmArg<ArgIndex<0>>, ImmArg<ArgIndex<1>>,
|
||||
ImmArg<ArgIndex<2>>]>;
|
||||
// AMX - internal intrinsics
|
||||
def int_x86_tileloadd64_internal :
|
||||
GCCBuiltin<"__builtin_ia32_tileloadd64_internal">,
|
||||
Intrinsic<[llvm_x86amx_ty],
|
||||
[llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty, llvm_i64_ty],
|
||||
[]>;
|
||||
def int_x86_tdpbssd_internal :
|
||||
GCCBuiltin<"__builtin_ia32_tdpbssd_internal">,
|
||||
Intrinsic<[llvm_x86amx_ty],
|
||||
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
|
||||
llvm_x86amx_ty, llvm_x86amx_ty,
|
||||
llvm_x86amx_ty], []>;
|
||||
def int_x86_tilestored64_internal :
|
||||
GCCBuiltin<"__builtin_ia32_tilestored64_internal">,
|
||||
Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty,
|
||||
llvm_i64_ty, llvm_x86amx_ty], []>;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
@ -5055,20 +5071,4 @@ let TargetPrefix = "x86" in {
|
||||
Intrinsic<[llvm_i8_ty], [], []>;
|
||||
def int_x86_senduipi : GCCBuiltin<"__builtin_ia32_senduipi">,
|
||||
Intrinsic<[], [llvm_i64_ty], []>;
|
||||
// AMX - internal intrinsics
|
||||
def int_x86_tileloadd64_internal :
|
||||
GCCBuiltin<"__builtin_ia32_tileloadd64_internal">,
|
||||
Intrinsic<[llvm_v256i32_ty],
|
||||
[llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty, llvm_i64_ty],
|
||||
[]>;
|
||||
def int_x86_tdpbssd_internal :
|
||||
GCCBuiltin<"__builtin_ia32_tdpbssd_internal">,
|
||||
Intrinsic<[llvm_v256i32_ty],
|
||||
[llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
|
||||
llvm_v256i32_ty, llvm_v256i32_ty,
|
||||
llvm_v256i32_ty], []>;
|
||||
def int_x86_tilestored64_internal :
|
||||
GCCBuiltin<"__builtin_ia32_tilestored64_internal">,
|
||||
Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty,
|
||||
llvm_i64_ty, llvm_v256i32_ty], []>;
|
||||
}
|
||||
|
@ -65,6 +65,7 @@ public:
|
||||
LabelTyID, ///< Labels
|
||||
MetadataTyID, ///< Metadata
|
||||
X86_MMXTyID, ///< MMX vectors (64 bits, X86 specific)
|
||||
X86_AMXTyID, ///< AMX vectors (8192 bits, X86 specific)
|
||||
TokenTyID, ///< Tokens
|
||||
|
||||
// Derived types... see DerivedTypes.h file.
|
||||
@ -182,6 +183,9 @@ public:
|
||||
/// Return true if this is X86 MMX.
|
||||
bool isX86_MMXTy() const { return getTypeID() == X86_MMXTyID; }
|
||||
|
||||
/// Return true if this is X86 AMX.
|
||||
bool isX86_AMXTy() const { return getTypeID() == X86_AMXTyID; }
|
||||
|
||||
/// Return true if this is a FP type or a vector of FP.
|
||||
bool isFPOrFPVectorTy() const { return getScalarType()->isFloatingPointTy(); }
|
||||
|
||||
@ -252,7 +256,7 @@ public:
|
||||
/// includes all first-class types except struct and array types.
|
||||
bool isSingleValueType() const {
|
||||
return isFloatingPointTy() || isX86_MMXTy() || isIntegerTy() ||
|
||||
isPointerTy() || isVectorTy();
|
||||
isPointerTy() || isVectorTy() || isX86_AMXTy();
|
||||
}
|
||||
|
||||
/// Return true if the type is an aggregate type. This means it is valid as
|
||||
@ -268,8 +272,8 @@ public:
|
||||
bool isSized(SmallPtrSetImpl<Type*> *Visited = nullptr) const {
|
||||
// If it's a primitive, it is always sized.
|
||||
if (getTypeID() == IntegerTyID || isFloatingPointTy() ||
|
||||
getTypeID() == PointerTyID ||
|
||||
getTypeID() == X86_MMXTyID)
|
||||
getTypeID() == PointerTyID || getTypeID() == X86_MMXTyID ||
|
||||
getTypeID() == X86_AMXTyID)
|
||||
return true;
|
||||
// If it is not something that can have a size (e.g. a function or label),
|
||||
// it doesn't have a size.
|
||||
@ -405,6 +409,7 @@ public:
|
||||
static Type *getFP128Ty(LLVMContext &C);
|
||||
static Type *getPPC_FP128Ty(LLVMContext &C);
|
||||
static Type *getX86_MMXTy(LLVMContext &C);
|
||||
static Type *getX86_AMXTy(LLVMContext &C);
|
||||
static Type *getTokenTy(LLVMContext &C);
|
||||
static IntegerType *getIntNTy(LLVMContext &C, unsigned N);
|
||||
static IntegerType *getInt1Ty(LLVMContext &C);
|
||||
@ -460,6 +465,7 @@ public:
|
||||
static PointerType *getFP128PtrTy(LLVMContext &C, unsigned AS = 0);
|
||||
static PointerType *getPPC_FP128PtrTy(LLVMContext &C, unsigned AS = 0);
|
||||
static PointerType *getX86_MMXPtrTy(LLVMContext &C, unsigned AS = 0);
|
||||
static PointerType *getX86_AMXPtrTy(LLVMContext &C, unsigned AS = 0);
|
||||
static PointerType *getIntNPtrTy(LLVMContext &C, unsigned N, unsigned AS = 0);
|
||||
static PointerType *getInt1PtrTy(LLVMContext &C, unsigned AS = 0);
|
||||
static PointerType *getInt8PtrTy(LLVMContext &C, unsigned AS = 0);
|
||||
|
@ -247,9 +247,10 @@ namespace llvm {
|
||||
exnref = 161, // WebAssembly's exnref type
|
||||
funcref = 162, // WebAssembly's funcref type
|
||||
externref = 163, // WebAssembly's externref type
|
||||
x86amx = 164, // This is an X86 AMX value
|
||||
|
||||
FIRST_VALUETYPE = 1, // This is always the beginning of the list.
|
||||
LAST_VALUETYPE = 164, // This always remains at the end of the list.
|
||||
LAST_VALUETYPE = 165, // This always remains at the end of the list.
|
||||
|
||||
// This is the current maximum for LAST_VALUETYPE.
|
||||
// MVT::MAX_ALLOWED_VALUETYPE is used for asserts and to size bit vectors
|
||||
@ -966,6 +967,7 @@ namespace llvm {
|
||||
case v256i32:
|
||||
case v128i64:
|
||||
case v256f32:
|
||||
case x86amx:
|
||||
case v128f64: return TypeSize::Fixed(8192);
|
||||
case v512i32:
|
||||
case v256i64:
|
||||
|
@ -105,9 +105,9 @@ Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &DL) {
|
||||
"Invalid constantexpr bitcast!");
|
||||
|
||||
// Catch the obvious splat cases.
|
||||
if (C->isNullValue() && !DestTy->isX86_MMXTy())
|
||||
if (C->isNullValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy())
|
||||
return Constant::getNullValue(DestTy);
|
||||
if (C->isAllOnesValue() && !DestTy->isX86_MMXTy() &&
|
||||
if (C->isAllOnesValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy() &&
|
||||
!DestTy->isPtrOrPtrVectorTy()) // Don't get ones for ptr types!
|
||||
return Constant::getAllOnesValue(DestTy);
|
||||
|
||||
@ -358,12 +358,13 @@ Constant *llvm::ConstantFoldLoadThroughBitcast(Constant *C, Type *DestTy,
|
||||
|
||||
// Catch the obvious splat cases (since all-zeros can coerce non-integral
|
||||
// pointers legally).
|
||||
if (C->isNullValue() && !DestTy->isX86_MMXTy())
|
||||
if (C->isNullValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy())
|
||||
return Constant::getNullValue(DestTy);
|
||||
if (C->isAllOnesValue() &&
|
||||
(DestTy->isIntegerTy() || DestTy->isFloatingPointTy() ||
|
||||
DestTy->isVectorTy()) &&
|
||||
!DestTy->isX86_MMXTy() && !DestTy->isPtrOrPtrVectorTy())
|
||||
!DestTy->isX86_AMXTy() && !DestTy->isX86_MMXTy() &&
|
||||
!DestTy->isPtrOrPtrVectorTy())
|
||||
// Get ones when the input is trivial, but
|
||||
// only for supported types inside getAllOnesValue.
|
||||
return Constant::getAllOnesValue(DestTy);
|
||||
@ -575,14 +576,16 @@ Constant *FoldReinterpretLoadFromConstPtr(Constant *C, Type *LoadTy,
|
||||
|
||||
C = FoldBitCast(C, MapTy->getPointerTo(AS), DL);
|
||||
if (Constant *Res = FoldReinterpretLoadFromConstPtr(C, MapTy, DL)) {
|
||||
if (Res->isNullValue() && !LoadTy->isX86_MMXTy())
|
||||
if (Res->isNullValue() && !LoadTy->isX86_MMXTy() &&
|
||||
!LoadTy->isX86_AMXTy())
|
||||
// Materializing a zero can be done trivially without a bitcast
|
||||
return Constant::getNullValue(LoadTy);
|
||||
Type *CastTy = LoadTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(LoadTy) : LoadTy;
|
||||
Res = FoldBitCast(Res, CastTy, DL);
|
||||
if (LoadTy->isPtrOrPtrVectorTy()) {
|
||||
// For vector of pointer, we needed to first convert to a vector of integer, then do vector inttoptr
|
||||
if (Res->isNullValue() && !LoadTy->isX86_MMXTy())
|
||||
if (Res->isNullValue() && !LoadTy->isX86_MMXTy() &&
|
||||
!LoadTy->isX86_AMXTy())
|
||||
return Constant::getNullValue(LoadTy);
|
||||
if (DL.isNonIntegralPointerType(LoadTy->getScalarType()))
|
||||
// Be careful not to replace a load of an addrspace value with an inttoptr here
|
||||
|
@ -840,6 +840,7 @@ lltok::Kind LLLexer::LexIdentifier() {
|
||||
TYPEKEYWORD("label", Type::getLabelTy(Context));
|
||||
TYPEKEYWORD("metadata", Type::getMetadataTy(Context));
|
||||
TYPEKEYWORD("x86_mmx", Type::getX86_MMXTy(Context));
|
||||
TYPEKEYWORD("x86_amx", Type::getX86_AMXTy(Context));
|
||||
TYPEKEYWORD("token", Type::getTokenTy(Context));
|
||||
|
||||
#undef TYPEKEYWORD
|
||||
|
@ -1763,6 +1763,9 @@ Error BitcodeReader::parseTypeTableBody() {
|
||||
case bitc::TYPE_CODE_X86_MMX: // X86_MMX
|
||||
ResultTy = Type::getX86_MMXTy(Context);
|
||||
break;
|
||||
case bitc::TYPE_CODE_X86_AMX: // X86_AMX
|
||||
ResultTy = Type::getX86_AMXTy(Context);
|
||||
break;
|
||||
case bitc::TYPE_CODE_TOKEN: // TOKEN
|
||||
ResultTy = Type::getTokenTy(Context);
|
||||
break;
|
||||
|
@ -913,6 +913,7 @@ void ModuleBitcodeWriter::writeTypeTable() {
|
||||
case Type::LabelTyID: Code = bitc::TYPE_CODE_LABEL; break;
|
||||
case Type::MetadataTyID: Code = bitc::TYPE_CODE_METADATA; break;
|
||||
case Type::X86_MMXTyID: Code = bitc::TYPE_CODE_X86_MMX; break;
|
||||
case Type::X86_AMXTyID: Code = bitc::TYPE_CODE_X86_AMX; break;
|
||||
case Type::TokenTyID: Code = bitc::TYPE_CODE_TOKEN; break;
|
||||
case Type::IntegerTyID:
|
||||
// INTEGER: [width]
|
||||
|
@ -164,6 +164,7 @@ std::string EVT::getEVTString() const {
|
||||
case MVT::Other: return "ch";
|
||||
case MVT::Glue: return "glue";
|
||||
case MVT::x86mmx: return "x86mmx";
|
||||
case MVT::x86amx: return "x86amx";
|
||||
case MVT::Metadata: return "Metadata";
|
||||
case MVT::Untyped: return "Untyped";
|
||||
case MVT::exnref: return "exnref";
|
||||
@ -195,6 +196,7 @@ Type *EVT::getTypeForEVT(LLVMContext &Context) const {
|
||||
case MVT::f128: return Type::getFP128Ty(Context);
|
||||
case MVT::ppcf128: return Type::getPPC_FP128Ty(Context);
|
||||
case MVT::x86mmx: return Type::getX86_MMXTy(Context);
|
||||
case MVT::x86amx: return Type::getX86_AMXTy(Context);
|
||||
case MVT::v1i1:
|
||||
return FixedVectorType::get(Type::getInt1Ty(Context), 1);
|
||||
case MVT::v2i1:
|
||||
@ -501,6 +503,7 @@ MVT MVT::getVT(Type *Ty, bool HandleUnknown){
|
||||
case Type::DoubleTyID: return MVT(MVT::f64);
|
||||
case Type::X86_FP80TyID: return MVT(MVT::f80);
|
||||
case Type::X86_MMXTyID: return MVT(MVT::x86mmx);
|
||||
case Type::X86_AMXTyID: return MVT(MVT::x86amx);
|
||||
case Type::FP128TyID: return MVT(MVT::f128);
|
||||
case Type::PPC_FP128TyID: return MVT(MVT::ppcf128);
|
||||
case Type::PointerTyID: return MVT(MVT::iPTR);
|
||||
|
@ -609,6 +609,7 @@ void TypePrinting::print(Type *Ty, raw_ostream &OS) {
|
||||
case Type::LabelTyID: OS << "label"; return;
|
||||
case Type::MetadataTyID: OS << "metadata"; return;
|
||||
case Type::X86_MMXTyID: OS << "x86_mmx"; return;
|
||||
case Type::X86_AMXTyID: OS << "x86_amx"; return;
|
||||
case Type::TokenTyID: OS << "token"; return;
|
||||
case Type::IntegerTyID:
|
||||
OS << 'i' << cast<IntegerType>(Ty)->getBitWidth();
|
||||
|
@ -535,7 +535,7 @@ Constant *llvm::ConstantFoldCastInstruction(unsigned opc, Constant *V,
|
||||
return UndefValue::get(DestTy);
|
||||
}
|
||||
|
||||
if (V->isNullValue() && !DestTy->isX86_MMXTy() &&
|
||||
if (V->isNullValue() && !DestTy->isX86_MMXTy() && !DestTy->isX86_AMXTy() &&
|
||||
opc != Instruction::AddrSpaceCast)
|
||||
return Constant::getNullValue(DestTy);
|
||||
|
||||
|
@ -512,6 +512,8 @@ LLVMTypeKind LLVMGetTypeKind(LLVMTypeRef Ty) {
|
||||
return LLVMVectorTypeKind;
|
||||
case Type::X86_MMXTyID:
|
||||
return LLVMX86_MMXTypeKind;
|
||||
case Type::X86_AMXTyID:
|
||||
return LLVMX86_AMXTypeKind;
|
||||
case Type::TokenTyID:
|
||||
return LLVMTokenTypeKind;
|
||||
case Type::ScalableVectorTyID:
|
||||
@ -623,6 +625,9 @@ LLVMTypeRef LLVMPPCFP128TypeInContext(LLVMContextRef C) {
|
||||
LLVMTypeRef LLVMX86MMXTypeInContext(LLVMContextRef C) {
|
||||
return (LLVMTypeRef) Type::getX86_MMXTy(*unwrap(C));
|
||||
}
|
||||
LLVMTypeRef LLVMX86AMXTypeInContext(LLVMContextRef C) {
|
||||
return (LLVMTypeRef) Type::getX86_AMXTy(*unwrap(C));
|
||||
}
|
||||
|
||||
LLVMTypeRef LLVMHalfType(void) {
|
||||
return LLVMHalfTypeInContext(LLVMGetGlobalContext());
|
||||
@ -648,6 +653,9 @@ LLVMTypeRef LLVMPPCFP128Type(void) {
|
||||
LLVMTypeRef LLVMX86MMXType(void) {
|
||||
return LLVMX86MMXTypeInContext(LLVMGetGlobalContext());
|
||||
}
|
||||
LLVMTypeRef LLVMX86AMXType(void) {
|
||||
return LLVMX86AMXTypeInContext(LLVMGetGlobalContext());
|
||||
}
|
||||
|
||||
/*--.. Operations on function types ........................................--*/
|
||||
|
||||
|
@ -810,6 +810,8 @@ Align DataLayout::getAlignment(Type *Ty, bool abi_or_pref) const {
|
||||
Alignment = PowerOf2Ceil(Alignment);
|
||||
return Align(Alignment);
|
||||
}
|
||||
case Type::X86_AMXTyID:
|
||||
return Align(64);
|
||||
default:
|
||||
llvm_unreachable("Bad type for getAlignment!!!");
|
||||
}
|
||||
|
@ -764,6 +764,7 @@ static std::string getMangledTypeStr(Type* Ty) {
|
||||
case Type::FP128TyID: Result += "f128"; break;
|
||||
case Type::PPC_FP128TyID: Result += "ppcf128"; break;
|
||||
case Type::X86_MMXTyID: Result += "x86mmx"; break;
|
||||
case Type::X86_AMXTyID: Result += "x86amx"; break;
|
||||
case Type::IntegerTyID:
|
||||
Result += "i" + utostr(cast<IntegerType>(Ty)->getBitWidth());
|
||||
break;
|
||||
@ -848,7 +849,8 @@ enum IIT_Info {
|
||||
IIT_V128 = 47,
|
||||
IIT_BF16 = 48,
|
||||
IIT_STRUCT9 = 49,
|
||||
IIT_V256 = 50
|
||||
IIT_V256 = 50,
|
||||
IIT_AMX = 51
|
||||
};
|
||||
|
||||
static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
|
||||
@ -871,6 +873,9 @@ static void DecodeIITType(unsigned &NextElt, ArrayRef<unsigned char> Infos,
|
||||
case IIT_MMX:
|
||||
OutputTable.push_back(IITDescriptor::get(IITDescriptor::MMX, 0));
|
||||
return;
|
||||
case IIT_AMX:
|
||||
OutputTable.push_back(IITDescriptor::get(IITDescriptor::AMX, 0));
|
||||
return;
|
||||
case IIT_TOKEN:
|
||||
OutputTable.push_back(IITDescriptor::get(IITDescriptor::Token, 0));
|
||||
return;
|
||||
@ -1108,6 +1113,7 @@ static Type *DecodeFixedType(ArrayRef<Intrinsic::IITDescriptor> &Infos,
|
||||
case IITDescriptor::Void: return Type::getVoidTy(Context);
|
||||
case IITDescriptor::VarArg: return Type::getVoidTy(Context);
|
||||
case IITDescriptor::MMX: return Type::getX86_MMXTy(Context);
|
||||
case IITDescriptor::AMX: return Type::getX86_AMXTy(Context);
|
||||
case IITDescriptor::Token: return Type::getTokenTy(Context);
|
||||
case IITDescriptor::Metadata: return Type::getMetadataTy(Context);
|
||||
case IITDescriptor::Half: return Type::getHalfTy(Context);
|
||||
@ -1287,6 +1293,7 @@ static bool matchIntrinsicType(
|
||||
case IITDescriptor::Void: return !Ty->isVoidTy();
|
||||
case IITDescriptor::VarArg: return true;
|
||||
case IITDescriptor::MMX: return !Ty->isX86_MMXTy();
|
||||
case IITDescriptor::AMX: return !Ty->isX86_AMXTy();
|
||||
case IITDescriptor::Token: return !Ty->isTokenTy();
|
||||
case IITDescriptor::Metadata: return !Ty->isMetadataTy();
|
||||
case IITDescriptor::Half: return !Ty->isHalfTy();
|
||||
|
@ -35,6 +35,7 @@ LLVMContextImpl::LLVMContextImpl(LLVMContext &C)
|
||||
FP128Ty(C, Type::FP128TyID),
|
||||
PPC_FP128Ty(C, Type::PPC_FP128TyID),
|
||||
X86_MMXTy(C, Type::X86_MMXTyID),
|
||||
X86_AMXTy(C, Type::X86_AMXTyID),
|
||||
Int1Ty(C, 1),
|
||||
Int8Ty(C, 8),
|
||||
Int16Ty(C, 16),
|
||||
|
@ -1418,7 +1418,7 @@ public:
|
||||
// Basic type instances.
|
||||
Type VoidTy, LabelTy, HalfTy, BFloatTy, FloatTy, DoubleTy, MetadataTy,
|
||||
TokenTy;
|
||||
Type X86_FP80Ty, FP128Ty, PPC_FP128Ty, X86_MMXTy;
|
||||
Type X86_FP80Ty, FP128Ty, PPC_FP128Ty, X86_MMXTy, X86_AMXTy;
|
||||
IntegerType Int1Ty, Int8Ty, Int16Ty, Int32Ty, Int64Ty, Int128Ty;
|
||||
|
||||
BumpPtrAllocator Alloc;
|
||||
|
@ -49,6 +49,7 @@ Type *Type::getPrimitiveType(LLVMContext &C, TypeID IDNumber) {
|
||||
case LabelTyID : return getLabelTy(C);
|
||||
case MetadataTyID : return getMetadataTy(C);
|
||||
case X86_MMXTyID : return getX86_MMXTy(C);
|
||||
case X86_AMXTyID : return getX86_AMXTy(C);
|
||||
case TokenTyID : return getTokenTy(C);
|
||||
default:
|
||||
return nullptr;
|
||||
@ -81,6 +82,14 @@ bool Type::canLosslesslyBitCastTo(Type *Ty) const {
|
||||
Ty->getPrimitiveSizeInBits().getFixedSize() == 64)
|
||||
return true;
|
||||
|
||||
// 8192-bit fixed width vector types can be losslessly converted to x86amx.
|
||||
if (((isa<FixedVectorType>(this)) && Ty->isX86_AMXTy()) &&
|
||||
getPrimitiveSizeInBits().getFixedSize() == 8192)
|
||||
return true;
|
||||
if ((isX86_AMXTy() && isa<FixedVectorType>(Ty)) &&
|
||||
Ty->getPrimitiveSizeInBits().getFixedSize() == 8192)
|
||||
return true;
|
||||
|
||||
// At this point we have only various mismatches of the first class types
|
||||
// remaining and ptr->ptr. Just select the lossless conversions. Everything
|
||||
// else is not lossless. Conservatively assume we can't losslessly convert
|
||||
@ -120,6 +129,7 @@ TypeSize Type::getPrimitiveSizeInBits() const {
|
||||
case Type::FP128TyID: return TypeSize::Fixed(128);
|
||||
case Type::PPC_FP128TyID: return TypeSize::Fixed(128);
|
||||
case Type::X86_MMXTyID: return TypeSize::Fixed(64);
|
||||
case Type::X86_AMXTyID: return TypeSize::Fixed(8192);
|
||||
case Type::IntegerTyID:
|
||||
return TypeSize::Fixed(cast<IntegerType>(this)->getBitWidth());
|
||||
case Type::FixedVectorTyID:
|
||||
@ -179,6 +189,7 @@ Type *Type::getX86_FP80Ty(LLVMContext &C) { return &C.pImpl->X86_FP80Ty; }
|
||||
Type *Type::getFP128Ty(LLVMContext &C) { return &C.pImpl->FP128Ty; }
|
||||
Type *Type::getPPC_FP128Ty(LLVMContext &C) { return &C.pImpl->PPC_FP128Ty; }
|
||||
Type *Type::getX86_MMXTy(LLVMContext &C) { return &C.pImpl->X86_MMXTy; }
|
||||
Type *Type::getX86_AMXTy(LLVMContext &C) { return &C.pImpl->X86_AMXTy; }
|
||||
|
||||
IntegerType *Type::getInt1Ty(LLVMContext &C) { return &C.pImpl->Int1Ty; }
|
||||
IntegerType *Type::getInt8Ty(LLVMContext &C) { return &C.pImpl->Int8Ty; }
|
||||
@ -223,6 +234,10 @@ PointerType *Type::getX86_MMXPtrTy(LLVMContext &C, unsigned AS) {
|
||||
return getX86_MMXTy(C)->getPointerTo(AS);
|
||||
}
|
||||
|
||||
PointerType *Type::getX86_AMXPtrTy(LLVMContext &C, unsigned AS) {
|
||||
return getX86_AMXTy(C)->getPointerTo(AS);
|
||||
}
|
||||
|
||||
PointerType *Type::getIntNPtrTy(LLVMContext &C, unsigned N, unsigned AS) {
|
||||
return getIntNTy(C, N)->getPointerTo(AS);
|
||||
}
|
||||
|
@ -4618,7 +4618,7 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
|
||||
Segment,
|
||||
CFG,
|
||||
Chain};
|
||||
CNode = CurDAG->getMachineNode(Opc, dl, {MVT::v256i32, MVT::Other}, Ops);
|
||||
CNode = CurDAG->getMachineNode(Opc, dl, {MVT::x86amx, MVT::Other}, Ops);
|
||||
ReplaceNode(Node, CNode);
|
||||
return;
|
||||
}
|
||||
@ -4637,7 +4637,7 @@ void X86DAGToDAGISel::Select(SDNode *Node) {
|
||||
CFG,
|
||||
Chain};
|
||||
MachineSDNode *CNode =
|
||||
CurDAG->getMachineNode(Opc, dl, {MVT::v256i32, MVT::Other}, Ops);
|
||||
CurDAG->getMachineNode(Opc, dl, {MVT::x86amx, MVT::Other}, Ops);
|
||||
ReplaceNode(Node, CNode);
|
||||
return;
|
||||
}
|
||||
|
@ -1898,7 +1898,7 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
|
||||
}
|
||||
|
||||
if (Subtarget.hasAMXTILE()) {
|
||||
addRegisterClass(MVT::v256i32, &X86::TILERegClass);
|
||||
addRegisterClass(MVT::x86amx, &X86::TILERegClass);
|
||||
}
|
||||
|
||||
// We want to custom lower some of our intrinsics.
|
||||
@ -5346,11 +5346,6 @@ bool X86TargetLowering::canMergeStoresTo(unsigned AddressSpace, EVT MemVT,
|
||||
if (MemVT.getSizeInBits() > Subtarget.getPreferVectorWidth())
|
||||
return false;
|
||||
|
||||
// Don't merge to x86 amx tile, as we only map MVT::v256i32
|
||||
// to x86 amx tile on amx intrinsics.
|
||||
if (MemVT == MVT::v256i32)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -6,20 +6,20 @@
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
/// \file Pass to transform <256 x i32>
|
||||
/// <256 x i32> is mapped to AMX tile register on X86, AMX instruction set only
|
||||
/// provides simple operation on tile register. The basic elementwise operation
|
||||
/// is not supported by AMX. Since we define the AMX tile as vector <256 x i32>
|
||||
/// \file Pass to transform <256 x i32> load/store
|
||||
/// <256 x i32> is bitcasted to x86_amx on X86, and AMX instruction set only
|
||||
/// provides simple operation on x86_amx. The basic elementwise operation
|
||||
/// is not supported by AMX. Since x86_amx is bitcasted from vector <256 x i32>
|
||||
/// and only AMX intrinsics can operate on the type, we need transform
|
||||
/// load/store <256 x i32> instruction to AMX load/store. Besides, we split
|
||||
/// <256 x i32> to 2 <128 x i32> if the vector is not used or defined by AMX
|
||||
/// intrinsics, so that in instruction selection it can be lowered to proper
|
||||
/// size which HW can support.
|
||||
/// load/store <256 x i32> instruction to AMX load/store. If the bitcast can
|
||||
/// not be combined with load/store, we transform the bitcast to amx load/store
|
||||
/// and <256 x i32> store/load.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
#include "X86.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/PostOrderIterator.h"
|
||||
#include "llvm/ADT/SmallSet.h"
|
||||
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
|
||||
#include "llvm/Analysis/TargetTransformInfo.h"
|
||||
#include "llvm/CodeGen/Passes.h"
|
||||
@ -30,145 +30,44 @@
|
||||
#include "llvm/IR/Instructions.h"
|
||||
#include "llvm/IR/IntrinsicInst.h"
|
||||
#include "llvm/IR/IntrinsicsX86.h"
|
||||
#include "llvm/IR/PatternMatch.h"
|
||||
#include "llvm/InitializePasses.h"
|
||||
#include "llvm/Pass.h"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace PatternMatch;
|
||||
|
||||
#define DEBUG_TYPE "lower-amx-type"
|
||||
|
||||
namespace {
|
||||
class X86LowerAMXType {
|
||||
Function &Func;
|
||||
const DataLayout &DL;
|
||||
DenseSet<Instruction *> LDSet;
|
||||
DenseSet<Instruction *> STSet;
|
||||
DenseMap<Value *, std::pair<LoadInst *, LoadInst *>> LoadMap;
|
||||
static AllocaInst *CreateAllocaInst(IRBuilder<> &Builder, BasicBlock *BB) {
|
||||
Function &F = *BB->getParent();
|
||||
Module *M = BB->getModule();
|
||||
const DataLayout &DL = M->getDataLayout();
|
||||
|
||||
public:
|
||||
X86LowerAMXType(Function &F) : Func(F), DL(F.getParent()->getDataLayout()) {}
|
||||
bool visit();
|
||||
bool visitLD();
|
||||
bool visitST();
|
||||
void splitST(Instruction *Inst);
|
||||
void splitLD(Instruction *Inst);
|
||||
};
|
||||
|
||||
// Split v256i32 load/store to 2 v128i32, so that ISel can
|
||||
// lower it to proper vector size.
|
||||
void X86LowerAMXType::splitST(Instruction *Inst) {
|
||||
StoreInst *ST = dyn_cast<StoreInst>(Inst);
|
||||
IRBuilder<> Builder(ST);
|
||||
Type *V256I32Ty = VectorType::get(Builder.getInt32Ty(), 256, false);
|
||||
LLVMContext &Ctx = Builder.getContext();
|
||||
Type *Ty = ST->getValueOperand()->getType();
|
||||
EVT VT = EVT::getEVT(Ty);
|
||||
EVT HalfVT = VT.getHalfNumVectorElementsVT(Ctx);
|
||||
Type *HalfTy = HalfVT.getTypeForEVT(Ctx);
|
||||
|
||||
LoadInst *Lo, *Hi;
|
||||
std::tie(Lo, Hi) = LoadMap[ST->getValueOperand()];
|
||||
Value *Ptr = ST->getPointerOperand();
|
||||
PointerType *HalfPtrTy = HalfTy->getPointerTo(ST->getPointerAddressSpace());
|
||||
Value *HalfPtr = Builder.CreateBitCast(Ptr, HalfPtrTy);
|
||||
// The HW require the alignment for AMX tile is 64, but front-end generate
|
||||
// code for the vector alignment which is the vector size.
|
||||
uint64_t HalfTySize = HalfTy->getPrimitiveSizeInBits().getFixedSize() / 8;
|
||||
Align Alignment = std::min(Lo->getAlign(), Align(HalfTySize));
|
||||
Builder.CreateAlignedStore(Lo, HalfPtr, Alignment, ST->isVolatile());
|
||||
|
||||
HalfPtr = Builder.CreateGEP(HalfTy, HalfPtr, Builder.getInt32(1));
|
||||
Builder.CreateAlignedStore(Hi, HalfPtr, Alignment, ST->isVolatile());
|
||||
auto AllocaAlignment = DL.getPrefTypeAlign(Type::getX86_AMXTy(Ctx));
|
||||
unsigned AllocaAS = DL.getAllocaAddrSpace();
|
||||
AllocaInst *AllocaRes =
|
||||
new AllocaInst(V256I32Ty, AllocaAS, "", &F.getEntryBlock().front());
|
||||
AllocaRes->setAlignment(AllocaAlignment);
|
||||
return AllocaRes;
|
||||
}
|
||||
|
||||
bool X86LowerAMXType::visitST() {
|
||||
if (STSet.empty())
|
||||
return false;
|
||||
for (auto *Inst : STSet) {
|
||||
Value *Row, *Col;
|
||||
const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst->getOperand(0));
|
||||
if (!II)
|
||||
Row = Col = nullptr;
|
||||
else {
|
||||
static std::pair<Value *, Value *> getShape(IntrinsicInst *II, unsigned OpNo) {
|
||||
Value *Row = nullptr, *Col = nullptr;
|
||||
switch (II->getIntrinsicID()) {
|
||||
default:
|
||||
Row = Col = nullptr;
|
||||
break;
|
||||
llvm_unreachable("Expect amx intrinsics");
|
||||
case Intrinsic::x86_tileloadd64_internal:
|
||||
case Intrinsic::x86_tdpbssd_internal: {
|
||||
case Intrinsic::x86_tilestored64_internal: {
|
||||
Row = II->getArgOperand(0);
|
||||
Col = II->getArgOperand(1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!Row) {
|
||||
splitST(Inst);
|
||||
continue;
|
||||
}
|
||||
IRBuilder<> Builder(Inst);
|
||||
LLVMContext &Ctx = Builder.getContext();
|
||||
// Use the maximun column as stride. It must be the same with load stride.
|
||||
Value *Stride = Builder.getInt64(64);
|
||||
Value *I8Ptr =
|
||||
Builder.CreateBitCast(Inst->getOperand(1), Type::getInt8PtrTy(Ctx));
|
||||
std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride,
|
||||
Inst->getOperand(0)};
|
||||
|
||||
Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void X86LowerAMXType::splitLD(Instruction *Inst) {
|
||||
LoadInst *LD = dyn_cast<LoadInst>(Inst);
|
||||
IRBuilder<> Builder(LD);
|
||||
LLVMContext &Ctx = Builder.getContext();
|
||||
Type *Ty = LD->getType();
|
||||
EVT VT = EVT::getEVT(Ty);
|
||||
EVT HalfVT = VT.getHalfNumVectorElementsVT(Ctx);
|
||||
Type *HalfTy = HalfVT.getTypeForEVT(Ctx);
|
||||
|
||||
Value *Ptr = LD->getPointerOperand();
|
||||
PointerType *HalfPtrTy = HalfTy->getPointerTo(LD->getPointerAddressSpace());
|
||||
Value *HalfPtr = Builder.CreateBitCast(Ptr, HalfPtrTy);
|
||||
// The HW require the alignment for AMX tile is 64, but front-end generate
|
||||
// code for the vector alignment which is the vector size.
|
||||
uint64_t HalfTySize = HalfTy->getPrimitiveSizeInBits().getFixedSize() / 8;
|
||||
Align Alignment = std::min(LD->getAlign(), Align(HalfTySize));
|
||||
auto *Lo =
|
||||
Builder.CreateAlignedLoad(HalfTy, HalfPtr, Alignment, LD->isVolatile());
|
||||
|
||||
HalfPtr = Builder.CreateGEP(HalfTy, HalfPtr, Builder.getInt32(1));
|
||||
auto *Hi =
|
||||
Builder.CreateAlignedLoad(HalfTy, HalfPtr, Alignment, LD->isVolatile());
|
||||
|
||||
LoadMap[Inst] = std::make_pair(Lo, Hi);
|
||||
}
|
||||
|
||||
bool X86LowerAMXType::visitLD() {
|
||||
if (LDSet.empty())
|
||||
return false;
|
||||
for (auto &Inst : LDSet) {
|
||||
int Count = 0;
|
||||
Value *NewInst = nullptr;
|
||||
// The user should be all AMX intrinsics or all LLVM instruction.
|
||||
// Don't support it is used by both AMX intrinsics and LLVM instructions.
|
||||
for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
|
||||
Use &U = *I++;
|
||||
const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U.getUser());
|
||||
if (!II) {
|
||||
Count++;
|
||||
continue;
|
||||
}
|
||||
if (NewInst)
|
||||
continue;
|
||||
Value *Row, *Col;
|
||||
switch (II->getIntrinsicID()) {
|
||||
default:
|
||||
report_fatal_error("Non-AMX intrinsic use tile type.");
|
||||
break;
|
||||
// a * b + c
|
||||
// The shape depends on which operand.
|
||||
case Intrinsic::x86_tdpbssd_internal: {
|
||||
unsigned OpNo = U.getOperandNo();
|
||||
switch (OpNo) {
|
||||
case 3:
|
||||
Row = II->getArgOperand(0);
|
||||
@ -185,76 +84,234 @@ bool X86LowerAMXType::visitLD() {
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Intrinsic::x86_tilestored64_internal: {
|
||||
Row = II->getArgOperand(0);
|
||||
Col = II->getArgOperand(1);
|
||||
break;
|
||||
}
|
||||
|
||||
return std::make_pair(Row, Col);
|
||||
}
|
||||
assert(Count == 0 && "Can NOT mix amx intrinsic and LLVM instruction");
|
||||
// FIXME: The shape def should be ahead of load.
|
||||
IRBuilder<> Builder(Inst);
|
||||
LLVMContext &Ctx = Builder.getContext();
|
||||
|
||||
// %src = load <256 x i32>, <256 x i32>* %addr, align 64
|
||||
// %2 = bitcast <256 x i32> %src to x86_amx
|
||||
// -->
|
||||
// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
|
||||
// i8* %addr, i64 %stride64)
|
||||
static void combineLoadBitcast(LoadInst *LD, BitCastInst *Bitcast) {
|
||||
Value *Row = nullptr, *Col = nullptr;
|
||||
Use &U = *(Bitcast->use_begin());
|
||||
unsigned OpNo = U.getOperandNo();
|
||||
auto *II = cast<IntrinsicInst>(U.getUser());
|
||||
std::tie(Row, Col) = getShape(II, OpNo);
|
||||
IRBuilder<> Builder(Bitcast);
|
||||
// Use the maximun column as stride.
|
||||
Value *Stride = Builder.getInt64(64);
|
||||
Value *I8Ptr =
|
||||
Builder.CreateBitCast(Inst->getOperand(0), Type::getInt8PtrTy(Ctx));
|
||||
Builder.CreateBitCast(LD->getOperand(0), Builder.getInt8PtrTy());
|
||||
std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
|
||||
|
||||
NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
|
||||
None, Args);
|
||||
|
||||
Inst->replaceAllUsesWith(NewInst);
|
||||
}
|
||||
if (!NewInst)
|
||||
splitLD(Inst);
|
||||
}
|
||||
return true;
|
||||
Value *NewInst =
|
||||
Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal, None, Args);
|
||||
Bitcast->replaceAllUsesWith(NewInst);
|
||||
}
|
||||
|
||||
bool X86LowerAMXType::visit() {
|
||||
bool C;
|
||||
auto IsAMXType = [](FixedVectorType *VTy) {
|
||||
if (!VTy)
|
||||
return false;
|
||||
if (!VTy->getScalarType()->isIntegerTy(32))
|
||||
return false;
|
||||
if (VTy->getNumElements() != 256)
|
||||
return false;
|
||||
// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
|
||||
// %stride);
|
||||
// %13 = bitcast x86_amx %src to <256 x i32>
|
||||
// store <256 x i32> %13, <256 x i32>* %addr, align 64
|
||||
// -->
|
||||
// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
|
||||
// %stride64, %13)
|
||||
static void combineBitcastStore(BitCastInst *Bitcast, StoreInst *ST) {
|
||||
|
||||
return true;
|
||||
Value *Tile = Bitcast->getOperand(0);
|
||||
auto *II = cast<IntrinsicInst>(Tile);
|
||||
// Tile is output from AMX intrinsic. The first operand of the
|
||||
// intrinsic is row, the second operand of the intrinsic is column.
|
||||
Value *Row = II->getOperand(0);
|
||||
Value *Col = II->getOperand(1);
|
||||
IRBuilder<> Builder(ST);
|
||||
// Use the maximum column as stride. It must be the same with load
|
||||
// stride.
|
||||
Value *Stride = Builder.getInt64(64);
|
||||
Value *I8Ptr =
|
||||
Builder.CreateBitCast(ST->getOperand(1), Builder.getInt8PtrTy());
|
||||
std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Tile};
|
||||
Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
|
||||
if (Bitcast->hasOneUse())
|
||||
return;
|
||||
// %13 = bitcast x86_amx %src to <256 x i32>
|
||||
// store <256 x i32> %13, <256 x i32>* %addr, align 64
|
||||
// %add = <256 x i32> %13, <256 x i32> %src2
|
||||
// -->
|
||||
// %13 = bitcast x86_amx %src to <256 x i32>
|
||||
// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
|
||||
// %stride64, %13)
|
||||
// %14 = load <256 x i32>, %addr
|
||||
// %add = <256 x i32> %14, <256 x i32> %src2
|
||||
Value *Vec = Builder.CreateLoad(Bitcast->getType(), ST->getOperand(1));
|
||||
Bitcast->replaceAllUsesWith(Vec);
|
||||
}
|
||||
|
||||
// transform bitcast to <store, load> instructions.
|
||||
static bool transformBitcast(BitCastInst *Bitcast) {
|
||||
IRBuilder<> Builder(Bitcast);
|
||||
AllocaInst *AllocaAddr;
|
||||
Value *I8Ptr, *Stride;
|
||||
auto *Src = Bitcast->getOperand(0);
|
||||
|
||||
auto Prepare = [&]() {
|
||||
AllocaAddr = CreateAllocaInst(Builder, Bitcast->getParent());
|
||||
I8Ptr = Builder.CreateBitCast(AllocaAddr, Builder.getInt8PtrTy());
|
||||
Stride = Builder.getInt64(64);
|
||||
};
|
||||
|
||||
for (BasicBlock &BB : Func) {
|
||||
for (Instruction &Inst : BB) {
|
||||
LoadInst *LD = dyn_cast<LoadInst>(&Inst);
|
||||
// Check load instruction.
|
||||
// %3 = load <256 x i32>, <256 x i32>* %1, align 64
|
||||
if (LD) {
|
||||
FixedVectorType *VTy = dyn_cast<FixedVectorType>(Inst.getType());
|
||||
if (!IsAMXType(VTy))
|
||||
if (Bitcast->getType()->isX86_AMXTy()) {
|
||||
// %2 = bitcast <256 x i32> %src to x86_amx
|
||||
// -->
|
||||
// %addr = alloca <256 x i32>, align 64
|
||||
// store <256 x i32> %src, <256 x i32>* %addr, align 64
|
||||
// %addr2 = bitcast <256 x i32>* to i8*
|
||||
// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
|
||||
// i8* %addr2,
|
||||
// i64 64)
|
||||
Use &U = *(Bitcast->use_begin());
|
||||
unsigned OpNo = U.getOperandNo();
|
||||
auto *II = dyn_cast<IntrinsicInst>(U.getUser());
|
||||
if (!II)
|
||||
return false; // May be bitcast from x86amx to <256 x i32>.
|
||||
Prepare();
|
||||
Builder.CreateStore(Src, AllocaAddr);
|
||||
// TODO we can pick an constant operand for the shape.
|
||||
Value *Row = nullptr, *Col = nullptr;
|
||||
std::tie(Row, Col) = getShape(II, OpNo);
|
||||
std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
|
||||
Value *NewInst = Builder.CreateIntrinsic(
|
||||
Intrinsic::x86_tileloadd64_internal, None, Args);
|
||||
Bitcast->replaceAllUsesWith(NewInst);
|
||||
} else {
|
||||
// %2 = bitcast x86_amx %src to <256 x i32>
|
||||
// -->
|
||||
// %addr = alloca <256 x i32>, align 64
|
||||
// %addr2 = bitcast <256 x i32>* to i8*
|
||||
// call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col,
|
||||
// i8* %addr2, i64 %stride)
|
||||
// %2 = load <256 x i32>, <256 x i32>* %addr, align 64
|
||||
auto *II = dyn_cast<IntrinsicInst>(Src);
|
||||
if (!II)
|
||||
return false; // May be bitcast from <256 x i32> to x86amx.
|
||||
Prepare();
|
||||
Value *Row = II->getOperand(0);
|
||||
Value *Col = II->getOperand(1);
|
||||
std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride, Src};
|
||||
Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
|
||||
Value *NewInst = Builder.CreateLoad(Bitcast->getType(), AllocaAddr);
|
||||
Bitcast->replaceAllUsesWith(NewInst);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
class X86LowerAMXType {
|
||||
Function &Func;
|
||||
|
||||
public:
|
||||
X86LowerAMXType(Function &F) : Func(F) {}
|
||||
bool visit();
|
||||
};
|
||||
|
||||
bool X86LowerAMXType::visit() {
|
||||
SmallVector<Instruction *, 8> DeadInsts;
|
||||
|
||||
for (BasicBlock *BB : post_order(&Func)) {
|
||||
for (BasicBlock::reverse_iterator II = BB->rbegin(), IE = BB->rend();
|
||||
II != IE;) {
|
||||
Instruction &Inst = *II++;
|
||||
auto *Bitcast = dyn_cast<BitCastInst>(&Inst);
|
||||
if (!Bitcast)
|
||||
continue;
|
||||
LDSet.insert(&Inst);
|
||||
|
||||
Value *Src = Bitcast->getOperand(0);
|
||||
if (Bitcast->getType()->isX86_AMXTy()) {
|
||||
if (Bitcast->user_empty()) {
|
||||
DeadInsts.push_back(Bitcast);
|
||||
continue;
|
||||
}
|
||||
// Check store instruction.
|
||||
// store <256 x i32> %3, <256 x i32>* %2, align 64
|
||||
StoreInst *ST = dyn_cast<StoreInst>(&Inst);
|
||||
if (!ST)
|
||||
LoadInst *LD = dyn_cast<LoadInst>(Src);
|
||||
if (!LD) {
|
||||
if (transformBitcast(Bitcast))
|
||||
DeadInsts.push_back(Bitcast);
|
||||
continue;
|
||||
FixedVectorType *VTy =
|
||||
dyn_cast<FixedVectorType>(ST->getOperand(0)->getType());
|
||||
if (!IsAMXType(VTy))
|
||||
}
|
||||
// If load has mutli-user, duplicate a vector load.
|
||||
// %src = load <256 x i32>, <256 x i32>* %addr, align 64
|
||||
// %2 = bitcast <256 x i32> %src to x86_amx
|
||||
// %add = add <256 x i32> %src, <256 x i32> %src2
|
||||
// -->
|
||||
// %src = load <256 x i32>, <256 x i32>* %addr, align 64
|
||||
// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
|
||||
// i8* %addr, i64 %stride64)
|
||||
// %add = add <256 x i32> %src, <256 x i32> %src2
|
||||
|
||||
// If load has one user, the load will be eliminated in DAG ISel.
|
||||
// %src = load <256 x i32>, <256 x i32>* %addr, align 64
|
||||
// %2 = bitcast <256 x i32> %src to x86_amx
|
||||
// -->
|
||||
// %2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col,
|
||||
// i8* %addr, i64 %stride64)
|
||||
combineLoadBitcast(LD, Bitcast);
|
||||
DeadInsts.push_back(Bitcast);
|
||||
if (LD->hasOneUse())
|
||||
DeadInsts.push_back(LD);
|
||||
} else if (Src->getType()->isX86_AMXTy()) {
|
||||
if (Bitcast->user_empty()) {
|
||||
DeadInsts.push_back(Bitcast);
|
||||
continue;
|
||||
STSet.insert(&Inst);
|
||||
}
|
||||
StoreInst *ST = nullptr;
|
||||
for (auto UI = Bitcast->use_begin(), UE = Bitcast->use_end();
|
||||
UI != UE;) {
|
||||
Value *I = (UI++)->getUser();
|
||||
ST = dyn_cast<StoreInst>(I);
|
||||
if (ST)
|
||||
break;
|
||||
}
|
||||
if (!ST) {
|
||||
if (transformBitcast(Bitcast))
|
||||
DeadInsts.push_back(Bitcast);
|
||||
continue;
|
||||
}
|
||||
// If bitcast (%13) has one use, combine bitcast and store to amx store.
|
||||
// %src = call x86_amx @llvm.x86.tileloadd64.internal(%row, %col, %addr,
|
||||
// %stride);
|
||||
// %13 = bitcast x86_amx %src to <256 x i32>
|
||||
// store <256 x i32> %13, <256 x i32>* %addr, align 64
|
||||
// -->
|
||||
// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
|
||||
// %stride64, %13)
|
||||
//
|
||||
// If bitcast (%13) has multi-use, transform as below.
|
||||
// %13 = bitcast x86_amx %src to <256 x i32>
|
||||
// store <256 x i32> %13, <256 x i32>* %addr, align 64
|
||||
// %add = <256 x i32> %13, <256 x i32> %src2
|
||||
// -->
|
||||
// %13 = bitcast x86_amx %src to <256 x i32>
|
||||
// call void @llvm.x86.tilestored64.internal(%row, %col, %addr,
|
||||
// %stride64, %13)
|
||||
// %14 = load <256 x i32>, %addr
|
||||
// %add = <256 x i32> %14, <256 x i32> %src2
|
||||
//
|
||||
combineBitcastStore(Bitcast, ST);
|
||||
// Delete user first.
|
||||
DeadInsts.push_back(ST);
|
||||
DeadInsts.push_back(Bitcast);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
C = visitLD() | visitST();
|
||||
for (auto *Inst : STSet)
|
||||
Inst->eraseFromParent();
|
||||
for (auto *Inst : LDSet)
|
||||
bool C = !DeadInsts.empty();
|
||||
|
||||
for (auto *Inst : DeadInsts)
|
||||
Inst->eraseFromParent();
|
||||
|
||||
return C;
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
@ -637,7 +637,7 @@ def BNDR : RegisterClass<"X86", [v2i64], 128, (sequence "BND%u", 0, 3)>;
|
||||
|
||||
// Tiles
|
||||
let CopyCost = -1 in // Don't allow copying of tile registers
|
||||
def TILE : RegisterClass<"X86", [v256i32], 8192,
|
||||
def TILE : RegisterClass<"X86", [x86amx], 8192,
|
||||
(sequence "TMM%u", 0, 7)> {let Size = 8192;}
|
||||
def TILECFG : RegisterClass<"X86", [untyped], 512, (add TMMCFG)> {
|
||||
let CopyCost = -1; // Don't allow copying of tile config registers.
|
||||
|
@ -1115,6 +1115,10 @@ static bool combineStoreToValueType(InstCombinerImpl &IC, StoreInst &SI) {
|
||||
// Fold away bit casts of the stored value by storing the original type.
|
||||
if (auto *BC = dyn_cast<BitCastInst>(V)) {
|
||||
V = BC->getOperand(0);
|
||||
// Don't transform when the type is x86_amx, it make the pass that lower
|
||||
// x86_amx type happy.
|
||||
if (BC->getType()->isX86_AMXTy() || V->getType()->isX86_AMXTy())
|
||||
return false;
|
||||
if (!SI.isAtomic() || isSupportedAtomicType(V->getType())) {
|
||||
combineStoreToNewValue(IC, SI, V);
|
||||
return true;
|
||||
|
@ -71,20 +71,20 @@ define dso_local void @test_api(i16 signext %0, i16 signext %1) local_unnamed_ad
|
||||
; CHECK-NEXT: .cfi_def_cfa_offset 8
|
||||
; CHECK-NEXT: tilerelease
|
||||
; CHECK-NEXT: retq
|
||||
%3 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %0, i16 8, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 0), i64 32) #4
|
||||
%4 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 8, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 1024), i64 32) #4
|
||||
%3 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %0, i16 8, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 0), i64 32) #4
|
||||
%4 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 8, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 1024), i64 32) #4
|
||||
tail call void (...) @foo() #4
|
||||
%5 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32) #4
|
||||
%6 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %0, i16 %1, i16 8, <256 x i32> %5, <256 x i32> %3, <256 x i32> %4) #4
|
||||
tail call void @llvm.x86.tilestored64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32, <256 x i32> %6) #4
|
||||
%5 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32) #4
|
||||
%6 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %0, i16 %1, i16 8, x86_amx %5, x86_amx %3, x86_amx %4) #4
|
||||
tail call void @llvm.x86.tilestored64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32, x86_amx %6) #4
|
||||
ret void
|
||||
}
|
||||
|
||||
declare dso_local void @foo(...) local_unnamed_addr #3
|
||||
|
||||
declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #4
|
||||
declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #4
|
||||
declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #4
|
||||
declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #4
|
||||
declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) #4
|
||||
declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) #4
|
||||
|
||||
attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
|
||||
attributes #3 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
|
||||
|
@ -47,31 +47,31 @@ define dso_local void @test_api(i32 %0, i16 signext %1, i16 signext %2) local_un
|
||||
br i1 %4, label %11, label %7
|
||||
|
||||
7: ; preds = %3
|
||||
%8 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%9 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%8 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%9 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%10 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
br label %15
|
||||
|
||||
11: ; preds = %3
|
||||
%12 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
|
||||
%13 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
|
||||
%14 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
|
||||
%12 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
|
||||
%13 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
|
||||
%14 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
|
||||
br label %15
|
||||
|
||||
15: ; preds = %11, %7
|
||||
%16 = phi <256 x i32> [ %12, %11 ], [ %8, %7 ]
|
||||
%17 = phi <256 x i32> [ %13, %11 ], [ %9, %7 ]
|
||||
%18 = phi <256 x i32> [ %14, %11 ], [ %10, %7 ]
|
||||
%19 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %6, i16 %2, i16 %1, <256 x i32> %18, <256 x i32> %16, <256 x i32> %17) #3
|
||||
tail call void @llvm.x86.tilestored64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, <256 x i32> %19) #3
|
||||
%16 = phi x86_amx [ %12, %11 ], [ %8, %7 ]
|
||||
%17 = phi x86_amx [ %13, %11 ], [ %9, %7 ]
|
||||
%18 = phi x86_amx [ %14, %11 ], [ %10, %7 ]
|
||||
%19 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %6, i16 %2, i16 %1, x86_amx %18, x86_amx %16, x86_amx %17) #3
|
||||
tail call void @llvm.x86.tilestored64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, x86_amx %19) #3
|
||||
ret void
|
||||
}
|
||||
|
||||
declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
|
||||
declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
|
||||
|
||||
declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3
|
||||
declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) #3
|
||||
|
||||
declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3
|
||||
declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) #3
|
||||
|
||||
attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+avx,+avx2,+avx512f,+cx8,+f16c,+fma,+fxsr,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
|
||||
attributes #3 = { nounwind }
|
||||
|
@ -37,23 +37,23 @@ define dso_local void @test_chain(i8* %A_mem, i8* %B_mem, i8* %C_mem) local_unna
|
||||
; CHECK-NEXT: vzeroupper
|
||||
; CHECK-NEXT: retq
|
||||
entry:
|
||||
%a1 = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %A_mem, i64 64)
|
||||
%a1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %A_mem, i64 64)
|
||||
%addr = getelementptr inbounds i8, i8* %A_mem, i64 1024
|
||||
%a2 = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %addr, i64 64)
|
||||
%c1 = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %C_mem, i64 64)
|
||||
%a2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %addr, i64 64)
|
||||
%c1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %C_mem, i64 64)
|
||||
%caddr = getelementptr inbounds i8, i8* %C_mem, i64 1024
|
||||
%c2 = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %caddr, i64 64)
|
||||
%c2 = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %caddr, i64 64)
|
||||
br label %dotpd
|
||||
|
||||
dotpd:
|
||||
%b = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %B_mem, i64 64)
|
||||
%dp1 = call <256 x i32> @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, <256 x i32> %c1, <256 x i32> %a1, <256 x i32> %b)
|
||||
call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* nonnull %C_mem, i64 64, <256 x i32> %dp1)
|
||||
%dp2 = call <256 x i32> @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, <256 x i32> %c2, <256 x i32> %a2, <256 x i32> %b)
|
||||
call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* nonnull %caddr, i64 64, <256 x i32> %dp2)
|
||||
%b = call x86_amx @llvm.x86.tileloadd64.internal(i16 16, i16 64, i8* nonnull %B_mem, i64 64)
|
||||
%dp1 = call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %c1, x86_amx %a1, x86_amx %b)
|
||||
call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* nonnull %C_mem, i64 64, x86_amx %dp1)
|
||||
%dp2 = call x86_amx @llvm.x86.tdpbssd.internal(i16 16, i16 64, i16 64, x86_amx %c2, x86_amx %a2, x86_amx %b)
|
||||
call void @llvm.x86.tilestored64.internal(i16 16, i16 64, i8* nonnull %caddr, i64 64, x86_amx %dp2)
|
||||
ret void
|
||||
}
|
||||
|
||||
declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
|
||||
declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>)
|
||||
declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>)
|
||||
declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
|
||||
declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
|
||||
declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
|
||||
|
@ -70,43 +70,43 @@ define dso_local void @test_api(i32 %0, i16 signext %1, i16 signext %2) local_un
|
||||
; CHECK-NEXT: tilerelease
|
||||
; CHECK-NEXT: vzeroupper
|
||||
; CHECK-NEXT: retq
|
||||
%4 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%5 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%6 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%7 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%8 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%9 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%4 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%5 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%6 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%7 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%8 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%9 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%10 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%11 = icmp eq i32 %0, 0
|
||||
br i1 %11, label %16, label %12
|
||||
|
||||
12: ; preds = %3
|
||||
%13 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%14 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%15 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%13 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%14 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
%15 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
|
||||
br label %20
|
||||
|
||||
16: ; preds = %3
|
||||
%17 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
|
||||
%18 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
|
||||
%19 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
|
||||
%17 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
|
||||
%18 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
|
||||
%19 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
|
||||
br label %20
|
||||
|
||||
20: ; preds = %16, %12
|
||||
%21 = phi <256 x i32> [ %17, %16 ], [ %13, %12 ]
|
||||
%22 = phi <256 x i32> [ %18, %16 ], [ %14, %12 ]
|
||||
%23 = phi <256 x i32> [ %19, %16 ], [ %15, %12 ]
|
||||
%24 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %1, <256 x i32> %23, <256 x i32> %21, <256 x i32> %22) #3
|
||||
%25 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, <256 x i32> %6, <256 x i32> %24, <256 x i32> %5) #3
|
||||
%26 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, <256 x i32> %8, <256 x i32> %25, <256 x i32> %7) #3
|
||||
%27 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %2, i16 %2, i16 %2, <256 x i32> %10, <256 x i32> %26, <256 x i32> %9) #3
|
||||
tail call void @llvm.x86.tilestored64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, <256 x i32> %27) #3
|
||||
%21 = phi x86_amx [ %17, %16 ], [ %13, %12 ]
|
||||
%22 = phi x86_amx [ %18, %16 ], [ %14, %12 ]
|
||||
%23 = phi x86_amx [ %19, %16 ], [ %15, %12 ]
|
||||
%24 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %1, x86_amx %23, x86_amx %21, x86_amx %22) #3
|
||||
%25 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, x86_amx %6, x86_amx %24, x86_amx %5) #3
|
||||
%26 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, x86_amx %8, x86_amx %25, x86_amx %7) #3
|
||||
%27 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %2, i16 %2, i16 %2, x86_amx %10, x86_amx %26, x86_amx %9) #3
|
||||
tail call void @llvm.x86.tilestored64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, x86_amx %27) #3
|
||||
ret void
|
||||
}
|
||||
|
||||
declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
|
||||
declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3
|
||||
declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3
|
||||
declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
|
||||
declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) #3
|
||||
declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) #3
|
||||
|
||||
attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
|
||||
attributes #3 = { nounwind }
|
||||
|
@ -8,18 +8,104 @@ target triple = "x86_64-unknown-linux-gnu"
|
||||
@buf = dso_local global [1024 x i8] zeroinitializer, align 16
|
||||
@buf2 = dso_local global [1024 x i8] zeroinitializer, align 16
|
||||
|
||||
; test bitcast x86_amx to <256 x i32>
|
||||
define dso_local void @test_user_empty(i16 %m, i16 %n, i8 *%buf, i64 %s) #2 {
|
||||
; CHECK-LABEL: @test_user_empty(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N:%.*]], i8* [[BUF:%.*]], i64 [[S:%.*]]) [[ATTR3:#.*]]
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
entry:
|
||||
%t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %n, i8* %buf, i64 %s) #3
|
||||
%t2 = bitcast x86_amx %t1 to <256 x i32>
|
||||
ret void
|
||||
}
|
||||
|
||||
; test bitcast <256 x i32> to x86_amx
|
||||
define dso_local void @test_user_empty2(<256 x i32> %in) #2 {
|
||||
; CHECK-LABEL: @test_user_empty2(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
entry:
|
||||
%t = bitcast <256 x i32> %in to x86_amx
|
||||
ret void
|
||||
}
|
||||
|
||||
define dso_local <256 x i32> @test_amx_load_bitcast(<256 x i32>* %in, i16 %m, i16 %n, i8 *%buf, i64 %s) #2 {
|
||||
; CHECK-LABEL: @test_amx_load_bitcast(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[T1:%.*]] = load <256 x i32>, <256 x i32>* [[IN:%.*]], align 64
|
||||
; CHECK-NEXT: [[TMP0:%.*]] = bitcast <256 x i32>* [[IN]] to i8*
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[N:%.*]], i8* [[TMP0]], i64 64)
|
||||
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[N]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP1]]) [[ATTR3]]
|
||||
; CHECK-NEXT: ret <256 x i32> [[T1]]
|
||||
;
|
||||
entry:
|
||||
%t1 = load <256 x i32>, <256 x i32>* %in, align 64
|
||||
%t2 = bitcast <256 x i32> %t1 to x86_amx
|
||||
call void @llvm.x86.tilestored64.internal(i16 %m, i16 %n, i8* %buf, i64 %s, x86_amx %t2) #3
|
||||
ret <256 x i32> %t1
|
||||
}
|
||||
|
||||
define dso_local <256 x i32> @test_amx_bitcast_store(<256 x i32>* %out, i16 %m, i16 %n, i8 *%buf, i64 %s) #2 {
|
||||
; CHECK-LABEL: @test_amx_bitcast_store(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[M:%.*]], i16 [[M]], i8* [[BUF:%.*]], i64 [[S:%.*]]) [[ATTR3]]
|
||||
; CHECK-NEXT: [[TMP0:%.*]] = bitcast <256 x i32>* [[OUT:%.*]] to i8*
|
||||
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[M]], i16 [[M]], i8* [[TMP0]], i64 64, x86_amx [[T1]])
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = load <256 x i32>, <256 x i32>* [[OUT]], align 1024
|
||||
; CHECK-NEXT: ret <256 x i32> [[TMP1]]
|
||||
;
|
||||
entry:
|
||||
%t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %m, i16 %m, i8* %buf, i64 %s) #3
|
||||
%t2 = bitcast x86_amx %t1 to <256 x i32>
|
||||
store <256 x i32> %t2, <256 x i32>* %out
|
||||
ret <256 x i32> %t2
|
||||
}
|
||||
|
||||
define dso_local void @test_src_add(<256 x i32> %x, <256 x i32> %y, i16 %r, i16 %c, i8* %buf, i64 %s) #2 {
|
||||
; CHECK-LABEL: @test_src_add(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64
|
||||
; CHECK-NEXT: [[ADD:%.*]] = add <256 x i32> [[Y:%.*]], [[X:%.*]]
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8*
|
||||
; CHECK-NEXT: store <256 x i32> [[ADD]], <256 x i32>* [[TMP0]], align 1024
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[R:%.*]], i16 [[C:%.*]], i8* [[TMP1]], i64 64)
|
||||
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[R]], i16 [[C]], i8* [[BUF:%.*]], i64 [[S:%.*]], x86_amx [[TMP2]]) [[ATTR3]]
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
entry:
|
||||
%add = add <256 x i32> %y, %x
|
||||
%t = bitcast <256 x i32> %add to x86_amx
|
||||
call void @llvm.x86.tilestored64.internal(i16 %r, i16 %c, i8* %buf, i64 %s, x86_amx %t) #3
|
||||
ret void
|
||||
}
|
||||
|
||||
define dso_local void @test_src_add2(<256 x i32> %x, i16 %r, i16 %c, i8* %buf, i64 %s) #2 {
|
||||
; CHECK-LABEL: @test_src_add2(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[TMP0:%.*]] = alloca <256 x i32>, align 64
|
||||
; CHECK-NEXT: [[T1:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[R:%.*]], i16 [[C:%.*]], i8* [[BUF:%.*]], i64 [[S:%.*]]) [[ATTR3]]
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <256 x i32>* [[TMP0]] to i8*
|
||||
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[R]], i16 [[C]], i8* [[TMP1]], i64 64, x86_amx [[T1]])
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = load <256 x i32>, <256 x i32>* [[TMP0]], align 1024
|
||||
; CHECK-NEXT: [[ADD:%.*]] = add <256 x i32> [[TMP2]], [[X:%.*]]
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
entry:
|
||||
%t1 = call x86_amx @llvm.x86.tileloadd64.internal(i16 %r, i16 %c, i8* %buf, i64 %s) #3
|
||||
%t2 = bitcast x86_amx %t1 to <256 x i32>
|
||||
%add = add <256 x i32> %t2, %x
|
||||
ret void
|
||||
}
|
||||
|
||||
define dso_local void @test_load(i8* %in, i8* %out) local_unnamed_addr #2 {
|
||||
; CHECK-LABEL: @test_load(
|
||||
; CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[IN:%.*]] to <256 x i32>*
|
||||
; CHECK-NEXT: [[TMP2:%.*]] = bitcast i8* [[OUT:%.*]] to <256 x i32>*
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <256 x i32>* [[TMP1]] to <128 x i32>*
|
||||
; CHECK-NEXT: [[TMP4:%.*]] = load <128 x i32>, <128 x i32>* [[TMP3]], align 64
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr <128 x i32>, <128 x i32>* [[TMP3]], i32 1
|
||||
; CHECK-NEXT: [[TMP6:%.*]] = load <128 x i32>, <128 x i32>* [[TMP5]], align 64
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = bitcast <256 x i32>* [[TMP2]] to <128 x i32>*
|
||||
; CHECK-NEXT: store <128 x i32> [[TMP4]], <128 x i32>* [[TMP7]], align 64
|
||||
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr <128 x i32>, <128 x i32>* [[TMP7]], i32 1
|
||||
; CHECK-NEXT: store <128 x i32> [[TMP6]], <128 x i32>* [[TMP8]], align 64
|
||||
; CHECK-NEXT: [[TMP3:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1]], align 64, [[TBAA2:!tbaa !.*]]
|
||||
; CHECK-NEXT: store <256 x i32> [[TMP3]], <256 x i32>* [[TMP2]], align 64, [[TBAA2]]
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
%1 = bitcast i8* %in to <256 x i32>*
|
||||
@ -29,18 +115,33 @@ define dso_local void @test_load(i8* %in, i8* %out) local_unnamed_addr #2 {
|
||||
ret void
|
||||
}
|
||||
|
||||
define dso_local <256 x i32> @foo(<256 x i32>* nocapture readonly byval(<256 x i32>) align 1024 %0, <256 x i32>* nocapture readonly byval(<256 x i32>) align 1024 %1) local_unnamed_addr #0 {
|
||||
; CHECK-LABEL: @foo(
|
||||
; CHECK-NEXT: entry:
|
||||
; CHECK-NEXT: [[X:%.*]] = load <256 x i32>, <256 x i32>* [[TMP0:%.*]], align 1024, [[TBAA5:!tbaa !.*]]
|
||||
; CHECK-NEXT: [[Y:%.*]] = load <256 x i32>, <256 x i32>* [[TMP1:%.*]], align 1024, [[TBAA5]]
|
||||
; CHECK-NEXT: [[ADD:%.*]] = add <256 x i32> [[Y]], [[X]]
|
||||
; CHECK-NEXT: ret <256 x i32> [[ADD]]
|
||||
;
|
||||
entry:
|
||||
%x = load <256 x i32>, <256 x i32>* %0, align 1024, !tbaa !2
|
||||
%y = load <256 x i32>, <256 x i32>* %1, align 1024, !tbaa !2
|
||||
%add = add <256 x i32> %y, %x
|
||||
ret <256 x i32> %add
|
||||
}
|
||||
|
||||
define dso_local void @__tile_loadd(%struct.__tile_str* nocapture %0, i8* %1, i64 %2) local_unnamed_addr #0 {
|
||||
; CHECK-LABEL: @__tile_loadd(
|
||||
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 0
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2:!tbaa !.*]]
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA5]]
|
||||
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 1
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7:!tbaa !.*]]
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA8:!tbaa !.*]]
|
||||
; CHECK-NEXT: [[TMP8:%.*]] = shl i64 [[TMP2:%.*]], 32
|
||||
; CHECK-NEXT: [[TMP9:%.*]] = ashr exact i64 [[TMP8]], 32
|
||||
; CHECK-NEXT: [[TMP10:%.*]] = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP1:%.*]], i64 [[TMP9]]) [[ATTR3:#.*]]
|
||||
; CHECK-NEXT: [[TMP10:%.*]] = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP1:%.*]], i64 [[TMP9]]) [[ATTR3]]
|
||||
; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 2
|
||||
; CHECK-NEXT: [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP11]] to i8*
|
||||
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP12]], i64 64, <256 x i32> [[TMP10]])
|
||||
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP12]], i64 64, x86_amx [[TMP10]])
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
%4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 0
|
||||
@ -49,32 +150,33 @@ define dso_local void @__tile_loadd(%struct.__tile_str* nocapture %0, i8* %1, i6
|
||||
%7 = load i16, i16* %6, align 2, !tbaa !7
|
||||
%8 = shl i64 %2, 32
|
||||
%9 = ashr exact i64 %8, 32
|
||||
%10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %5, i16 %7, i8* %1, i64 %9) #3
|
||||
%11 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2
|
||||
store <256 x i32> %10, <256 x i32>* %11, align 64, !tbaa !8
|
||||
%10 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 %5, i16 %7, i8* %1, i64 %9) #3
|
||||
%11 = bitcast x86_amx %10 to <256 x i32>
|
||||
%12 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2
|
||||
store <256 x i32> %11, <256 x i32>* %12, align 64, !tbaa !8
|
||||
ret void
|
||||
}
|
||||
|
||||
define dso_local void @__tile_dpbsud(%struct.__tile_str* nocapture %0, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr #0 {
|
||||
; CHECK-LABEL: @__tile_dpbsud(
|
||||
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP1:%.*]], i64 0, i32 0
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2]]
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA5]]
|
||||
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 1
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7]]
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA8]]
|
||||
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 1
|
||||
; CHECK-NEXT: [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 2, [[TBAA7]]
|
||||
; CHECK-NEXT: [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 2, [[TBAA8]]
|
||||
; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2
|
||||
; CHECK-NEXT: [[TMP11:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8*
|
||||
; CHECK-NEXT: [[TMP12:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP11]], i64 64)
|
||||
; CHECK-NEXT: [[TMP12:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP11]], i64 64)
|
||||
; CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2
|
||||
; CHECK-NEXT: [[TMP14:%.*]] = bitcast <256 x i32>* [[TMP13]] to i8*
|
||||
; CHECK-NEXT: [[TMP15:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP14]], i64 64)
|
||||
; CHECK-NEXT: [[TMP15:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP14]], i64 64)
|
||||
; CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
|
||||
; CHECK-NEXT: [[TMP17:%.*]] = bitcast <256 x i32>* [[TMP16]] to i8*
|
||||
; CHECK-NEXT: [[TMP18:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP7]], i8* [[TMP17]], i64 64)
|
||||
; CHECK-NEXT: [[TMP19:%.*]] = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], <256 x i32> [[TMP12]], <256 x i32> [[TMP15]], <256 x i32> [[TMP18]]) [[ATTR3]]
|
||||
; CHECK-NEXT: [[TMP18:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP7]], i8* [[TMP17]], i64 64)
|
||||
; CHECK-NEXT: [[TMP19:%.*]] = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], x86_amx [[TMP12]], x86_amx [[TMP15]], x86_amx [[TMP18]]) [[ATTR3]]
|
||||
; CHECK-NEXT: [[TMP20:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8*
|
||||
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP20]], i64 64, <256 x i32> [[TMP19]])
|
||||
; CHECK-NEXT: call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP20]], i64 64, x86_amx [[TMP19]])
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
%4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 0
|
||||
@ -85,27 +187,31 @@ define dso_local void @__tile_dpbsud(%struct.__tile_str* nocapture %0, %struct._
|
||||
%9 = load i16, i16* %8, align 2, !tbaa !7
|
||||
%10 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2
|
||||
%11 = load <256 x i32>, <256 x i32>* %10, align 64, !tbaa !8
|
||||
%12 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 2
|
||||
%13 = load <256 x i32>, <256 x i32>* %12, align 64, !tbaa !8
|
||||
%14 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2
|
||||
%15 = load <256 x i32>, <256 x i32>* %14, align 64, !tbaa !8
|
||||
%16 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %5, i16 %7, i16 %9, <256 x i32> %11, <256 x i32> %13, <256 x i32> %15) #3
|
||||
store <256 x i32> %16, <256 x i32>* %10, align 64, !tbaa !8
|
||||
%12 = bitcast <256 x i32> %11 to x86_amx
|
||||
%13 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 2
|
||||
%14 = load <256 x i32>, <256 x i32>* %13, align 64, !tbaa !8
|
||||
%15 = bitcast <256 x i32> %14 to x86_amx
|
||||
%16 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2
|
||||
%17 = load <256 x i32>, <256 x i32>* %16, align 64, !tbaa !8
|
||||
%18 = bitcast <256 x i32> %17 to x86_amx
|
||||
%19 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 %5, i16 %7, i16 %9, x86_amx %12, x86_amx %15, x86_amx %18) #3
|
||||
%20 = bitcast x86_amx %19 to <256 x i32>
|
||||
store <256 x i32> %20, <256 x i32>* %10, align 64, !tbaa !8
|
||||
ret void
|
||||
}
|
||||
|
||||
define dso_local void @__tile_stored(i8* %0, i64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr #1 {
|
||||
; CHECK-LABEL: @__tile_stored(
|
||||
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 0
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2]]
|
||||
; CHECK-NEXT: [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA5]]
|
||||
; CHECK-NEXT: [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 1
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7]]
|
||||
; CHECK-NEXT: [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA8]]
|
||||
; CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
|
||||
; CHECK-NEXT: [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP8]] to i8*
|
||||
; CHECK-NEXT: [[TMP10:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP9]], i64 64)
|
||||
; CHECK-NEXT: [[TMP10:%.*]] = call x86_amx @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP9]], i64 64)
|
||||
; CHECK-NEXT: [[TMP11:%.*]] = shl i64 [[TMP1:%.*]], 32
|
||||
; CHECK-NEXT: [[TMP12:%.*]] = ashr exact i64 [[TMP11]], 32
|
||||
; CHECK-NEXT: tail call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP0:%.*]], i64 [[TMP12]], <256 x i32> [[TMP10]]) [[ATTR3]]
|
||||
; CHECK-NEXT: tail call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP0:%.*]], i64 [[TMP12]], x86_amx [[TMP10]]) [[ATTR3]]
|
||||
; CHECK-NEXT: ret void
|
||||
;
|
||||
%4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 0
|
||||
@ -114,15 +220,16 @@ define dso_local void @__tile_stored(i8* %0, i64 %1, %struct.__tile_str* nocaptu
|
||||
%7 = load i16, i16* %6, align 2, !tbaa !7
|
||||
%8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2
|
||||
%9 = load <256 x i32>, <256 x i32>* %8, align 64, !tbaa !8
|
||||
%10 = shl i64 %1, 32
|
||||
%11 = ashr exact i64 %10, 32
|
||||
tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %7, i8* %0, i64 %11, <256 x i32> %9) #3
|
||||
%10 = bitcast <256 x i32> %9 to x86_amx
|
||||
%11 = shl i64 %1, 32
|
||||
%12 = ashr exact i64 %11, 32
|
||||
tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %7, i8* %0, i64 %12, x86_amx %10) #3
|
||||
ret void
|
||||
}
|
||||
|
||||
declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
|
||||
declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3
|
||||
declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3
|
||||
declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
|
||||
declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx) #3
|
||||
declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx) #3
|
||||
|
||||
attributes #0 = { alwaysinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
|
||||
attributes #1 = { alwaysinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
|
||||
|
@ -76,6 +76,7 @@ StringRef llvm::getEnumName(MVT::SimpleValueType T) {
|
||||
case MVT::f128: return "MVT::f128";
|
||||
case MVT::ppcf128: return "MVT::ppcf128";
|
||||
case MVT::x86mmx: return "MVT::x86mmx";
|
||||
case MVT::x86amx: return "MVT::x86amx";
|
||||
case MVT::Glue: return "MVT::Glue";
|
||||
case MVT::isVoid: return "MVT::isVoid";
|
||||
case MVT::v1i1: return "MVT::v1i1";
|
||||
|
@ -248,7 +248,8 @@ enum IIT_Info {
|
||||
IIT_V128 = 47,
|
||||
IIT_BF16 = 48,
|
||||
IIT_STRUCT9 = 49,
|
||||
IIT_V256 = 50
|
||||
IIT_V256 = 50,
|
||||
IIT_AMX = 51
|
||||
};
|
||||
|
||||
static void EncodeFixedValueType(MVT::SimpleValueType VT,
|
||||
@ -276,6 +277,7 @@ static void EncodeFixedValueType(MVT::SimpleValueType VT,
|
||||
case MVT::token: return Sig.push_back(IIT_TOKEN);
|
||||
case MVT::Metadata: return Sig.push_back(IIT_METADATA);
|
||||
case MVT::x86mmx: return Sig.push_back(IIT_MMX);
|
||||
case MVT::x86amx: return Sig.push_back(IIT_AMX);
|
||||
// MVT::OtherVT is used to mean the empty struct type here.
|
||||
case MVT::Other: return Sig.push_back(IIT_EMPTYSTRUCT);
|
||||
// MVT::isVoid is used to represent varargs here.
|
||||
|
Loading…
Reference in New Issue
Block a user