|  | //===-- X86TileConfig.cpp - Tile Register Configure----------------------===// | 
|  | // | 
|  | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | 
|  | // See https://llvm.org/LICENSE.txt for license information. | 
|  | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | 
|  | // | 
|  | //===----------------------------------------------------------------------===// | 
|  | // | 
|  | /// \file Pass to config the shape of AMX physical registers | 
|  | /// AMX register need to be configured before use. In X86PreTileConfig pass | 
|  | /// the pldtilecfg instruction is inserted, however at that time we don't | 
|  | /// know the shape of each physical tile registers, because the register | 
|  | /// allocation is not done yet. This pass runs after egister allocation | 
|  | /// pass. It collects the shape information of each physical tile register | 
|  | /// and store the shape in the stack slot that is allocated for load config | 
|  | /// to tile config register. | 
|  | // | 
|  | //===----------------------------------------------------------------------===// | 
|  |  | 
|  | #include "X86.h" | 
|  | #include "X86InstrBuilder.h" | 
|  | #include "X86MachineFunctionInfo.h" | 
|  | #include "X86Subtarget.h" | 
|  | #include "llvm/CodeGen/LiveIntervals.h" | 
|  | #include "llvm/CodeGen/MachineFrameInfo.h" | 
|  | #include "llvm/CodeGen/MachineFunctionPass.h" | 
|  | #include "llvm/CodeGen/MachineInstr.h" | 
|  | #include "llvm/CodeGen/MachineRegisterInfo.h" | 
|  | #include "llvm/CodeGen/Passes.h" | 
|  | #include "llvm/CodeGen/TargetInstrInfo.h" | 
|  | #include "llvm/CodeGen/TargetRegisterInfo.h" | 
|  | #include "llvm/CodeGen/TileShapeInfo.h" | 
|  | #include "llvm/CodeGen/VirtRegMap.h" | 
|  | #include "llvm/InitializePasses.h" | 
|  |  | 
|  | using namespace llvm; | 
|  |  | 
|  | #define DEBUG_TYPE "tileconfig" | 
|  |  | 
|  | namespace { | 
|  |  | 
|  | struct X86TileConfig : public MachineFunctionPass { | 
|  |  | 
|  | X86TileConfig() : MachineFunctionPass(ID) {} | 
|  |  | 
|  | /// Return the pass name. | 
|  | StringRef getPassName() const override { return "Tile Register Configure"; } | 
|  |  | 
|  | /// X86TileConfig analysis usage. | 
|  | void getAnalysisUsage(AnalysisUsage &AU) const override { | 
|  | AU.setPreservesAll(); | 
|  | AU.addRequired<VirtRegMapWrapperLegacy>(); | 
|  | AU.addRequired<LiveIntervalsWrapperPass>(); | 
|  | MachineFunctionPass::getAnalysisUsage(AU); | 
|  | } | 
|  |  | 
|  | /// Perform register allocation. | 
|  | bool runOnMachineFunction(MachineFunction &mf) override; | 
|  |  | 
|  | MachineFunctionProperties getRequiredProperties() const override { | 
|  | return MachineFunctionProperties().setNoPHIs(); | 
|  | } | 
|  |  | 
|  | static char ID; | 
|  | }; | 
|  |  | 
|  | } // end anonymous namespace | 
|  |  | 
|  | char X86TileConfig::ID = 0; | 
|  |  | 
|  | INITIALIZE_PASS_BEGIN(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", | 
|  | false, false) | 
|  | INITIALIZE_PASS_DEPENDENCY(VirtRegMapWrapperLegacy) | 
|  | INITIALIZE_PASS_END(X86TileConfig, DEBUG_TYPE, "Tile Register Configure", false, | 
|  | false) | 
|  |  | 
|  | unsigned getAMXRegNum(MachineRegisterInfo *MRI, Register Reg) { | 
|  | if (Reg.isVirtual()) { | 
|  | unsigned RegClassID = MRI->getRegClass(Reg)->getID(); | 
|  | if (RegClassID == X86::TILERegClassID) | 
|  | return 1; | 
|  | if (RegClassID == X86::TILEPAIRRegClassID) | 
|  | return 2; | 
|  | } else { | 
|  | if (Reg >= X86::TMM0 && Reg <= X86::TMM7) | 
|  | return 1; | 
|  | if (Reg >= X86::TMM0_TMM1 && Reg <= X86::TMM6_TMM7) | 
|  | return 2; | 
|  | } | 
|  | return 0; | 
|  | } | 
|  |  | 
|  | static void collectVirtRegShapes(MachineRegisterInfo *MRI, VirtRegMap &VRM, | 
|  | Register VirtReg, | 
|  | SmallVector<ShapeT, 8> &Phys2Shapes) { | 
|  | unsigned Num = getAMXRegNum(MRI, VirtReg); | 
|  | MCRegister PhysReg = VRM.getPhys(VirtReg); | 
|  | if (!PhysReg) | 
|  | return; | 
|  |  | 
|  | if (Num == 1) { | 
|  | unsigned Index = PhysReg - X86::TMM0; | 
|  | if (!Phys2Shapes[Index].isValid()) { | 
|  | ShapeT Shape = VRM.getShape(VirtReg); | 
|  | Phys2Shapes[Index] = std::move(Shape); | 
|  | return; | 
|  | } | 
|  | } | 
|  | // Split tile pair shape info to 2 single tile shape info. e.g: | 
|  | // Put TMM0_TMM1's Shape to TMM0's shape + TMM1's Shape in Phys2Shapes. | 
|  | if (Num == 2) { | 
|  | unsigned Index0 = (PhysReg - X86::TMM0_TMM1) * 2; | 
|  | unsigned Index1 = (PhysReg - X86::TMM0_TMM1) * 2 + 1; | 
|  |  | 
|  | ShapeT Shape = VRM.getShape(VirtReg); | 
|  | assert(Shape.getShapeNum() == 2 && "Unexpected shape number!"); | 
|  |  | 
|  | if (!Phys2Shapes[Index0].isValid()) { | 
|  | ShapeT Shape0(Shape.getRow(0), Shape.getCol(0), MRI); | 
|  | Phys2Shapes[Index0] = std::move(Shape0); | 
|  | } | 
|  |  | 
|  | if (!Phys2Shapes[Index1].isValid()) { | 
|  | ShapeT Shape1(Shape.getRow(1), Shape.getCol(1), MRI); | 
|  | Phys2Shapes[Index1] = std::move(Shape1); | 
|  | } | 
|  | } | 
|  | } | 
|  |  | 
|  | static bool isAMXRegClass(MachineRegisterInfo *MRI, Register Reg) { | 
|  | return getAMXRegNum(MRI, Reg) > 0; | 
|  | } | 
|  |  | 
|  | bool X86TileConfig::runOnMachineFunction(MachineFunction &MF) { | 
|  | X86MachineFunctionInfo *X86FI = MF.getInfo<X86MachineFunctionInfo>(); | 
|  | // Early exit in the common case of non-AMX code. | 
|  | if (X86FI->getAMXProgModel() != AMXProgModelEnum::ManagedRA) | 
|  | return false; | 
|  |  | 
|  | const X86Subtarget &ST = MF.getSubtarget<X86Subtarget>(); | 
|  | const TargetRegisterInfo *TRI = ST.getRegisterInfo(); | 
|  | const TargetInstrInfo *TII = ST.getInstrInfo(); | 
|  | MachineRegisterInfo &MRI = MF.getRegInfo(); | 
|  | LiveIntervals &LIS = getAnalysis<LiveIntervalsWrapperPass>().getLIS(); | 
|  | VirtRegMap &VRM = getAnalysis<VirtRegMapWrapperLegacy>().getVRM(); | 
|  |  | 
|  | if (VRM.isShapeMapEmpty()) | 
|  | return false; | 
|  |  | 
|  | int SS = INT_MAX; | 
|  | for (MachineBasicBlock &MBB : MF) { | 
|  | for (MachineInstr &MI : MBB) { | 
|  | if (MI.getOpcode() == X86::PLDTILECFGV) { | 
|  | SS = MI.getOperand(0).getIndex(); | 
|  | break; | 
|  | } | 
|  | } | 
|  | if (SS != INT_MAX) | 
|  | break; | 
|  | } | 
|  | // Didn't find PLDTILECFGV, just return false; | 
|  | if (SS == INT_MAX) | 
|  | return false; | 
|  |  | 
|  | // Try to find a point to insert MIs for constant shapes. | 
|  | // Here we are leveraging the palette id inserted in PreRA pass. | 
|  | unsigned ConstPos = 0; | 
|  | MachineInstr *ConstMI = nullptr; | 
|  | for (MachineInstr &MI : MF.front()) { | 
|  | if (MI.getOpcode() == X86::MOV8mi && SS == MI.getOperand(0).getIndex()) { | 
|  | ConstMI = &MI; | 
|  | break; | 
|  | } | 
|  | ++ConstPos; | 
|  | } | 
|  | assert(ConstMI && "Cannot find an insertion point"); | 
|  |  | 
|  | unsigned AMXRegNum = TRI->getRegClass(X86::TILERegClassID)->getNumRegs(); | 
|  | SmallVector<ShapeT, 8> Phys2Shapes(AMXRegNum, ShapeT()); | 
|  | for (unsigned I = 0, E = MRI.getNumVirtRegs(); I != E; ++I) { | 
|  | Register VirtReg = Register::index2VirtReg(I); | 
|  | if (MRI.reg_nodbg_empty(VirtReg)) | 
|  | continue; | 
|  | if (!isAMXRegClass(&MRI, VirtReg)) | 
|  | continue; | 
|  | collectVirtRegShapes(&MRI, VRM, VirtReg, Phys2Shapes); | 
|  | } | 
|  |  | 
|  | // Fill in the shape of each tile physical register. | 
|  | for (unsigned I = 0; I < AMXRegNum; ++I) { | 
|  | ShapeT Shape = Phys2Shapes[I]; | 
|  | if (!Shape.isValid()) | 
|  | continue; | 
|  | DebugLoc DL; | 
|  | bool IsRow = true; | 
|  | MachineInstr *NewMI = nullptr; | 
|  | for (auto &R : {Shape.getRow()->getReg(), Shape.getCol()->getReg()}) { | 
|  | // Here is the data format for the tile config. | 
|  | // 0      palette | 
|  | // 1      start_row | 
|  | // 2-15   reserved, must be zero | 
|  | // 16-17  tile0.colsb Tile 0 bytes per row. | 
|  | // 18-19  tile1.colsb Tile 1 bytes per row. | 
|  | // 20-21  tile2.colsb Tile 2 bytes per row. | 
|  | // ... (sequence continues) | 
|  | // 30-31  tile7.colsb Tile 7 bytes per row. | 
|  | // 32-47  reserved, must be zero | 
|  | // 48     tile0.rows Tile 0 rows. | 
|  | // 49     tile1.rows Tile 1 rows. | 
|  | // 50     tile2.rows Tile 2 rows. | 
|  | // ... (sequence continues) | 
|  | // 55     tile7.rows Tile 7 rows. | 
|  | // 56-63  reserved, must be zero | 
|  | int64_t Imm = INT64_MAX; | 
|  | int Offset = IsRow ? 48 + I : 16 + I * 2; | 
|  | for (auto &DefMI : MRI.def_instructions(R)) { | 
|  | MachineBasicBlock &MBB = *DefMI.getParent(); | 
|  | if (DefMI.isMoveImmediate()) { | 
|  | if (Imm != INT64_MAX) { | 
|  | // FIXME: We should handle this case in future. | 
|  | assert(Imm == DefMI.getOperand(1).getImm() && | 
|  | "Cannot initialize with different shapes"); | 
|  | continue; | 
|  | } | 
|  | if (DefMI.getOperand(1).isImm()) { | 
|  | Imm = DefMI.getOperand(1).getImm(); | 
|  | } else { | 
|  | assert(DefMI.getOpcode() == X86::MOV32r0 && | 
|  | "The opcode is assumed to be MOV32r0 if the operand is not " | 
|  | "immediate."); | 
|  | Imm = 0; | 
|  | } | 
|  |  | 
|  | NewMI = addFrameReference( | 
|  | BuildMI(MF.front(), ++ConstMI->getIterator(), DL, | 
|  | TII->get(IsRow ? X86::MOV8mi : X86::MOV16mi)), | 
|  | SS, Offset) | 
|  | .addImm(Imm); | 
|  | ConstMI = NewMI; | 
|  | LIS.InsertMachineInstrInMaps(*NewMI); | 
|  | } else { | 
|  | unsigned SubIdx = IsRow ? X86::sub_8bit : X86::sub_16bit; | 
|  | unsigned RegSize = TRI->getRegSizeInBits(*MRI.getRegClass(R)); | 
|  | if ((IsRow && RegSize == 8) || (!IsRow && RegSize == 16)) | 
|  | SubIdx = 0; | 
|  | auto Iter = DefMI.getIterator(); | 
|  | if (&MBB == &MF.front() && | 
|  | (unsigned)std::distance(MBB.instr_begin(), Iter) < ConstPos) | 
|  | Iter = ConstMI->getIterator(); | 
|  | NewMI = addFrameReference( | 
|  | BuildMI(MBB, ++Iter, DL, | 
|  | TII->get(IsRow ? X86::MOV8mr : X86::MOV16mr)), | 
|  | SS, Offset) | 
|  | .addReg(R, 0, SubIdx); | 
|  | SlotIndex SIdx = LIS.InsertMachineInstrInMaps(*NewMI); | 
|  | LIS.extendToIndices(LIS.getInterval(R), {SIdx.getRegSlot()}); | 
|  | } | 
|  | } | 
|  | IsRow = false; | 
|  | } | 
|  | } | 
|  | return true; | 
|  | } | 
|  |  | 
|  | FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); } |