| from typing import Optional, List |
| |
| |
| # Base class for an instruction. To implement a basic instruction that doesn't |
| # impact the control-flow, create a new class inheriting from this. |
| class Instruction: |
| # Contains the name of the output register, if any. |
| _result: Optional[str] |
| # Contains the instruction opcode. |
| _opcode: str |
| # Contains all the instruction operands, except result and opcode. |
| _operands: List[str] |
| |
| def __init__(self, line: str): |
| self.line = line |
| tokens = line.split() |
| if len(tokens) > 1 and tokens[1] == "=": |
| self._result = tokens[0] |
| self._opcode = tokens[2] |
| self._operands = tokens[3:] if len(tokens) > 2 else [] |
| else: |
| self._result = None |
| self._opcode = tokens[0] |
| self._operands = tokens[1:] if len(tokens) > 1 else [] |
| |
| def __str__(self): |
| if self._result is None: |
| return f" {self._opcode} {self._operands}" |
| return f"{self._result:3} = {self._opcode} {self._operands}" |
| |
| # Returns the instruction opcode. |
| def opcode(self) -> str: |
| return self._opcode |
| |
| # Returns the instruction operands. |
| def operands(self) -> List[str]: |
| return self._operands |
| |
| # Returns the instruction output register. Calling this function is |
| # only allowed if has_output_register() is true. |
| def output_register(self) -> str: |
| assert self._result is not None |
| return self._result |
| |
| # Returns true if this function has an output register. False otherwise. |
| def has_output_register(self) -> bool: |
| return self._result is not None |
| |
| # This function is used to initialize state related to this instruction |
| # before module execution begins. For example, global Input variables |
| # can use this to store the lane ID into the register. |
| def static_execution(self, lane): |
| pass |
| |
| # This function is called everytime this instruction is executed by a |
| # tangle. This function should not be directly overriden, instead see |
| # _impl and _advance_ip. |
| def runtime_execution(self, module, lane): |
| self._impl(module, lane) |
| self._advance_ip(module, lane) |
| |
| # This function needs to be overriden if your instruction can be executed. |
| # It implements the logic of the instruction. |
| # 'Static' instructions like OpConstant should not override this since |
| # they are not supposed to be executed at runtime. |
| def _impl(self, module, lane): |
| raise RuntimeError(f"Unimplemented instruction {self}") |
| |
| # By default, IP is incremented to point to the next instruction. |
| # If the instruction modifies IP (like OpBranch), this must be overridden. |
| def _advance_ip(self, module, lane): |
| lane.set_ip(lane.ip() + 1) |
| |
| |
| # Those are parsed, but never executed. |
| class OpEntryPoint(Instruction): |
| pass |
| |
| |
| class OpFunction(Instruction): |
| pass |
| |
| |
| class OpFunctionEnd(Instruction): |
| pass |
| |
| |
| class OpLabel(Instruction): |
| pass |
| |
| |
| class OpVariable(Instruction): |
| pass |
| |
| |
| class OpName(Instruction): |
| def name(self) -> str: |
| return self._operands[1][1:-1] |
| |
| def decoratedRegister(self) -> str: |
| return self._operands[0] |
| |
| |
| # The only decoration we use if the BuiltIn one to initialize the values. |
| class OpDecorate(Instruction): |
| def static_execution(self, lane): |
| if self._operands[1] == "LinkageAttributes": |
| return |
| |
| assert ( |
| self._operands[1] == "BuiltIn" |
| and self._operands[2] == "SubgroupLocalInvocationId" |
| ) |
| lane.set_register(self._operands[0], lane.tid()) |
| |
| |
| # Constants |
| class OpConstant(Instruction): |
| def static_execution(self, lane): |
| lane.set_register(self._result, int(self._operands[1])) |
| |
| |
| class OpConstantTrue(OpConstant): |
| def static_execution(self, lane): |
| lane.set_register(self._result, True) |
| |
| |
| class OpConstantFalse(OpConstant): |
| def static_execution(self, lane): |
| lane.set_register(self._result, False) |
| |
| |
| class OpConstantComposite(OpConstant): |
| def static_execution(self, lane): |
| result = [] |
| for op in self._operands[1:]: |
| result.append(lane.get_register(op)) |
| lane.set_register(self._result, result) |
| |
| |
| # Control flow instructions |
| class OpFunctionCall(Instruction): |
| def _impl(self, module, lane): |
| pass |
| |
| def _advance_ip(self, module, lane): |
| entry = module.get_function_entry(self._operands[1]) |
| lane.do_call(entry, self._result) |
| |
| |
| class OpReturn(Instruction): |
| def _impl(self, module, lane): |
| pass |
| |
| def _advance_ip(self, module, lane): |
| lane.do_return(None) |
| |
| |
| class OpReturnValue(Instruction): |
| def _impl(self, module, lane): |
| pass |
| |
| def _advance_ip(self, module, lane): |
| lane.do_return(lane.get_register(self._operands[0])) |
| |
| |
| class OpBranch(Instruction): |
| def _impl(self, module, lane): |
| pass |
| |
| def _advance_ip(self, module, lane): |
| lane.set_ip(module.get_bb_entry(self._operands[0])) |
| pass |
| |
| |
| class OpBranchConditional(Instruction): |
| def _impl(self, module, lane): |
| pass |
| |
| def _advance_ip(self, module, lane): |
| condition = lane.get_register(self._operands[0]) |
| if condition: |
| lane.set_ip(module.get_bb_entry(self._operands[1])) |
| else: |
| lane.set_ip(module.get_bb_entry(self._operands[2])) |
| |
| |
| class OpSwitch(Instruction): |
| def _impl(self, module, lane): |
| pass |
| |
| def _advance_ip(self, module, lane): |
| value = lane.get_register(self._operands[0]) |
| default_label = self._operands[1] |
| i = 2 |
| while i < len(self._operands): |
| imm = int(self._operands[i]) |
| label = self._operands[i + 1] |
| if value == imm: |
| lane.set_ip(module.get_bb_entry(label)) |
| return |
| i += 2 |
| lane.set_ip(module.get_bb_entry(default_label)) |
| |
| |
| class OpUnreachable(Instruction): |
| def _impl(self, module, lane): |
| raise RuntimeError("This instruction should never be executed.") |
| |
| |
| # Convergence instructions |
| class MergeInstruction(Instruction): |
| def merge_location(self): |
| return self._operands[0] |
| |
| def continue_location(self): |
| return None if len(self._operands) < 3 else self._operands[1] |
| |
| def _impl(self, module, lane): |
| lane.handle_convergence_header(self) |
| |
| |
| class OpLoopMerge(MergeInstruction): |
| pass |
| |
| |
| class OpSelectionMerge(MergeInstruction): |
| pass |
| |
| |
| # Other instructions |
| class OpBitcast(Instruction): |
| def _impl(self, module, lane): |
| # TODO: find out the type from the defining instruction. |
| # This can only work for DXC. |
| if self._operands[0] == "%int": |
| lane.set_register(self._result, int(lane.get_register(self._operands[1]))) |
| else: |
| raise RuntimeError("Unsupported OpBitcast operand") |
| |
| |
| class OpAccessChain(Instruction): |
| def _impl(self, module, lane): |
| # Python dynamic types allows me to simplify. As long as the SPIR-V |
| # is legal, this should be fine. |
| # Note: SPIR-V structs are stored as tuples |
| value = lane.get_register(self._operands[1]) |
| for operand in self._operands[2:]: |
| value = value[lane.get_register(operand)] |
| lane.set_register(self._result, value) |
| |
| |
| class OpCompositeConstruct(Instruction): |
| def _impl(self, module, lane): |
| output = [] |
| for op in self._operands[1:]: |
| output.append(lane.get_register(op)) |
| lane.set_register(self._result, output) |
| |
| |
| class OpCompositeExtract(Instruction): |
| def _impl(self, module, lane): |
| value = lane.get_register(self._operands[1]) |
| output = value |
| for op in self._operands[2:]: |
| output = output[int(op)] |
| lane.set_register(self._result, output) |
| |
| |
| class OpStore(Instruction): |
| def _impl(self, module, lane): |
| lane.set_register(self._operands[0], lane.get_register(self._operands[1])) |
| |
| |
| class OpLoad(Instruction): |
| def _impl(self, module, lane): |
| lane.set_register(self._result, lane.get_register(self._operands[1])) |
| |
| |
| class OpIAdd(Instruction): |
| def _impl(self, module, lane): |
| LHS = lane.get_register(self._operands[1]) |
| RHS = lane.get_register(self._operands[2]) |
| lane.set_register(self._result, LHS + RHS) |
| |
| |
| class OpISub(Instruction): |
| def _impl(self, module, lane): |
| LHS = lane.get_register(self._operands[1]) |
| RHS = lane.get_register(self._operands[2]) |
| lane.set_register(self._result, LHS - RHS) |
| |
| |
| class OpIMul(Instruction): |
| def _impl(self, module, lane): |
| LHS = lane.get_register(self._operands[1]) |
| RHS = lane.get_register(self._operands[2]) |
| lane.set_register(self._result, LHS * RHS) |
| |
| |
| class OpLogicalNot(Instruction): |
| def _impl(self, module, lane): |
| LHS = lane.get_register(self._operands[1]) |
| lane.set_register(self._result, not LHS) |
| |
| |
| class _LessThan(Instruction): |
| def _impl(self, module, lane): |
| LHS = lane.get_register(self._operands[1]) |
| RHS = lane.get_register(self._operands[2]) |
| lane.set_register(self._result, LHS < RHS) |
| |
| |
| class _GreaterThan(Instruction): |
| def _impl(self, module, lane): |
| LHS = lane.get_register(self._operands[1]) |
| RHS = lane.get_register(self._operands[2]) |
| lane.set_register(self._result, LHS > RHS) |
| |
| |
| class OpSLessThan(_LessThan): |
| pass |
| |
| |
| class OpULessThan(_LessThan): |
| pass |
| |
| |
| class OpSGreaterThan(_GreaterThan): |
| pass |
| |
| |
| class OpUGreaterThan(_GreaterThan): |
| pass |
| |
| |
| class OpIEqual(Instruction): |
| def _impl(self, module, lane): |
| LHS = lane.get_register(self._operands[1]) |
| RHS = lane.get_register(self._operands[2]) |
| lane.set_register(self._result, LHS == RHS) |
| |
| |
| class OpINotEqual(Instruction): |
| def _impl(self, module, lane): |
| LHS = lane.get_register(self._operands[1]) |
| RHS = lane.get_register(self._operands[2]) |
| lane.set_register(self._result, LHS != RHS) |
| |
| |
| class OpPhi(Instruction): |
| def _impl(self, module, lane): |
| previousBBName = lane.get_previous_bb_name() |
| i = 1 |
| while i < len(self._operands): |
| label = self._operands[i + 1] |
| if label == previousBBName: |
| lane.set_register(self._result, lane.get_register(self._operands[i])) |
| return |
| i += 2 |
| raise RuntimeError("previousBB not in the OpPhi _operands") |
| |
| |
| class OpSelect(Instruction): |
| def _impl(self, module, lane): |
| condition = lane.get_register(self._operands[1]) |
| value = lane.get_register(self._operands[2 if condition else 3]) |
| lane.set_register(self._result, value) |
| |
| |
| # Wave intrinsics |
| class OpGroupNonUniformBroadcastFirst(Instruction): |
| def _impl(self, module, lane): |
| assert lane.get_register(self._operands[1]) == 3 |
| if lane.is_first_active_lane(): |
| lane.broadcast_register(self._result, lane.get_register(self._operands[2])) |
| |
| |
| class OpGroupNonUniformElect(Instruction): |
| def _impl(self, module, lane): |
| lane.set_register(self._result, lane.is_first_active_lane()) |