//===- Target/X86/X86PreAMXConfig.cpp - ------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // /// Insert tilecfg for each area of key AMX intrinsic. /// All the key AMX intrinsic's tile operand must come from tileload. And the /// def tile of key AMX intrinsic must be tilestored. /// take tdpbssd for example: /// -------------------------------------------------------------------------- /// %t1 = call x86_amx @llvm.x86.tileloadd64.internal(...) key /// %t2 = call x86_amx @llvm.x86.tileloadd64.internal(...) | /// %t3 = call x86_amx @llvm.x86.tileloadd64.internal(...) amx /// %td = tail call x86_amx @llvm.x86.tdpbssd.internal(t1, t2, t3) | /// call void @llvm.x86.tilestored64.internal(... td) area /// -------------------------------------------------------------------------- /// This pass will insert tilecfg before every key-amx-area, some like: /// -------------------------------------------------------------------------- /// %cfgmem = alloca <16 x i32>, align 4 * allocate mem /// store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init /// ... /// ... pre-config shape of %t1 * /// store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 * /// store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config /// ... * /// ... pre-config shape of %t2 * shapes /// store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * /// store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 * /// ... /// call void @llvm.x86.ldtilecfg(i8* %cfgmem) * tile config // //===----------------------------------------------------------------------===// // #include "X86.h" #include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/CodeGen/Passes.h" #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/CodeGen/ValueTypes.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/IntrinsicsX86.h" #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" using namespace llvm; using namespace PatternMatch; #define DEBUG_TYPE "pre-amx-config" static bool isAMXIntrinsic(IntrinsicInst *II) { for (Value *Operand : II->operands()) if (Operand->getType()->isX86_AMXTy()) return true; return II->getType()->isX86_AMXTy(); } static bool isTileLoad(IntrinsicInst *II) { return II->getIntrinsicID() == Intrinsic::x86_tileloadd64_internal || II->getIntrinsicID() == Intrinsic::x86_tileloaddt164_internal; } static bool isTileStore(IntrinsicInst *II) { return II->getIntrinsicID() == Intrinsic::x86_tilestored64_internal; } #ifndef NDEBUG static bool onlyTileDef(IntrinsicInst *II) { for (Value *Operand : II->operands()) if (Operand->getType()->isX86_AMXTy()) return false; return II->getType()->isX86_AMXTy(); } static bool brokenVolatile(Instruction *I) { // Todo: it is weak to identify a normal call here. if ((isa(I) && !isa(I)) || I->isTerminator()) return true; return false; } #endif namespace { class X86PreAMXConfig { Function &F; public: X86PreAMXConfig(Function &Func) : F(Func) {} bool preTileConfig(); bool addTileConfig(Instruction *ModelStart, SmallVector &Shapes); bool findConfigShapes( DenseMap> &PosAndShapes); bool getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector &Shapes); bool preWriteTileCfg(Value *I8Ptr, Instruction *Pos, SmallVector &Shapes); BasicBlock::iterator getShapesAndConfigPosEnd(BasicBlock::iterator Iter, SmallVector &Shapes); bool checkVolatileModel(SmallSet &Loads, IntrinsicInst *Store, IntrinsicInst *KeyAMX); }; // Orderly write the shapes in tilecfg's mem. This maybe not right. // Because the first shape may not corresponding to the first tmm register, // so we need to handle at at X86FastTileConfig::materializeTileCfg() // after register allocation. // For example: // -------------------------------------------------------------------------- // zeroinitialize tilecfg's mem (of ldtilecfg) // -------------------------------------------------------------------------- // ... pre-config shape of %t1 * // %amx.tmm.0.shape.row = getelementptr i8, i8* %mem, i64 48 * // %amx.tmm.0.shape.col = getelementptr i16, i16* %mem, i64 16 * // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 * // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config // ... * // ... pre-config shape of %t2 * // %amx.tmm.1.shape.row = getelementptr i8, i8* %mem, i64 49 * // %amx.tmm.1.shape.col = getelementptr i16, i16* %mem, i64 18 * // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 * // ... * // ... pre-config shape of %t3 * of // %amx.tmm.2.shape.row = getelementptr i8, i8* %mem, i64 50 * // %amx.tmm.2.shape.col = getelementptr i16, i16* %mem, i64 20 * // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 * // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 * // ... * tiles // ... pre-config shape of %td * // %amx.tmm.3.shape.row = getelementptr i8, i8* %mem, i64 51 * // %amx.tmm.3.shape.col = getelementptr i16, i16* %mem, i64 22 * // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 * // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 * // -------------------------------------------------------------------------- // call void @llvm.x86.ldtilecfg(i8* %mem) * tile config // -------------------------------------------------------------------------- // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) // call void @llvm.x86.tilestored64.internal(... td) area // -------------------------------------------------------------------------- bool X86PreAMXConfig::preWriteTileCfg(Value *I8Ptr, Instruction *Pos, SmallVector &Shapes) { bool Write = false; LLVMContext &Ctx = Pos->getParent()->getContext(); Type *I8Ty = Type::getInt8Ty(Ctx); Type *I16Ty = Type::getInt16Ty(Ctx); // TODO: Currently we defaultly set Palette = 1, it may be assigned to // other value in the future. Value *PaletteOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 0); Value *PaletteValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1); Value *PalettePos = GetElementPtrInst::Create(I8Ty, I8Ptr, PaletteOffset, "", Pos); new StoreInst(PaletteValue, PalettePos, Pos); for (int I = 0, E = Shapes.size() / 2; I < E; I++) { Value *RowOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 48 + I); Value *ColOffset = ConstantInt::get(Type::getInt64Ty(Ctx), 16 + I * 2); const std::string ShapeName = "amx.tmm." + itostr(I); Value *RowPos = GetElementPtrInst::Create(I8Ty, I8Ptr, RowOffset, ShapeName + ".shape.row", Pos); Value *ColPos = GetElementPtrInst::Create(I8Ty, I8Ptr, ColOffset, "", Pos); ColPos = new BitCastInst(ColPos, PointerType::get(I16Ty, 0), ShapeName + ".shape.col", Pos); Value *Row = Shapes[I * 2]; Value *Col = Shapes[I * 2 + 1]; Row = new TruncInst(Row, I8Ty, "", Pos); new StoreInst(Row, RowPos, Pos); new StoreInst(Col, ColPos, Pos); Write = true; } return Write; } bool X86PreAMXConfig::addTileConfig(Instruction *ModelStart, SmallVector &Shapes) { Module *M = F.getParent(); IRBuilder<> Builder(ModelStart); const DataLayout &DL = M->getDataLayout(); unsigned AddrSpace = DL.getAllocaAddrSpace(); LLVMContext &Ctx = Builder.getContext(); Type *V512Ty = VectorType::get(Builder.getInt32Ty(), 16, false); Align Alignment = DL.getPrefTypeAlign(Type::getInt32Ty(Ctx)); AllocaInst *Addr = new AllocaInst(V512Ty, AddrSpace, "", &F.getEntryBlock().front()); Addr->setAlignment(Alignment); Value *I8Ptr = Builder.CreateBitCast(Addr, Builder.getInt8PtrTy()); std::array Args = {I8Ptr}; Instruction *Cfg = Builder.CreateIntrinsic(Intrinsic::x86_ldtilecfg_internal, None, Args); Value *Val0 = Constant::getNullValue(V512Ty); Instruction *Init0 = new StoreInst(Val0, Addr, false, Alignment, Cfg); assert(Init0 && "Not Zero initilizate the cfg mem!"); preWriteTileCfg(I8Ptr, Cfg, Shapes); return Init0; } // Todo: We may need to handle "more than one store" case in the future. bool X86PreAMXConfig::checkVolatileModel(SmallSet &Loads, IntrinsicInst *Store, IntrinsicInst *KeyAMX) { Value *ST = Store->getOperand(4); // Only has tileload and tilestore. if (!KeyAMX) return (Loads.size() == 1) && Loads.contains(ST); // All Loads should be operands of KeyAMX. // All tile operands of KeyAMX should come from Loads. for (Value *Op : KeyAMX->operands()) { if (Op->getType()->isX86_AMXTy()) if (!Loads.erase(Op)) return false; } // The def of KeyAMX should be stored into mem. // Todo: is it key amx can be no def? return Loads.empty() && (ST == cast(KeyAMX)); } bool X86PreAMXConfig::getKeyAMXShapes(IntrinsicInst *KeyAMX, SmallVector &Shapes) { for (unsigned I = 0; I < KeyAMX->getNumOperands(); I++) { Value *Op = KeyAMX->getOperand(I); if (!Op->getType()->isX86_AMXTy()) continue; IntrinsicInst *TileDef = dyn_cast(Op); assert((TileDef && isTileLoad(TileDef)) && "All KeyAMX's tile definiation should comes from TileLoad!"); Shapes.push_back(TileDef->getOperand(0)); Shapes.push_back(TileDef->getOperand(1)); } if (!isTileStore(KeyAMX)) { Shapes.push_back(KeyAMX->getOperand(0)); Shapes.push_back(KeyAMX->getOperand(1)); } return Shapes.size() != 0; } // Collect the shapes and skip the area of current key amx intrinsic. // // For example: // ... // -------------------------------------------------------------------------- // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) record (m,k) // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) record (m,k) // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) record (m,k) // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) // call void @llvm.x86.tilestored64.internal(m, n,... td) <--PosEnd record (m,k) // -------------------------------------------------------------------------- BasicBlock::iterator X86PreAMXConfig::getShapesAndConfigPosEnd(BasicBlock::iterator Iter, SmallVector &Shapes) { IntrinsicInst *KeyAMX = nullptr; BasicBlock *BB = Iter->getParent(); BasicBlock::iterator PosEnd = BB->end(); SmallSet Loads; // See TileStore as "Config Position End" and check volatile model. for (auto I = Iter, E = BB->end(); I != E; ++I) { assert(!brokenVolatile(&*I) && "Not reach tile store!"); IntrinsicInst *II = dyn_cast(&*I); if (!II || !isAMXIntrinsic(II)) continue; if (isTileLoad(II)) { Loads.insert(II); } else if (isTileStore(II)) { if (!checkVolatileModel(Loads, II, KeyAMX)) report_fatal_error("Not Volatile AMX Model!"); PosEnd = I; break; } else { assert(!KeyAMX && "Too many key amx intrinsic!"); KeyAMX = II; } } assert(PosEnd != BB->end() && "Not find TileStore!"); // See KeyAMX as TileStore if only TileLoad and TileStore. if (!KeyAMX) KeyAMX = dyn_cast(&*PosEnd); // Get Shapes in order. assert(Shapes.empty() && "Shapes should be clean."); getKeyAMXShapes(KeyAMX, Shapes); return PosEnd; } // Record a key amx area's shapes with its position. // Use the first tileload as its position. // For example: // ... // -------------------------------------------------------------------------- // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) <-- pos // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) / // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) shapes: // %td = call x86_amx @llvm.x86.tdpbssd.internal(...t1, t2, t3) (m,k)(k,n) // call void @llvm.x86.tilestored64.internal(m, n,... td) (m,n)(m,n) // -------------------------------------------------------------------------- bool X86PreAMXConfig::findConfigShapes( DenseMap> &PosAndShapes) { bool Find = false; for (BasicBlock &BB : F) { for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ++I) { IntrinsicInst *II = dyn_cast(&*I); if (!II) continue; if (!isAMXIntrinsic(II)) continue; assert(onlyTileDef(II) && "Not volatile model for AMX at O0!"); I = getShapesAndConfigPosEnd(I, PosAndShapes[&*I]); Find = true; } } return Find; } // Insert ldtilecfg and preconfig the shapes for each area of key AMX intrinsic. // e.g. (key amx = tdpbssd) // -------------------------------------------------------------------------- // %cfgmem = alloca <16 x i32>, align 4 * allocate mem // store <16 x i32> zeroinitializer, <16 x i32>* %cfgmem * zero init // ... // ... pre-config shape of %t1 * // store volatile i8 %m, i8* %amx.tmm.0.shape.row, align 1 * // store volatile i16 %k, i16* %amx.tmm.0.shape.col, align 2 * pre-config // ... * // ... pre-config shape of %t2 * // store volatile i8 %k, i8* %amx.tmm.1.shape.row, align 1 * shapes // store volatile i16 %n, i16* %amx.tmm.1.shape.col, align 2 * // ... * // ... pre-config shape of %t3 * of // store volatile i8 %m, i8* %amx.tmm.2.shape.row, align 1 * // store volatile i16 %n, i16* %amx.tmm.2.shape.col, align 2 * // ... * tiles // ... pre-config shape of %td * // store volatile i8 %m, i8* %amx.tmm.3.shape.row, align 1 * // store volatile i16 %n, i16* %amx.tmm.3.shape.col, align 2 * // // call void @llvm.x86.ldtilecfg(i8* %cfgmem) * pre-config // -------------------------------------------------------------------------- // %t1 = call x86_amx @llvm.x86.tileloadd64.internal(m, k, ...) key // %t2 = call x86_amx @llvm.x86.tileloadd64.internal(k, n, ...) // %t3 = call x86_amx @llvm.x86.tileloadd64.internal(m, n, ...) amx // %td = tail call x86_amx @llvm.x86.tdpbssd.internal(m, n, k, t1, t2, t3) // call void @llvm.x86.tilestored64.internal(... td) area // -------------------------------------------------------------------------- bool X86PreAMXConfig::preTileConfig() { DenseMap> PosAndShapes; bool NeedCfg = findConfigShapes(PosAndShapes); if (!NeedCfg) return false; for (auto &IPAndShapes : PosAndShapes) addTileConfig(IPAndShapes.first, IPAndShapes.second); return true; } } // anonymous namespace namespace { class X86PreAMXConfigPass : public FunctionPass { public: static char ID; X86PreAMXConfigPass() : FunctionPass(ID) { initializeX86PreAMXConfigPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override { TargetMachine *TM = &getAnalysis().getTM(); bool C = false; // Prepare for fast register allocation at O0. if (TM->getOptLevel() == CodeGenOpt::None) { // We pre-config each key AMX intrinsic at O0. // In theory, one tile config can cover several AMX intrinsics, but // it is very diffcult to classify the tile shapes at O0. So here we // let thing be easy, pre-config every key AMX intrinsic. X86PreAMXConfig PCFG(F); C = PCFG.preTileConfig(); } return C; } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired(); } }; } // anonymous namespace static const char PassName[] = "Pre AMX Tile Config"; char X86PreAMXConfigPass::ID = 0; INITIALIZE_PASS_BEGIN(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false) INITIALIZE_PASS_DEPENDENCY(TargetPassConfig) INITIALIZE_PASS_END(X86PreAMXConfigPass, DEBUG_TYPE, PassName, false, false) FunctionPass *llvm::createX86PreAMXConfigPass() { return new X86PreAMXConfigPass(); }