| /* |
| * Copyright 2023 Intel Corporation |
| * SPDX-License-Identifier: MIT |
| */ |
| |
| /** |
| * \file brw_nir_lower_cooperative_matrix.c |
| * Lower cooperative matrix to subgroup operations. |
| * |
| * All supported matrix types are assumed to have either 8 rows or 8 |
| * columns. The other dimension of the matrix is typically 8 times the number |
| * of data elements that can be stored in a 32-bit dword. Matrix data is |
| * indexed by a combination of an array element and a subgroup invocation ID. |
| * |
| * Two layouts for matrix data are used. In the first layout, |
| * subgroupShuffle(slice[N], ...) accesses row N of the matrix. This will be |
| * called row-major hereafter. In the other layout, |
| * subgroupShuffle(slice[...], M) accesses column M of the matrix. This will |
| * be called column-major hereafter. In cases where a single 32-bit value is |
| * stored in each entry, these layouts are identical. |
| * |
| * The subtle difference arises when multiple values are packed into a single |
| * 32-bit dword. If two 16-bit values are packed in a single 32-bit value in |
| * column-major, subgroupShuffle(slice[0], 1) holds matrix entries m[1][1] and |
| * m[2][1] (in m[row][column] notation). In row-major, that same shuffle holds |
| * m[0][2] and m[0][3]. |
| * |
| * There is an alternate way to think about the matrix layouts. Every matrix |
| * size supported by the Intel driver is either Sx8 (e.g., 16x8 for float16 B |
| * matrix) or Sx8T (e.g., 8x32 for int8 A matrix). The A matrix and B matrix |
| * layouts are such that a single 8 dword register hold an entire row of the |
| * matrix. |
| * |
| * Consider a matrix stored starting in register g32. In an A matrix, the |
| * packed dwords of g32 contain only the data for a single row of the |
| * matrix. g32 is row 0, g33 is row 1, etc. In a B matrix, the packed dwords |
| * of g(32+N).X contain only the data for a single column of the |
| * matrix. g[32:40].0 is column 0, g[32:40].1 is column 1, etc. |
| * |
| * This leads to some shenanigans in \c lower_cmat_load_store. |
| * |
| * In the common case, A, C, and result matrices are stored row major while B |
| * matrices are stored column major. This arrangement facilitates efficient |
| * dot product operations using DPAS or DP4A instructions. |
| * |
| * Future optimizations are possible when row and column major are |
| * flipped. That is, efficient dot products are also possible when A, C, and |
| * result matrices are column major while B is row major. |
| */ |
| |
| #include "brw_nir.h" |
| |
| typedef struct { |
| /* Vector type that holds the elements packed. */ |
| const glsl_type *type; |
| |
| /* How many cmat elements per slice element. */ |
| unsigned packing_factor; |
| |
| struct glsl_cmat_description desc; |
| |
| /* Used by the tables. Variable holding a slice or |
| * arrays-of-arrays of slices. |
| * |
| * If present, the var->type (without arrays!) should match |
| * the type above. |
| */ |
| nir_variable *var; |
| } slice_info; |
| |
| #define BRW_MAX_PACKING_FACTOR 4 |
| |
| struct lower_cmat_state { |
| void *temp_ctx; |
| |
| nir_shader *shader; |
| |
| struct hash_table *slice_var_to_slice_info; |
| |
| struct hash_table *mat_var_to_slice_info; |
| |
| unsigned subgroup_size; |
| |
| struct { |
| nir_def *tmp[NIR_MAX_VEC_COMPONENTS * BRW_MAX_PACKING_FACTOR]; |
| } scratch; |
| }; |
| |
| static bool |
| cmat_descriptions_are_equal(struct glsl_cmat_description a, |
| struct glsl_cmat_description b) |
| { |
| return a.element_type == b.element_type && |
| a.scope == b.scope && |
| a.rows == b.rows && |
| a.cols == b.cols && |
| a.use == b.use; |
| } |
| |
| static void |
| print_coop_types(struct lower_cmat_state *state) |
| { |
| fprintf(stderr, "--- Slices to Cooperative Matrix type table\n"); |
| hash_table_foreach(state->slice_var_to_slice_info, e) { |
| nir_variable *var = (void *)e->key; |
| const slice_info *info = e->data; |
| fprintf(stderr, "%p: %s -> %s\n", var, var->name, |
| glsl_get_type_name(glsl_cmat_type(&info->desc))); |
| } |
| fprintf(stderr, "\n\n"); |
| } |
| |
| static const slice_info * |
| get_slice_info(struct lower_cmat_state *state, nir_deref_instr *deref) |
| { |
| nir_variable *var = nir_deref_instr_get_variable(deref); |
| struct hash_entry *entry = _mesa_hash_table_search(state->slice_var_to_slice_info, var); |
| |
| assert(entry != NULL); |
| |
| return entry->data; |
| } |
| |
| static bool |
| lower_cmat_filter(const nir_instr *instr, const void *_state) |
| { |
| if (instr->type == nir_instr_type_deref) { |
| nir_deref_instr *deref = nir_instr_as_deref(instr); |
| return glsl_type_is_cmat(deref->type); |
| } |
| |
| if (instr->type != nir_instr_type_intrinsic) |
| return false; |
| |
| nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); |
| switch (intrin->intrinsic) { |
| case nir_intrinsic_cmat_construct: |
| case nir_intrinsic_cmat_load: |
| case nir_intrinsic_cmat_store: |
| case nir_intrinsic_cmat_length: |
| case nir_intrinsic_cmat_muladd: |
| case nir_intrinsic_cmat_convert: |
| case nir_intrinsic_cmat_unary_op: |
| case nir_intrinsic_cmat_binary_op: |
| case nir_intrinsic_cmat_scalar_op: |
| case nir_intrinsic_cmat_bitcast: |
| case nir_intrinsic_cmat_insert: |
| case nir_intrinsic_cmat_extract: |
| case nir_intrinsic_cmat_copy: |
| return true; |
| |
| default: |
| return false; |
| } |
| } |
| |
| static void |
| init_slice_info(struct lower_cmat_state *state, |
| struct glsl_cmat_description desc, |
| slice_info *info) |
| { |
| enum glsl_base_type base_type; |
| |
| /* Number of matrix elements stored by each subgroup invocation. If the |
| * data is packed, the slice size will be less than this. |
| */ |
| const unsigned elements_per_invocation = |
| (desc.rows * desc.cols) / state->subgroup_size; |
| |
| assert(elements_per_invocation > 0); |
| |
| const unsigned element_bits = 32; |
| const unsigned bits = glsl_base_type_get_bit_size(desc.element_type); |
| |
| /* Each invocation must have at least one dword of data, and that dword |
| * must be tightly packed with values. No matter the matrix dimensions, a |
| * matrix of uint8_t data must pack 4 values in each entry. |
| */ |
| const unsigned packing_factor = element_bits / bits; |
| assert(packing_factor <= BRW_MAX_PACKING_FACTOR); |
| |
| assert(elements_per_invocation >= packing_factor); |
| |
| switch (desc.element_type) { |
| case GLSL_TYPE_FLOAT: |
| base_type = GLSL_TYPE_FLOAT; |
| break; |
| case GLSL_TYPE_UINT: |
| case GLSL_TYPE_FLOAT16: |
| case GLSL_TYPE_BFLOAT16: |
| case GLSL_TYPE_UINT8: |
| case GLSL_TYPE_UINT16: |
| base_type = GLSL_TYPE_UINT; |
| break; |
| case GLSL_TYPE_INT: |
| case GLSL_TYPE_INT8: |
| case GLSL_TYPE_INT16: |
| base_type = GLSL_TYPE_INT; |
| break; |
| default: |
| unreachable("Invalid cooperative matrix element type."); |
| } |
| |
| unsigned len = elements_per_invocation / packing_factor; |
| |
| /* Supported matrix sizes are designed to fill either 4 or 8 SIMD8 |
| * registers on DG2. That means: |
| * |
| * 4 regsiters 8 registers |
| * SIMD32 len = 1 len = 2 |
| * SIMD16 len = 2 len = 4 |
| * SIMD8 len = 4 len = 8 |
| * |
| * On Xe2, supported matrix sizes are still designed to fill 4 registers |
| * (e.g., 8x32 uint8_t) or 8 registers (e.g., 16x16 float16). However, the |
| * 16x16 float16 matrix will assign 16 elements per channel at SIMD16. |
| */ |
| assert(len == 1 || len == 2 || len == 4 || len == 8 || len == 16); |
| |
| const struct glsl_type *slice_type = glsl_vector_type(base_type, len); |
| |
| info->type = slice_type; |
| info->desc = desc; |
| info->packing_factor = packing_factor; |
| } |
| |
| static void |
| lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intrin, |
| struct lower_cmat_state *state) |
| { |
| const bool load = intrin->intrinsic == nir_intrinsic_cmat_load; |
| const unsigned mat_src = load ? 0 : 1; |
| const unsigned ptr_src = load ? 1 : 0; |
| |
| nir_deref_instr *slice = nir_src_as_deref(intrin->src[mat_src]); |
| const slice_info *info = get_slice_info(state, slice); |
| const struct glsl_cmat_description desc = info->desc; |
| |
| nir_def *results[NIR_MAX_VEC_COMPONENTS]; |
| const unsigned num_components = glsl_get_vector_elements(slice->type); |
| |
| nir_deref_instr *pointer = nir_src_as_deref(intrin->src[ptr_src]); |
| const unsigned ptr_comp_width = glsl_get_bit_size(pointer->type); |
| const unsigned ptr_num_comps = glsl_get_vector_elements(pointer->type); |
| |
| /* The stride is given in number of elements of the pointed type, which |
| * doesn't necessarily match the matrix element type, so we need to adjust |
| * it considering it may be a vector and have a different bit-width. |
| */ |
| nir_def *stride = nir_udiv_imm(b, |
| nir_imul_imm(b, |
| intrin->src[2].ssa, |
| ptr_comp_width * ptr_num_comps), |
| glsl_base_type_get_bit_size(desc.element_type)); |
| |
| /* The data that will be packed is in successive columns for A and |
| * accumulator matrices. The data that will be packed for B matrices is in |
| * successive rows. |
| */ |
| const unsigned cols = |
| desc.use != GLSL_CMAT_USE_B ? desc.cols / info->packing_factor : desc.cols; |
| |
| nir_def *invocation = nir_load_subgroup_invocation(b); |
| nir_def *invocation_div_cols = nir_udiv_imm(b, invocation, cols); |
| nir_def *invocation_mod_cols = nir_umod_imm(b, invocation, cols); |
| |
| nir_def *i_stride; |
| |
| const bool memory_layout_matches_register_layout = |
| (nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) == |
| (desc.use != GLSL_CMAT_USE_B); |
| |
| if (memory_layout_matches_register_layout) { |
| /* In the row-major arrangement, data is loaded a dword at a time |
| * instead of a single element at a time. For this reason the stride is |
| * divided by the packing factor. |
| */ |
| i_stride = nir_udiv_imm(b, stride, info->packing_factor); |
| } else { |
| /* In the column-major arrangement, data is loaded a single element at a |
| * time. Because the data elements are transposed, the step direction |
| * that moves a single (packed) element in the row-major arrangement has |
| * to explicitly step over the packing factor count of elements. For |
| * this reason the stride is multiplied by the packing factor. |
| * |
| * NOTE: The unscaled stride is also still needed when stepping from one |
| * packed element to the next. This occurs in the for-j loop below. |
| */ |
| i_stride = nir_imul_imm(b, stride, info->packing_factor); |
| } |
| |
| nir_def *base_offset; |
| nir_def *i_step; |
| |
| if (nir_intrinsic_matrix_layout(intrin) == GLSL_MATRIX_LAYOUT_ROW_MAJOR) { |
| base_offset = nir_iadd(b, |
| nir_imul(b, |
| invocation_div_cols, |
| i_stride), |
| invocation_mod_cols); |
| |
| i_step = nir_imul_imm(b, i_stride, state->subgroup_size / cols); |
| } else { |
| base_offset = nir_iadd(b, |
| nir_imul(b, |
| invocation_mod_cols, |
| i_stride), |
| invocation_div_cols); |
| |
| i_step = nir_imm_int(b, state->subgroup_size / cols); |
| } |
| |
| if (memory_layout_matches_register_layout) { |
| const struct glsl_type *element_type = |
| glsl_scalar_type(glsl_get_base_type(slice->type)); |
| |
| pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, |
| element_type, |
| glsl_get_bit_size(element_type) / 8); |
| |
| for (unsigned i = 0; i < num_components; i++) { |
| nir_def *offset = nir_imul_imm(b, i_step, i); |
| |
| nir_deref_instr *memory_deref = |
| nir_build_deref_ptr_as_array(b, pointer, |
| nir_i2iN(b, |
| nir_iadd(b, |
| base_offset, |
| offset), |
| pointer->def.bit_size)); |
| |
| if (load) { |
| results[i] = nir_load_deref(b, memory_deref); |
| } else { |
| nir_def *src = nir_channel(b, nir_load_deref(b, slice), i); |
| nir_store_deref(b, memory_deref, src, 0x1); |
| } |
| } |
| } else { |
| const struct glsl_type *element_type = glsl_scalar_type(desc.element_type); |
| const unsigned element_bits = glsl_base_type_get_bit_size(desc.element_type); |
| const unsigned element_stride = element_bits / 8; |
| |
| pointer = nir_build_deref_cast(b, &pointer->def, pointer->modes, element_type, |
| element_stride); |
| |
| for (unsigned i = 0; i < num_components; i++) { |
| nir_def *i_offset = nir_imul_imm(b, i_step, i); |
| nir_def *v[4]; |
| |
| for (unsigned j = 0; j < info->packing_factor; j++) { |
| nir_def *offset = nir_iadd(b, nir_imul_imm(b, stride, j), i_offset); |
| |
| nir_deref_instr *memory_deref = |
| nir_build_deref_ptr_as_array(b, pointer, |
| nir_i2iN(b, |
| nir_iadd(b, |
| base_offset, |
| offset), |
| pointer->def.bit_size)); |
| |
| if (load) { |
| v[j] = nir_load_deref(b, memory_deref); |
| } else { |
| nir_def *src = nir_channel(b, nir_load_deref(b, slice), i); |
| |
| nir_def *v = |
| nir_channel(b, nir_unpack_bits(b, src, element_bits), j); |
| |
| nir_store_deref(b, memory_deref, v, 0x1); |
| } |
| } |
| |
| if (load) { |
| results[i] = nir_pack_bits(b, nir_vec(b, v, info->packing_factor), |
| info->packing_factor * element_bits); |
| } |
| } |
| } |
| |
| if (load) |
| nir_store_deref(b, slice, nir_vec(b, results, num_components), |
| nir_component_mask(num_components)); |
| } |
| |
| /* Unpack, apply operation, then pack again. */ |
| static nir_def * |
| emit_packed_alu1(nir_builder *b, |
| struct lower_cmat_state *state, |
| const slice_info *src_info, |
| const slice_info *dst_info, |
| nir_op op, |
| nir_def *src) |
| { |
| const unsigned dst_bits = glsl_base_type_bit_size(dst_info->desc.element_type); |
| const unsigned src_bits = glsl_base_type_bit_size(src_info->desc.element_type); |
| |
| const unsigned src_components = glsl_get_vector_elements(src_info->type); |
| const unsigned dst_components = glsl_get_vector_elements(dst_info->type); |
| assert(src_components * src_info->packing_factor == |
| dst_components * dst_info->packing_factor); |
| |
| /* Store the result of all individual unpacked values. */ |
| assert(src_components * src_info->packing_factor <= ARRAY_SIZE(state->scratch.tmp)); |
| nir_def **tmp = state->scratch.tmp; |
| |
| for (unsigned i = 0; i < src_components; i++) { |
| nir_def *chan = nir_channel(b, src, i); |
| |
| for (unsigned j = 0; j < src_info->packing_factor; j++) { |
| const unsigned pos = (i * src_info->packing_factor) + j; |
| nir_def *val = nir_channel(b, nir_unpack_bits(b, chan, src_bits), j); |
| tmp[pos] = nir_build_alu1(b, op, val); |
| } |
| } |
| |
| /* Store each element of the result, might pack multiple values. */ |
| nir_def *results[NIR_MAX_VEC_COMPONENTS] = {}; |
| assert(dst_components <= ARRAY_SIZE(results)); |
| |
| /* Store each packed element in destination, to be combined |
| * into results. |
| */ |
| nir_def *partial[BRW_MAX_PACKING_FACTOR]; |
| |
| for (unsigned i = 0; i < dst_components; i++) { |
| for (unsigned j = 0; j < dst_info->packing_factor; j++) { |
| const unsigned pos = (i * dst_info->packing_factor) + j; |
| partial[j] = tmp[pos]; |
| } |
| |
| results[i] = |
| nir_pack_bits(b, nir_vec(b, partial, dst_info->packing_factor), |
| dst_info->packing_factor * dst_bits); |
| } |
| |
| return nir_vec(b, results, dst_components); |
| } |
| |
| static nir_op |
| get_cmat_conversion_op(enum glsl_base_type src, |
| enum glsl_base_type dst) |
| { |
| if (src == GLSL_TYPE_BFLOAT16) { |
| assert(dst == GLSL_TYPE_FLOAT); |
| return nir_op_bf2f; |
| |
| } else if (dst == GLSL_TYPE_BFLOAT16) { |
| assert(src == GLSL_TYPE_FLOAT); |
| return nir_op_f2bf; |
| |
| } else { |
| return nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src), |
| nir_get_nir_type_for_glsl_base_type(dst), |
| nir_rounding_mode_undef); |
| } |
| } |
| |
| static void |
| lower_cmat_convert(nir_builder *b, nir_intrinsic_instr *intrin, |
| struct lower_cmat_state *state) |
| { |
| nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); |
| nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]); |
| |
| const slice_info *dst_info = get_slice_info(state, dst_slice); |
| const slice_info *src_info = get_slice_info(state, src_slice); |
| |
| const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin); |
| |
| enum glsl_base_type src_element_type = glsl_apply_signedness_to_base_type( |
| src_info->desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED); |
| enum glsl_base_type dst_element_type = glsl_apply_signedness_to_base_type( |
| dst_info->desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED); |
| |
| bool needs_intermediate = |
| (src_element_type == GLSL_TYPE_BFLOAT16 && dst_element_type != GLSL_TYPE_FLOAT) || |
| (src_element_type != GLSL_TYPE_FLOAT && dst_element_type == GLSL_TYPE_BFLOAT16); |
| |
| nir_def *result; |
| nir_def *src = nir_load_deref(b, src_slice); |
| |
| if (needs_intermediate) { |
| /* Cooperative matrices must have the same "shape" to be converted. */ |
| assert(src_info->desc.rows == dst_info->desc.rows); |
| assert(src_info->desc.cols == dst_info->desc.cols); |
| assert(src_info->desc.use == dst_info->desc.use); |
| assert(src_info->desc.scope == dst_info->desc.scope); |
| |
| struct glsl_cmat_description float_desc = src_info->desc; |
| float_desc.element_type = GLSL_TYPE_FLOAT; |
| |
| slice_info float_info = {}; |
| init_slice_info(state, float_desc, &float_info); |
| |
| nir_op op1 = get_cmat_conversion_op(src_element_type, GLSL_TYPE_FLOAT); |
| nir_op op2 = get_cmat_conversion_op(GLSL_TYPE_FLOAT, dst_element_type); |
| |
| nir_def *tmp = emit_packed_alu1(b, state, src_info, &float_info, op1, src); |
| result = emit_packed_alu1(b, state, &float_info, dst_info, op2, tmp); |
| |
| } else { |
| const unsigned dst_components = glsl_get_vector_elements(dst_info->type); |
| const unsigned dst_bits = glsl_base_type_bit_size(dst_info->desc.element_type); |
| |
| result = |
| nir_convert_cmat_intel(b, |
| dst_components, |
| dst_info->packing_factor * dst_bits, |
| src, |
| .dst_cmat_desc = dst_info->desc, |
| .src_cmat_desc = src_info->desc); |
| } |
| |
| nir_store_deref(b, dst_slice, result, nir_component_mask(result->num_components)); |
| } |
| |
| static void |
| lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin, |
| struct lower_cmat_state *state) |
| { |
| nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); |
| nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]); |
| |
| const slice_info *dst_info = get_slice_info(state, dst_slice); |
| const slice_info *src_info = get_slice_info(state, src_slice); |
| assert(cmat_descriptions_are_equal(src_info->desc, dst_info->desc)); |
| |
| nir_def *result = emit_packed_alu1(b, state, src_info, dst_info, |
| nir_intrinsic_alu_op(intrin), |
| nir_load_deref(b, src_slice)); |
| |
| nir_store_deref(b, dst_slice, result, nir_component_mask(result->num_components)); |
| } |
| |
| static void |
| lower_cmat_binary_op(nir_builder *b, nir_intrinsic_instr *intrin, |
| struct lower_cmat_state *state) |
| { |
| nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); |
| nir_deref_instr *src_a_slice = nir_src_as_deref(intrin->src[1]); |
| nir_deref_instr *src_b_slice = nir_src_as_deref(intrin->src[2]); |
| |
| nir_def *src_a = nir_load_deref(b, src_a_slice); |
| nir_def *src_b = nir_load_deref(b, src_b_slice); |
| nir_def *results[NIR_MAX_VEC_COMPONENTS]; |
| const unsigned num_components = glsl_get_vector_elements(dst_slice->type); |
| |
| const slice_info *info = get_slice_info(state, dst_slice); |
| ASSERTED const slice_info *src_a_info = get_slice_info(state, src_a_slice); |
| ASSERTED const slice_info *src_b_info = get_slice_info(state, src_b_slice); |
| |
| assert(cmat_descriptions_are_equal(info->desc, src_a_info->desc)); |
| assert(cmat_descriptions_are_equal(info->desc, src_b_info->desc)); |
| |
| const unsigned bits = glsl_base_type_bit_size(info->desc.element_type); |
| |
| for (unsigned i = 0; i < num_components; i++) { |
| nir_def *val_a = nir_channel(b, src_a, i); |
| nir_def *val_b = nir_channel(b, src_b, i); |
| |
| results[i] = |
| nir_pack_bits(b, nir_build_alu2(b, nir_intrinsic_alu_op(intrin), |
| nir_unpack_bits(b, val_a, bits), |
| nir_unpack_bits(b, val_b, bits)), |
| info->packing_factor * bits); |
| } |
| |
| nir_store_deref(b, dst_slice, nir_vec(b, results, num_components), |
| nir_component_mask(num_components)); |
| } |
| |
| static void |
| lower_cmat_scalar_op(nir_builder *b, nir_intrinsic_instr *intrin, |
| struct lower_cmat_state *state) |
| { |
| nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); |
| nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]); |
| nir_def *scalar = intrin->src[2].ssa; |
| |
| nir_def *src = nir_load_deref(b, src_slice); |
| nir_def *results[NIR_MAX_VEC_COMPONENTS]; |
| const unsigned num_components = glsl_get_vector_elements(dst_slice->type); |
| |
| const slice_info *info = get_slice_info(state, dst_slice); |
| ASSERTED const slice_info *src_info = get_slice_info(state, src_slice); |
| assert(cmat_descriptions_are_equal(info->desc, src_info->desc)); |
| |
| const unsigned bits = glsl_base_type_bit_size(info->desc.element_type); |
| |
| for (unsigned i = 0; i < num_components; i++) { |
| nir_def *val = nir_channel(b, src, i); |
| |
| results[i] = |
| nir_pack_bits(b, nir_build_alu2(b, nir_intrinsic_alu_op(intrin), |
| nir_unpack_bits(b, val, bits), |
| scalar), |
| info->packing_factor * bits); |
| } |
| |
| nir_store_deref(b, dst_slice, nir_vec(b, results, num_components), |
| nir_component_mask(num_components)); |
| } |
| |
| static nir_deref_instr * |
| lower_cmat_deref(nir_builder *b, nir_deref_instr *deref, |
| struct lower_cmat_state *state) |
| { |
| nir_deref_instr *parent = nir_deref_instr_parent(deref); |
| if (parent) { |
| assert(deref->deref_type == nir_deref_type_array); |
| parent = lower_cmat_deref(b, parent, state); |
| return nir_build_deref_array(b, parent, deref->arr.index.ssa); |
| } else { |
| assert(deref->deref_type == nir_deref_type_var); |
| assert(deref->var); |
| assert(glsl_type_is_cmat(glsl_without_array(deref->var->type))); |
| |
| struct hash_entry *entry = _mesa_hash_table_search(state->mat_var_to_slice_info, deref->var); |
| assert(entry); |
| const slice_info *info = entry->data; |
| return nir_build_deref_var(b, info->var); |
| } |
| } |
| |
| static nir_def * |
| lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) |
| { |
| struct lower_cmat_state *state = _state; |
| |
| if (instr->type == nir_instr_type_deref) { |
| nir_deref_instr *deref = lower_cmat_deref(b, nir_instr_as_deref(instr), state); |
| return &deref->def; |
| } |
| |
| nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); |
| switch (intrin->intrinsic) { |
| case nir_intrinsic_cmat_load: |
| case nir_intrinsic_cmat_store: |
| lower_cmat_load_store(b, intrin, state); |
| return NIR_LOWER_INSTR_PROGRESS_REPLACE; |
| |
| case nir_intrinsic_cmat_construct: { |
| nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]); |
| nir_def *src = intrin->src[1].ssa; |
| |
| const slice_info *info = get_slice_info(state, slice); |
| |
| if (info->packing_factor > 1) { |
| src = nir_pack_bits(b, nir_replicate(b, src, info->packing_factor), |
| info->packing_factor * glsl_base_type_get_bit_size(info->desc.element_type)); |
| } |
| |
| const unsigned num_components = glsl_get_vector_elements(slice->type); |
| |
| nir_store_deref(b, slice, nir_replicate(b, src, num_components), |
| nir_component_mask(num_components)); |
| return NIR_LOWER_INSTR_PROGRESS_REPLACE; |
| } |
| |
| case nir_intrinsic_cmat_convert: |
| lower_cmat_convert(b, intrin, state); |
| return NIR_LOWER_INSTR_PROGRESS_REPLACE; |
| |
| case nir_intrinsic_cmat_unary_op: |
| lower_cmat_unary_op(b, intrin, state); |
| return NIR_LOWER_INSTR_PROGRESS_REPLACE; |
| |
| case nir_intrinsic_cmat_binary_op: |
| lower_cmat_binary_op(b, intrin, state); |
| return NIR_LOWER_INSTR_PROGRESS_REPLACE; |
| |
| case nir_intrinsic_cmat_scalar_op: |
| lower_cmat_scalar_op(b, intrin, state); |
| return NIR_LOWER_INSTR_PROGRESS_REPLACE; |
| |
| case nir_intrinsic_cmat_length: { |
| slice_info info = {}; |
| init_slice_info(state, nir_intrinsic_cmat_desc(intrin), &info); |
| return nir_imm_intN_t(b, info.packing_factor * |
| glsl_get_vector_elements(info.type), 32); |
| } |
| |
| case nir_intrinsic_cmat_muladd: { |
| nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); |
| nir_deref_instr *A_slice = nir_src_as_deref(intrin->src[1]); |
| nir_deref_instr *B_slice = nir_src_as_deref(intrin->src[2]); |
| nir_deref_instr *accum_slice = nir_src_as_deref(intrin->src[3]); |
| |
| const slice_info *dst_info = get_slice_info(state, dst_slice); |
| const slice_info *src_info = get_slice_info(state, A_slice); |
| |
| const unsigned num_components = glsl_get_vector_elements(dst_slice->type); |
| |
| const nir_cmat_signed cmat_signed_mask = |
| nir_intrinsic_cmat_signed_mask(intrin); |
| |
| assert(((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) == |
| ((cmat_signed_mask & NIR_CMAT_B_SIGNED) == 0)); |
| assert(((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) == |
| ((cmat_signed_mask & NIR_CMAT_C_SIGNED) == 0)); |
| assert(((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) == |
| ((cmat_signed_mask & NIR_CMAT_RESULT_SIGNED) == 0)); |
| |
| enum glsl_base_type src_type = src_info->desc.element_type; |
| enum glsl_base_type dst_type = dst_info->desc.element_type; |
| |
| /* For integer types, the signedness is determined by flags on the |
| * muladd instruction. The types of the sources play no role. Adjust the |
| * types passed to the dpas_intel intrinsic to match. |
| */ |
| if (glsl_base_type_is_integer(src_type)) { |
| if ((cmat_signed_mask & NIR_CMAT_A_SIGNED) == 0) { |
| src_type = glsl_unsigned_base_type_of(src_type); |
| dst_type = glsl_unsigned_base_type_of(dst_type); |
| } else { |
| src_type = glsl_signed_base_type_of(src_type); |
| dst_type = glsl_signed_base_type_of(dst_type); |
| } |
| } |
| |
| nir_def *result = |
| nir_dpas_intel(b, |
| dst_info->packing_factor * glsl_base_type_get_bit_size(dst_info->desc.element_type), |
| nir_load_deref(b, accum_slice), |
| nir_load_deref(b, A_slice), |
| nir_load_deref(b, B_slice), |
| .dest_base_type = dst_type, |
| .src_base_type = src_type, |
| .saturate = nir_intrinsic_saturate(intrin), |
| .systolic_depth = 8, |
| .repeat_count = 8); |
| |
| nir_store_deref(b, dst_slice, result, |
| nir_component_mask(num_components)); |
| |
| return NIR_LOWER_INSTR_PROGRESS_REPLACE; |
| } |
| |
| case nir_intrinsic_cmat_bitcast: { |
| nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); |
| nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]); |
| |
| const unsigned num_components = glsl_get_vector_elements(dst_slice->type); |
| |
| assert(glsl_get_vector_elements(src_slice->type) == num_components); |
| |
| nir_store_deref(b, dst_slice, nir_load_deref(b, src_slice), |
| nir_component_mask(num_components)); |
| return NIR_LOWER_INSTR_PROGRESS_REPLACE; |
| } |
| |
| case nir_intrinsic_cmat_copy: |
| nir_copy_deref(b, |
| nir_src_as_deref(intrin->src[0]), |
| nir_src_as_deref(intrin->src[1])); |
| return NIR_LOWER_INSTR_PROGRESS_REPLACE; |
| |
| case nir_intrinsic_cmat_insert: { |
| nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); |
| nir_def *scalar = intrin->src[1].ssa; |
| nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[2]); |
| const nir_src dst_index = intrin->src[3]; |
| |
| const slice_info *info = get_slice_info(state, dst_slice); |
| ASSERTED const slice_info *src_info = get_slice_info(state, src_slice); |
| assert(cmat_descriptions_are_equal(info->desc, src_info->desc)); |
| |
| const unsigned bits = glsl_base_type_bit_size(info->desc.element_type); |
| const unsigned num_components = glsl_get_vector_elements(dst_slice->type); |
| |
| nir_def *slice_index = nir_udiv_imm(b, dst_index.ssa, info->packing_factor); |
| nir_def *vector_index = nir_umod_imm(b, dst_index.ssa, info->packing_factor); |
| nir_def *results[NIR_MAX_VEC_COMPONENTS]; |
| |
| const int slice_constant_index = nir_src_is_const(dst_index) |
| ? nir_src_as_uint(dst_index) / info->packing_factor |
| : -1; |
| |
| for (unsigned i = 0; i < num_components; i++) { |
| nir_def *val = nir_channel(b, nir_load_deref(b, src_slice), i); |
| nir_def *insert; |
| |
| if (slice_constant_index < 0 || slice_constant_index == i) { |
| if (info->packing_factor == 1) { |
| insert = scalar; |
| } else { |
| nir_def *unpacked = nir_unpack_bits(b, val, bits); |
| nir_def *v = nir_vector_insert(b, unpacked, scalar, vector_index); |
| |
| insert = nir_pack_bits(b, v, bits * info->packing_factor); |
| } |
| } else { |
| insert = val; |
| } |
| |
| results[i] = slice_constant_index < 0 |
| ? nir_bcsel(b, nir_ieq_imm(b, slice_index, i), insert, val) |
| : insert; |
| } |
| |
| nir_store_deref(b, dst_slice, nir_vec(b, results, num_components), |
| nir_component_mask(num_components)); |
| |
| return NIR_LOWER_INSTR_PROGRESS_REPLACE; |
| } |
| |
| case nir_intrinsic_cmat_extract: { |
| nir_deref_instr *slice = nir_src_as_deref(intrin->src[0]); |
| const slice_info *info = get_slice_info(state, slice); |
| nir_def *index = intrin->src[1].ssa; |
| |
| const unsigned bits = glsl_base_type_bit_size(info->desc.element_type); |
| |
| nir_def *src = |
| nir_vector_extract(b, nir_load_deref(b, slice), |
| nir_udiv_imm(b, index, info->packing_factor)); |
| |
| if (info->packing_factor == 1) { |
| return src; |
| } else { |
| return nir_vector_extract(b, |
| nir_unpack_bits(b, src, bits), |
| nir_umod_imm(b, index, info->packing_factor)); |
| } |
| |
| return NIR_LOWER_INSTR_PROGRESS_REPLACE; |
| } |
| |
| default: |
| unreachable("invalid cooperative matrix intrinsic"); |
| } |
| } |
| |
| static const glsl_type * |
| make_aoa_slice_type(const glsl_type *t, const glsl_type *slice_type) |
| { |
| if (glsl_type_is_array(t)) { |
| const glsl_type *s = make_aoa_slice_type(glsl_get_array_element(t), slice_type); |
| return glsl_array_type(s, glsl_array_size(t), 0); |
| } |
| |
| assert(glsl_type_is_cmat(t)); |
| return slice_type; |
| } |
| |
| static void |
| create_slice_var(struct lower_cmat_state *state, nir_variable *var, |
| nir_function_impl *impl) |
| { |
| const struct glsl_type *mat_type = glsl_without_array(var->type); |
| |
| assert(glsl_type_is_cmat(mat_type)); |
| assert((!impl && var->data.mode == nir_var_shader_temp) || |
| ( impl && var->data.mode == nir_var_function_temp)); |
| |
| slice_info *info = rzalloc(state->temp_ctx, slice_info); |
| init_slice_info(state, *glsl_get_cmat_description(mat_type), info); |
| |
| const glsl_type *aoa_slice_type = make_aoa_slice_type(var->type, info->type); |
| |
| const char *slice_name = ralloc_asprintf(state->shader, "%s_slice", var->name); |
| info->var = impl ? |
| nir_local_variable_create(impl, aoa_slice_type, slice_name) : |
| nir_variable_create(state->shader, var->data.mode, aoa_slice_type, slice_name); |
| |
| _mesa_hash_table_insert(state->mat_var_to_slice_info, var, info); |
| _mesa_hash_table_insert(state->slice_var_to_slice_info, info->var, info); |
| } |
| |
| bool |
| brw_nir_lower_cmat(nir_shader *shader, unsigned subgroup_size) |
| { |
| void *temp_ctx = ralloc_context(NULL); |
| |
| struct lower_cmat_state state = { |
| .temp_ctx = temp_ctx, |
| .shader = shader, |
| .slice_var_to_slice_info = _mesa_pointer_hash_table_create(temp_ctx), |
| .mat_var_to_slice_info = _mesa_pointer_hash_table_create(temp_ctx), |
| .subgroup_size = subgroup_size, |
| }; |
| |
| /* Create a slice array for each variable and add a map from the original |
| * variable back to it, so it can be reached during lowering. |
| * |
| * TODO: Cooperative matrix inside struct? |
| */ |
| nir_foreach_variable_in_shader(var, shader) { |
| if (glsl_type_is_cmat(glsl_without_array(var->type))) |
| create_slice_var(&state, var, NULL); |
| } |
| nir_foreach_function(func, shader) { |
| nir_foreach_function_temp_variable(var, func->impl) { |
| if (glsl_type_is_cmat(glsl_without_array(var->type))) |
| create_slice_var(&state, var, func->impl); |
| } |
| } |
| |
| bool progress = nir_shader_lower_instructions(shader, |
| lower_cmat_filter, |
| lower_cmat_instr, |
| &state); |
| |
| ralloc_free(temp_ctx); |
| |
| return progress; |
| } |