| /* Copyright 2019 Google LLC. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| #include <algorithm> |
| #include <cstdint> |
| |
| #include "ruy/check_macros.h" |
| #include "ruy/kernel_x86.h" |
| #include "ruy/opt_set.h" |
| #include "ruy/platform.h" |
| #include "ruy/profiler/instrumentation.h" |
| |
| #if RUY_PLATFORM_AVX512 && RUY_OPT(ASM) |
| #include <immintrin.h> // IWYU pragma: keep |
| #endif |
| |
| namespace ruy { |
| |
| #if !(RUY_PLATFORM_AVX512 && RUY_OPT(ASM)) |
| |
| void Kernel8bitAvx512(const KernelParams8bit<16, 16>&) { |
| // CPU-ID-based checks should disable the path that would reach this point. |
| RUY_DCHECK(false); |
| } |
| |
| void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>&) { |
| // CPU-ID-based checks should disable the path that would reach this point. |
| RUY_DCHECK(false); |
| } |
| |
| void KernelFloatAvx512(const KernelParamsFloat<16, 16>&) { |
| // CPU-ID-based checks should disable the path that would reach this point. |
| RUY_DCHECK(false); |
| } |
| |
| void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>&) { |
| // CPU-ID-based checks should disable the path that would reach this point. |
| RUY_DCHECK(false); |
| } |
| |
| #else // RUY_PLATFORM_AVX512 && RUY_OPT(ASM) |
| |
| void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) { |
| profiler::ScopeLabel label("Kernel kAvx512 8-bit"); |
| |
| std::int32_t dst_stride = 0; |
| if ((params.dst_type_id == DstTypeId<std::int8_t>::kValue) || |
| (params.dst_type_id == DstTypeId<std::uint8_t>::kValue)) { |
| dst_stride = params.dst_stride; |
| } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { |
| dst_stride = params.dst_stride / sizeof(std::int16_t); |
| } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { |
| dst_stride = params.dst_stride / sizeof(std::int32_t); |
| } else { |
| RUY_DCHECK(false); |
| } |
| |
| int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; |
| |
| const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; |
| void* dst_col_ptr = params.dst_base_ptr; |
| const std::int32_t* bias_col_ptr = params.bias; |
| if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { |
| bias_col_ptr += params.start_row; |
| } |
| |
| for (int col = params.start_col; col <= params.last_col; col += 16) { |
| const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
| void* dst_ptr = dst_col_ptr; |
| const std::int32_t* bias_ptr = bias_col_ptr; |
| |
| const std::int32_t lhs_zero_point = params.lhs_zero_point; |
| const bool has_rhs_sums_offsets = |
| (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; |
| std::int32_t rhs_sums_offsets[16]; |
| if (has_rhs_sums_offsets) { |
| const __m512i rhs_sums_offset_v = |
| _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), |
| _mm512_loadu_si512(¶ms.rhs_sums[col])); |
| _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), |
| rhs_sums_offset_v); |
| } |
| |
| for (int row = params.start_row; row <= params.last_row; row += 16) { |
| const int residual_rows = std::min(params.dst_rows - row, 16); |
| const int residual_cols = std::min(params.dst_cols - col, 16); |
| |
| __m512i accum_data_v0; |
| __m512i accum_data_v1; |
| __m512i accum_data_v2; |
| __m512i accum_data_v3; |
| __m512i accum_data_v4; |
| __m512i accum_data_v5; |
| __m512i accum_data_v6; |
| __m512i accum_data_v7; |
| __m512i accum_data_v8; |
| __m512i accum_data_v9; |
| __m512i accum_data_va; |
| __m512i accum_data_vb; |
| __m512i accum_data_vc; |
| __m512i accum_data_vd; |
| __m512i accum_data_ve; |
| __m512i accum_data_vf; |
| |
| // Initialize with bias. |
| const __mmask16 row_mask = |
| (static_cast<std::uint32_t>(1) << residual_rows) - 1; |
| __m512i initial_accum_data = _mm512_maskz_loadu_epi32(row_mask, bias_ptr); |
| bias_ptr += bias_ptr_block_increment; |
| |
| const std::int32_t rhs_zero_point = params.rhs_zero_point; |
| if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { |
| const __m512i lhs_sums_offset = |
| _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), |
| _mm512_loadu_si512(¶ms.lhs_sums[row])); |
| initial_accum_data = |
| _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); |
| } |
| |
| const std::int32_t prod_zp_depth = params.prod_zp_depth; |
| if (prod_zp_depth != 0) { |
| initial_accum_data = _mm512_add_epi32(initial_accum_data, |
| _mm512_set1_epi32(prod_zp_depth)); |
| } |
| |
| // Adjustments differing across columns. |
| if (has_rhs_sums_offsets) { |
| accum_data_v0 = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[0])); |
| accum_data_v1 = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[1])); |
| accum_data_v2 = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[2])); |
| accum_data_v3 = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[3])); |
| accum_data_v4 = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[4])); |
| accum_data_v5 = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[5])); |
| accum_data_v6 = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[6])); |
| accum_data_v7 = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[7])); |
| accum_data_v8 = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[8])); |
| accum_data_v9 = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[9])); |
| accum_data_va = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[10])); |
| accum_data_vb = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[11])); |
| accum_data_vc = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[12])); |
| accum_data_vd = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[13])); |
| accum_data_ve = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[14])); |
| accum_data_vf = _mm512_sub_epi32( |
| initial_accum_data, _mm512_set1_epi32(rhs_sums_offsets[15])); |
| } else { |
| accum_data_v0 = initial_accum_data; |
| accum_data_v1 = initial_accum_data; |
| accum_data_v2 = initial_accum_data; |
| accum_data_v3 = initial_accum_data; |
| accum_data_v4 = initial_accum_data; |
| accum_data_v5 = initial_accum_data; |
| accum_data_v6 = initial_accum_data; |
| accum_data_v7 = initial_accum_data; |
| accum_data_v8 = initial_accum_data; |
| accum_data_v9 = initial_accum_data; |
| accum_data_va = initial_accum_data; |
| accum_data_vb = initial_accum_data; |
| accum_data_vc = initial_accum_data; |
| accum_data_vd = initial_accum_data; |
| accum_data_ve = initial_accum_data; |
| accum_data_vf = initial_accum_data; |
| } |
| |
| const std::int8_t* lhs_ptr = lhs_col_ptr; |
| const std::int8_t* rhs_ptr = rhs_col_ptr; |
| for (int d = 0; d < params.depth; d += 4) { |
| const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr); |
| __m512i rhs_data_8bit = _mm512_loadu_si512(rhs_ptr); |
| |
| // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. |
| std::int32_t rhs_data[32]; |
| const __m256i rhs_data_bottom_lane = |
| _mm512_castsi512_si256(rhs_data_8bit); |
| const __m256i rhs_data_top_lane = |
| _mm512_extracti32x8_epi32(rhs_data_8bit, 1); |
| const __m512i rhs_16_bit_dup_low = |
| _mm512_cvtepi8_epi16(rhs_data_bottom_lane); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_cvtepi8_epi16(rhs_data_top_lane); |
| // Now that we have cast the RHS data, we store it so that each value |
| // can be separately loaded in the accumulation loop. |
| _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data), |
| rhs_16_bit_dup_low); |
| _mm512_storeu_si512(reinterpret_cast<__m256i*>(rhs_data + 16), |
| rhs_16_bit_dup_high); |
| |
| // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. |
| const __m512i lhs_16_bit_low = |
| _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); |
| // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. |
| const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( |
| _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); |
| |
| // Process column 0. |
| { |
| __m512i accum_v = accum_data_v0; |
| constexpr int index = 0; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_v0 = accum_v; |
| } |
| // Process column 1. |
| { |
| __m512i accum_v = accum_data_v1; |
| constexpr int index = 2; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_v1 = accum_v; |
| } |
| // Process column 2. |
| { |
| __m512i accum_v = accum_data_v2; |
| constexpr int index = 4; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_v2 = accum_v; |
| } |
| // Process column 3. |
| { |
| __m512i accum_v = accum_data_v3; |
| constexpr int index = 6; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_v3 = accum_v; |
| } |
| // Process column 4. |
| { |
| __m512i accum_v = accum_data_v4; |
| constexpr int index = 8; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_v4 = accum_v; |
| } |
| // Process column 5. |
| { |
| __m512i accum_v = accum_data_v5; |
| constexpr int index = 10; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_v5 = accum_v; |
| } |
| // Process column 6. |
| { |
| __m512i accum_v = accum_data_v6; |
| constexpr int index = 12; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_v6 = accum_v; |
| } |
| // Process column 7. |
| { |
| __m512i accum_v = accum_data_v7; |
| constexpr int index = 14; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_v7 = accum_v; |
| } |
| // Process column 8. |
| { |
| __m512i accum_v = accum_data_v8; |
| constexpr int index = 16; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_v8 = accum_v; |
| } |
| // Process column 9. |
| { |
| __m512i accum_v = accum_data_v9; |
| constexpr int index = 18; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_v9 = accum_v; |
| } |
| // Process column 10. |
| { |
| __m512i accum_v = accum_data_va; |
| constexpr int index = 20; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_va = accum_v; |
| } |
| // Process column 11. |
| { |
| __m512i accum_v = accum_data_vb; |
| constexpr int index = 22; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_vb = accum_v; |
| } |
| // Process column 12. |
| { |
| __m512i accum_v = accum_data_vc; |
| constexpr int index = 24; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_vc = accum_v; |
| } |
| // Process column 13. |
| { |
| __m512i accum_v = accum_data_vd; |
| constexpr int index = 26; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_vd = accum_v; |
| } |
| // Process column 14. |
| { |
| __m512i accum_v = accum_data_ve; |
| constexpr int index = 28; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_ve = accum_v; |
| } |
| // Process column 15. |
| { |
| __m512i accum_v = accum_data_vf; |
| constexpr int index = 30; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_vf = accum_v; |
| } |
| |
| lhs_ptr += 16 * 4; |
| rhs_ptr += 16 * 4; |
| } |
| |
| if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { |
| __m512i m_vector; |
| __m512i e_vector; |
| // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. |
| if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { |
| m_vector = _mm512_maskz_loadu_epi32( |
| row_mask, ¶ms.multiplier_fixedpoint[row]); |
| e_vector = _mm512_maskz_loadu_epi32(row_mask, |
| ¶ms.multiplier_exponent[row]); |
| } else { |
| // These arrays have size LhsCols, and are pre-filled. |
| m_vector = _mm512_set1_epi32(params.multiplier_fixedpoint[0]); |
| e_vector = _mm512_set1_epi32(params.multiplier_exponent[0]); |
| } |
| |
| const __m512i m_64bit_low = |
| _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); |
| const __m512i m_64bit_high = |
| _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); |
| |
| const __m512i zero_vector = _mm512_setzero_epi32(); |
| const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); |
| const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); |
| const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); |
| const __m512i final_right_shift = |
| _mm512_add_epi32(right_shift, _mm512_set1_epi32(31)); |
| const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(final_right_shift, 0)); |
| const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(final_right_shift, 1)); |
| |
| const __m512i offset_vector = |
| _mm512_slli_epi64(_mm512_set1_epi64(1), 30); |
| // Really these should be shifted by neg_e_vector, but tests pass when |
| // using right_shift. |
| const __m512i offset_vector_low = _mm512_sllv_epi64( |
| offset_vector, |
| _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0))); |
| const __m512i offset_vector_high = _mm512_sllv_epi64( |
| offset_vector, |
| _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1))); |
| |
| // Shift and round column 0. |
| { |
| accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v0, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v0, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_v0 = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_v0 = _mm512_inserti32x8( |
| accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| // Shift and round column 1. |
| { |
| accum_data_v1 = _mm512_sllv_epi32(accum_data_v1, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v1, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v1, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_v1 = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_v1 = _mm512_inserti32x8( |
| accum_data_v1, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| // Shift and round column 2. |
| { |
| accum_data_v2 = _mm512_sllv_epi32(accum_data_v2, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v2, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v2, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_v2 = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_v2 = _mm512_inserti32x8( |
| accum_data_v2, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| // Shift and round column 3. |
| { |
| accum_data_v3 = _mm512_sllv_epi32(accum_data_v3, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v3, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v3, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_v3 = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_v3 = _mm512_inserti32x8( |
| accum_data_v3, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| // Shift and round column 4. |
| { |
| accum_data_v4 = _mm512_sllv_epi32(accum_data_v4, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v4, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v4, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_v4 = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_v4 = _mm512_inserti32x8( |
| accum_data_v4, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| // Shift and round column 5. |
| { |
| accum_data_v5 = _mm512_sllv_epi32(accum_data_v5, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v5, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v5, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_v5 = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_v5 = _mm512_inserti32x8( |
| accum_data_v5, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| // Shift and round column 6. |
| { |
| accum_data_v6 = _mm512_sllv_epi32(accum_data_v6, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v6, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v6, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_v6 = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_v6 = _mm512_inserti32x8( |
| accum_data_v6, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| // Shift and round column 7. |
| { |
| accum_data_v7 = _mm512_sllv_epi32(accum_data_v7, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v7, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v7, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_v7 = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_v7 = _mm512_inserti32x8( |
| accum_data_v7, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| // Shift and round column 8. |
| { |
| accum_data_v8 = _mm512_sllv_epi32(accum_data_v8, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v8, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v8, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_v8 = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_v8 = _mm512_inserti32x8( |
| accum_data_v8, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| // Shift and round column 9. |
| { |
| accum_data_v9 = _mm512_sllv_epi32(accum_data_v9, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v9, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_v9, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_v9 = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_v9 = _mm512_inserti32x8( |
| accum_data_v9, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| // Shift and round column 10. |
| { |
| accum_data_va = _mm512_sllv_epi32(accum_data_va, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_va, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_va, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_va = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_va = _mm512_inserti32x8( |
| accum_data_va, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| // Shift and round column 11. |
| { |
| accum_data_vb = _mm512_sllv_epi32(accum_data_vb, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_vb, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_vb, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_vb = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_vb = _mm512_inserti32x8( |
| accum_data_vb, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| // Shift and round column 12. |
| { |
| accum_data_vc = _mm512_sllv_epi32(accum_data_vc, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_vc, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_vc, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_vc = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_vc = _mm512_inserti32x8( |
| accum_data_vc, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| // Shift and round column 13. |
| { |
| accum_data_vd = _mm512_sllv_epi32(accum_data_vd, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_vd, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_vd, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_vd = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_vd = _mm512_inserti32x8( |
| accum_data_vd, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| // Shift and round column 14. |
| { |
| accum_data_ve = _mm512_sllv_epi32(accum_data_ve, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_ve, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_ve, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_ve = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_ve = _mm512_inserti32x8( |
| accum_data_ve, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| // Shift and round column 15. |
| { |
| accum_data_vf = _mm512_sllv_epi32(accum_data_vf, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_vf, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = |
| _mm512_mul_epi32(_mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(accum_data_vf, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = |
| _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_vf = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_vf = _mm512_inserti32x8( |
| accum_data_vf, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| } |
| |
| if (params.dst_zero_point != 0) { |
| __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); |
| accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); |
| accum_data_v1 = _mm512_add_epi32(accum_data_v1, dst_zero_point); |
| accum_data_v2 = _mm512_add_epi32(accum_data_v2, dst_zero_point); |
| accum_data_v3 = _mm512_add_epi32(accum_data_v3, dst_zero_point); |
| accum_data_v4 = _mm512_add_epi32(accum_data_v4, dst_zero_point); |
| accum_data_v5 = _mm512_add_epi32(accum_data_v5, dst_zero_point); |
| accum_data_v6 = _mm512_add_epi32(accum_data_v6, dst_zero_point); |
| accum_data_v7 = _mm512_add_epi32(accum_data_v7, dst_zero_point); |
| accum_data_v8 = _mm512_add_epi32(accum_data_v8, dst_zero_point); |
| accum_data_v9 = _mm512_add_epi32(accum_data_v9, dst_zero_point); |
| accum_data_va = _mm512_add_epi32(accum_data_va, dst_zero_point); |
| accum_data_vb = _mm512_add_epi32(accum_data_vb, dst_zero_point); |
| accum_data_vc = _mm512_add_epi32(accum_data_vc, dst_zero_point); |
| accum_data_vd = _mm512_add_epi32(accum_data_vd, dst_zero_point); |
| accum_data_ve = _mm512_add_epi32(accum_data_ve, dst_zero_point); |
| accum_data_vf = _mm512_add_epi32(accum_data_vf, dst_zero_point); |
| } |
| } |
| |
| const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); |
| const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); |
| |
| const bool store_full_block = |
| (residual_rows == 16) && (residual_cols == 16); |
| |
| __m512i accum_data_v[16]; |
| |
| // In most cases we would make this conditional on (!store_full_block) and |
| // unwind the clamp-and-store loop, but the benefit appears small. |
| { |
| accum_data_v[0] = accum_data_v0; |
| accum_data_v[1] = accum_data_v1; |
| accum_data_v[2] = accum_data_v2; |
| accum_data_v[3] = accum_data_v3; |
| accum_data_v[4] = accum_data_v4; |
| accum_data_v[5] = accum_data_v5; |
| accum_data_v[6] = accum_data_v6; |
| accum_data_v[7] = accum_data_v7; |
| accum_data_v[8] = accum_data_v8; |
| accum_data_v[9] = accum_data_v9; |
| accum_data_v[10] = accum_data_va; |
| accum_data_v[11] = accum_data_vb; |
| accum_data_v[12] = accum_data_vc; |
| accum_data_v[13] = accum_data_vd; |
| accum_data_v[14] = accum_data_ve; |
| accum_data_v[15] = accum_data_vf; |
| } |
| |
| if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { |
| std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); |
| const int block_col_offset = dst_stride; |
| if (store_full_block) { |
| for (int j = 0; j < 16; ++j) { |
| __m512i result = accum_data_v[j]; |
| result = _mm512_min_epi32(result, clamp_max_v); |
| result = _mm512_max_epi32(result, clamp_min_v); |
| _mm_storeu_si128( |
| reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset), |
| _mm512_cvtepi32_epi8(result)); |
| } |
| } else { |
| for (int j = 0; j < residual_cols; ++j) { |
| __m512i result = accum_data_v[j]; |
| result = _mm512_min_epi32(result, clamp_max_v); |
| result = _mm512_max_epi32(result, clamp_min_v); |
| _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, |
| _mm512_cvtepi32_epi8(result)); |
| } |
| } |
| dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16); |
| } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { |
| std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); |
| const int block_col_offset = dst_stride; |
| if (store_full_block) { |
| for (int j = 0; j < residual_cols; ++j) { |
| __m512i result = accum_data_v[j]; |
| result = _mm512_min_epi32(result, clamp_max_v); |
| result = _mm512_max_epi32(result, clamp_min_v); |
| _mm_storeu_si128( |
| reinterpret_cast<__m128i*>(tmp_ptr + j * block_col_offset), |
| _mm512_cvtepi32_epi8(result)); |
| } |
| } else { |
| for (int j = 0; j < residual_cols; ++j) { |
| __m512i result = accum_data_v[j]; |
| result = _mm512_min_epi32(result, clamp_max_v); |
| result = _mm512_max_epi32(result, clamp_min_v); |
| _mm_mask_storeu_epi8(tmp_ptr + j * block_col_offset, row_mask, |
| _mm512_cvtepi32_epi8(result)); |
| } |
| } |
| dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16); |
| } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { |
| std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); |
| const int block_col_offset = dst_stride; |
| if (store_full_block) { |
| for (int j = 0; j < 16; ++j) { |
| __m512i result = accum_data_v[j]; |
| result = _mm512_min_epi32(result, clamp_max_v); |
| result = _mm512_max_epi32(result, clamp_min_v); |
| _mm256_storeu_si256( |
| reinterpret_cast<__m256i*>(tmp_ptr + j * block_col_offset), |
| _mm512_cvtepi32_epi16(result)); |
| } |
| } else { |
| for (int j = 0; j < residual_cols; ++j) { |
| __m512i result = accum_data_v[j]; |
| result = _mm512_min_epi32(result, clamp_max_v); |
| result = _mm512_max_epi32(result, clamp_min_v); |
| _mm256_mask_storeu_epi16(tmp_ptr + j * block_col_offset, row_mask, |
| _mm512_cvtepi32_epi16(result)); |
| } |
| } |
| dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16); |
| } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { |
| if (store_full_block) { |
| std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); |
| for (int j = 0; j < 16; ++j) { |
| _mm512_storeu_si512(tmp_ptr + j * dst_stride, accum_data_v[j]); |
| } |
| } else { |
| std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); |
| for (int j = 0; j < residual_cols; ++j) { |
| _mm512_mask_storeu_epi32(tmp_ptr + j * dst_stride, row_mask, |
| accum_data_v[j]); |
| } |
| } |
| dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16); |
| } else { |
| RUY_DCHECK(false); |
| } |
| |
| lhs_col_ptr += 16 * params.lhs_stride; |
| } // End row-block loop. |
| |
| dst_col_ptr = static_cast<void*>(static_cast<char*>(dst_col_ptr) + |
| 16 * params.dst_stride); |
| rhs_col_ptr += 16 * params.rhs_stride; |
| } // End col-block loop. |
| } // NOLINT(readability/fn_size) |
| |
| void Kernel8bitAvx512SingleCol(const KernelParams8bit<16, 16>& params) { |
| profiler::ScopeLabel label("Kernel kAvx512 8-bit GEMV"); |
| |
| RUY_DCHECK_EQ(params.dst_cols, 1); |
| RUY_DCHECK_EQ(params.last_col, 0); |
| RUY_DCHECK_EQ(params.start_col, 0); |
| |
| int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 16 : 0; |
| |
| const std::int8_t* rhs_col_ptr = params.rhs_base_ptr; |
| void* dst_col_ptr = params.dst_base_ptr; |
| const std::int32_t* bias_col_ptr = params.bias; |
| if (params.flags & RUY_ASM_FLAG_HAS_BIAS) { |
| bias_col_ptr += params.start_row; |
| } |
| |
| const std::int8_t* lhs_col_ptr = params.lhs_base_ptr; |
| void* dst_ptr = dst_col_ptr; |
| const std::int32_t* bias_ptr = bias_col_ptr; |
| |
| const std::int32_t lhs_zero_point = params.lhs_zero_point; |
| const bool has_rhs_sums_offsets = |
| (params.flags & RUY_ASM_FLAG_HAS_RHS_SUMS) && lhs_zero_point; |
| std::int32_t rhs_sums_offsets[16]; |
| if (has_rhs_sums_offsets) { |
| const __m512i rhs_sums_offset_v = |
| _mm512_mullo_epi32(_mm512_set1_epi32(lhs_zero_point), |
| _mm512_loadu_si512(¶ms.rhs_sums[0])); |
| _mm512_storeu_si512(reinterpret_cast<__m512i*>(rhs_sums_offsets), |
| rhs_sums_offset_v); |
| } |
| |
| for (int row = params.start_row; row <= params.last_row; row += 16) { |
| const int residual_rows = std::min(params.dst_rows - row, 16); |
| |
| __m512i accum_data_v0; |
| |
| // Initialize with bias. |
| const __mmask16 row_mask = |
| (static_cast<std::uint32_t>(1) << residual_rows) - 1; |
| __m512i initial_accum_data = _mm512_maskz_loadu_epi32(row_mask, bias_ptr); |
| bias_ptr += bias_ptr_block_increment; |
| |
| const std::int32_t rhs_zero_point = params.rhs_zero_point; |
| if ((params.flags & RUY_ASM_FLAG_HAS_LHS_SUMS) && rhs_zero_point) { |
| const __m512i lhs_sums_offset = |
| _mm512_mullo_epi32(_mm512_set1_epi32(rhs_zero_point), |
| _mm512_loadu_si512(¶ms.lhs_sums[row])); |
| initial_accum_data = |
| _mm512_sub_epi32(initial_accum_data, lhs_sums_offset); |
| } |
| |
| const std::int32_t prod_zp_depth = params.prod_zp_depth; |
| if (prod_zp_depth != 0) { |
| initial_accum_data = _mm512_add_epi32(initial_accum_data, |
| _mm512_set1_epi32(prod_zp_depth)); |
| } |
| |
| // Adjustments differing across columns. |
| if (has_rhs_sums_offsets) { |
| accum_data_v0 = _mm512_sub_epi32(initial_accum_data, |
| _mm512_set1_epi32(rhs_sums_offsets[0])); |
| } else { |
| accum_data_v0 = initial_accum_data; |
| } |
| |
| const std::int8_t* lhs_ptr = lhs_col_ptr; |
| const std::int8_t* rhs_ptr = rhs_col_ptr; |
| for (int d = 0; d < params.depth; d += 4) { |
| const __m512i lhs_data = _mm512_loadu_si512(lhs_ptr); |
| const __m128i rhs_data_8bit = |
| _mm_loadu_si128(reinterpret_cast<const __m128i*>(rhs_ptr)); |
| |
| // Each "int32" is two 16-bit RHS values, sign extended from 8-bit. |
| // For simplicity we load 4x the data that we need and process twice the |
| // data that we need and store only the data we need. |
| std::int32_t rhs_data[2]; |
| const __m128i rhs_16_bit_dup = _mm_cvtepi8_epi16(rhs_data_8bit); |
| // Now that we have cast the RHS data, we store it so that each value |
| // can be separately loaded in the accumulation loop. |
| _mm_storeu_si64(reinterpret_cast<__m128i*>(rhs_data), rhs_16_bit_dup); |
| |
| // Take bytes 0, 1, 4, 5, 8, 9, ... and expand to 16-bit. |
| const __m512i lhs_16_bit_low = |
| _mm512_cvtepi8_epi16(_mm512_cvtepi32_epi16(lhs_data)); |
| // Take bytes 2, 3, 6, 7, 10, 11, ... and expand to 16-bit. |
| const __m512i lhs_16_bit_high = _mm512_cvtepi8_epi16( |
| _mm512_cvtepi32_epi16(_mm512_srli_epi32(lhs_data, 16))); |
| |
| // Process column 0. |
| __m512i accum_v = accum_data_v0; |
| constexpr int index = 0; |
| |
| const __m512i rhs_16_bit_dup_low = _mm512_set1_epi32(rhs_data[index]); |
| const __m512i rhs_16_bit_dup_high = |
| _mm512_set1_epi32(rhs_data[index + 1]); |
| |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low)); |
| accum_v = _mm512_add_epi32( |
| accum_v, _mm512_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high)); |
| accum_data_v0 = accum_v; |
| |
| lhs_ptr += 16 * 4; |
| rhs_ptr += 16 * 4; |
| } |
| |
| if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) { |
| __m512i m_vector; |
| __m512i e_vector; |
| // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT. |
| if (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) { |
| m_vector = _mm512_maskz_loadu_epi32(row_mask, |
| ¶ms.multiplier_fixedpoint[row]); |
| e_vector = _mm512_maskz_loadu_epi32(row_mask, |
| ¶ms.multiplier_exponent[row]); |
| } else { |
| // These arrays have size LhsCols, and are pre-filled. |
| m_vector = _mm512_set1_epi32(params.multiplier_fixedpoint[0]); |
| e_vector = _mm512_set1_epi32(params.multiplier_exponent[0]); |
| } |
| |
| const __m512i m_64bit_low = |
| _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 0)); |
| const __m512i m_64bit_high = |
| _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(m_vector, 1)); |
| |
| const __m512i zero_vector = _mm512_setzero_epi32(); |
| const __m512i left_shift = _mm512_max_epi32(e_vector, zero_vector); |
| const __m512i neg_e_vector = _mm512_sub_epi32(zero_vector, e_vector); |
| const __m512i right_shift = _mm512_max_epi32(neg_e_vector, zero_vector); |
| const __m512i final_right_shift = |
| _mm512_add_epi32(right_shift, _mm512_set1_epi32(31)); |
| const __m512i final_right_shift_low = _mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(final_right_shift, 0)); |
| const __m512i final_right_shift_high = _mm512_cvtepi32_epi64( |
| _mm512_extracti32x8_epi32(final_right_shift, 1)); |
| |
| const __m512i offset_vector = _mm512_slli_epi64(_mm512_set1_epi64(1), 30); |
| // Really these should be shifted by neg_e_vector, but tests pass when |
| // using right_shift. |
| const __m512i offset_vector_low = _mm512_sllv_epi64( |
| offset_vector, |
| _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 0))); |
| const __m512i offset_vector_high = _mm512_sllv_epi64( |
| offset_vector, |
| _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1))); |
| |
| // Shift and round column 0. |
| accum_data_v0 = _mm512_sllv_epi32(accum_data_v0, left_shift); |
| // Apply the fixed-point part of the multiplier. |
| __m512i scaled_v_low = _mm512_mul_epi32( |
| _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 0)), |
| m_64bit_low); |
| __m512i scaled_v_high = _mm512_mul_epi32( |
| _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum_data_v0, 1)), |
| m_64bit_high); |
| |
| scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low); |
| scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high); |
| |
| scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low); |
| scaled_v_high = _mm512_srav_epi64(scaled_v_high, final_right_shift_high); |
| |
| accum_data_v0 = |
| _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low)); |
| accum_data_v0 = _mm512_inserti32x8( |
| accum_data_v0, _mm512_cvtepi64_epi32(scaled_v_high), 1); |
| |
| if (params.dst_zero_point != 0) { |
| __m512i dst_zero_point = _mm512_set1_epi32(params.dst_zero_point); |
| accum_data_v0 = _mm512_add_epi32(accum_data_v0, dst_zero_point); |
| } |
| } |
| |
| const __m512i clamp_max_v = _mm512_set1_epi32(params.clamp_max); |
| const __m512i clamp_min_v = _mm512_set1_epi32(params.clamp_min); |
| |
| if (params.dst_type_id == DstTypeId<std::int8_t>::kValue) { |
| std::int8_t* tmp_ptr = static_cast<std::int8_t*>(dst_ptr); |
| __m512i result = accum_data_v0; |
| result = _mm512_min_epi32(result, clamp_max_v); |
| result = _mm512_max_epi32(result, clamp_min_v); |
| _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); |
| dst_ptr = static_cast<void*>(static_cast<std::int8_t*>(dst_ptr) + 16); |
| } else if (params.dst_type_id == DstTypeId<std::uint8_t>::kValue) { |
| std::uint8_t* tmp_ptr = static_cast<std::uint8_t*>(dst_ptr); |
| __m512i result = accum_data_v0; |
| result = _mm512_min_epi32(result, clamp_max_v); |
| result = _mm512_max_epi32(result, clamp_min_v); |
| _mm_mask_storeu_epi8(tmp_ptr, row_mask, _mm512_cvtepi32_epi8(result)); |
| dst_ptr = static_cast<void*>(static_cast<std::uint8_t*>(dst_ptr) + 16); |
| } else if (params.dst_type_id == DstTypeId<std::int16_t>::kValue) { |
| std::int16_t* tmp_ptr = static_cast<std::int16_t*>(dst_ptr); |
| __m512i result = accum_data_v0; |
| result = _mm512_min_epi32(result, clamp_max_v); |
| result = _mm512_max_epi32(result, clamp_min_v); |
| _mm256_mask_storeu_epi16(tmp_ptr, row_mask, |
| _mm512_cvtepi32_epi16(result)); |
| dst_ptr = static_cast<void*>(static_cast<std::int16_t*>(dst_ptr) + 16); |
| } else if (params.dst_type_id == DstTypeId<std::int32_t>::kValue) { |
| std::int32_t* tmp_ptr = static_cast<std::int32_t*>(dst_ptr); |
| _mm512_mask_storeu_epi32(tmp_ptr, row_mask, accum_data_v0); |
| dst_ptr = static_cast<void*>(static_cast<std::int32_t*>(dst_ptr) + 16); |
| } else { |
| RUY_DCHECK(false); |
| } |
| |
| lhs_col_ptr += 16 * params.lhs_stride; |
| } // End row-block loop. |
| } // NOLINT(readability/fn_size) |
| |
| void KernelFloatAvx512(const KernelParamsFloat<16, 16>& params) { |
| profiler::ScopeLabel label("Kernel kAvx512 float"); |
| |
| // As parameters are defined, we need to scale by sizeof(float). |
| const std::int64_t lhs_stride = params.lhs_stride >> 2; |
| const std::int64_t dst_stride = params.dst_stride >> 2; |
| const std::int64_t rhs_stride = params.rhs_stride >> 2; |
| |
| int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; |
| const int end_row = std::min(params.dst_rows, params.last_row + 16); |
| const int end_col = std::min(params.dst_cols, params.last_col + 16); |
| |
| const float* adj_rhs_col_ptr = |
| params.rhs_base_ptr - params.start_col * rhs_stride; |
| float* adj_dst_col_ptr = |
| params.dst_base_ptr - params.start_col * dst_stride - params.start_row; |
| const float* adj_lhs_col_ptr = |
| params.lhs_base_ptr - params.start_row * lhs_stride; |
| const float* bias_col_ptr = params.bias; |
| |
| const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); |
| const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); |
| |
| int col = params.start_col; |
| for (; col <= end_col - 16; col += 16) { |
| const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; |
| float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; |
| |
| int row = params.start_row; |
| for (; row <= end_row - 16; row += 16) { |
| const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; |
| float* dst_ptr = dst_col_ptr + row; |
| const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; |
| |
| // Initialize with bias. |
| const __m512 initial_accum_data = _mm512_loadu_ps(bias_ptr); |
| |
| // Process block in two halves, split by columns. |
| { |
| constexpr int mmm = 0; |
| |
| __m512 accum_data_v0 = initial_accum_data; |
| __m512 accum_data_v1 = initial_accum_data; |
| __m512 accum_data_v2 = initial_accum_data; |
| __m512 accum_data_v3 = initial_accum_data; |
| __m512 accum_data_v4 = initial_accum_data; |
| __m512 accum_data_v5 = initial_accum_data; |
| __m512 accum_data_v6 = initial_accum_data; |
| __m512 accum_data_v7 = initial_accum_data; |
| |
| const float* lhs_ptr = lhs_col_ptr; |
| const float* rhs_ptr = rhs_col_ptr + 8 * mmm; |
| for (int d = 0; d < (params.depth - 1); ++d) { |
| const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
| // In this version RHS values are loaded individually rather than |
| // first loading together and then extract with broadcasting. This is |
| // because AVX flavours and instrinsics and compilers in combination |
| // do not handle this pattern of extraction very well. |
| const float* rhs_data = rhs_ptr; |
| lhs_ptr += 16; |
| rhs_ptr += 16; |
| |
| { |
| const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); |
| accum_data_v0 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); |
| const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); |
| accum_data_v1 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); |
| const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); |
| accum_data_v2 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); |
| const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); |
| accum_data_v3 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); |
| const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); |
| accum_data_v4 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); |
| const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); |
| accum_data_v5 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); |
| const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); |
| accum_data_v6 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); |
| const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); |
| accum_data_v7 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); |
| } |
| } |
| { |
| const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
| const float* rhs_data = rhs_ptr; |
| { |
| const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); |
| accum_data_v0 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); |
| const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); |
| accum_data_v1 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); |
| const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); |
| accum_data_v2 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); |
| const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); |
| accum_data_v3 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); |
| const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); |
| accum_data_v4 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); |
| const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); |
| accum_data_v5 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); |
| const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); |
| accum_data_v6 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); |
| const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); |
| accum_data_v7 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); |
| } |
| { |
| float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; |
| accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); |
| accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); |
| accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); |
| accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); |
| accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); |
| accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); |
| accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); |
| accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); |
| accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); |
| accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); |
| accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); |
| accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); |
| accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); |
| accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); |
| accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); |
| accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); |
| } |
| } |
| } // Inner half-block loop, unrolled, first iteration. |
| { |
| constexpr int mmm = 1; |
| |
| __m512 accum_data_v0 = initial_accum_data; |
| __m512 accum_data_v1 = initial_accum_data; |
| __m512 accum_data_v2 = initial_accum_data; |
| __m512 accum_data_v3 = initial_accum_data; |
| __m512 accum_data_v4 = initial_accum_data; |
| __m512 accum_data_v5 = initial_accum_data; |
| __m512 accum_data_v6 = initial_accum_data; |
| __m512 accum_data_v7 = initial_accum_data; |
| |
| const float* lhs_ptr = lhs_col_ptr; |
| const float* rhs_ptr = rhs_col_ptr + 8 * mmm; |
| for (int d = 0; d < (params.depth - 1); ++d) { |
| const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
| const float* rhs_data = rhs_ptr; |
| lhs_ptr += 16; |
| rhs_ptr += 16; |
| { |
| const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); |
| accum_data_v0 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); |
| const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); |
| accum_data_v1 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); |
| const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); |
| accum_data_v2 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); |
| const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); |
| accum_data_v3 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); |
| const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); |
| accum_data_v4 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); |
| const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); |
| accum_data_v5 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); |
| const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); |
| accum_data_v6 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); |
| const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); |
| accum_data_v7 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); |
| } |
| } |
| { |
| const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
| const float* rhs_data = rhs_ptr; |
| { |
| const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); |
| accum_data_v0 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); |
| const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); |
| accum_data_v1 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); |
| const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); |
| accum_data_v2 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); |
| const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); |
| accum_data_v3 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); |
| const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); |
| accum_data_v4 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); |
| const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); |
| accum_data_v5 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); |
| const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); |
| accum_data_v6 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); |
| const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); |
| accum_data_v7 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); |
| } |
| { |
| float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; |
| accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); |
| accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 0 * dst_stride, accum_data_v0); |
| accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); |
| accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 1 * dst_stride, accum_data_v1); |
| accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); |
| accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 2 * dst_stride, accum_data_v2); |
| accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); |
| accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 3 * dst_stride, accum_data_v3); |
| accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); |
| accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 4 * dst_stride, accum_data_v4); |
| accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); |
| accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 5 * dst_stride, accum_data_v5); |
| accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); |
| accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 6 * dst_stride, accum_data_v6); |
| accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); |
| accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); |
| _mm512_storeu_ps(block_ptr + 7 * dst_stride, accum_data_v7); |
| } |
| } |
| } // Inner half-block loop, unrolled, second iteration. |
| } // End row-block loop. |
| |
| // The unrolling within this conditional may be somewhat pointless. It |
| // depends on the kinds of models. |
| if (row < end_row) { |
| const int residual_rows = end_row - row; |
| |
| const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; |
| float* dst_ptr = dst_col_ptr + row; |
| const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; |
| |
| // Initialize with bias. |
| const __mmask16 row_mask = |
| (static_cast<std::uint32_t>(1) << residual_rows) - 1; |
| const __m512 initial_accum_data = |
| _mm512_maskz_loadu_ps(row_mask, bias_ptr); |
| |
| // Process block in two halves, split by columns. |
| for (int mmm = 0; mmm < 2; ++mmm) { |
| __m512 accum_data_v0 = initial_accum_data; |
| __m512 accum_data_v1 = initial_accum_data; |
| __m512 accum_data_v2 = initial_accum_data; |
| __m512 accum_data_v3 = initial_accum_data; |
| __m512 accum_data_v4 = initial_accum_data; |
| __m512 accum_data_v5 = initial_accum_data; |
| __m512 accum_data_v6 = initial_accum_data; |
| __m512 accum_data_v7 = initial_accum_data; |
| |
| const float* lhs_ptr = lhs_col_ptr; |
| const float* rhs_ptr = rhs_col_ptr + 8 * mmm; |
| for (int d = 0; d < (params.depth - 1); ++d) { |
| const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
| const float* rhs_data = rhs_ptr; |
| lhs_ptr += 16; |
| rhs_ptr += 16; |
| { |
| const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); |
| accum_data_v0 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); |
| const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); |
| accum_data_v1 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); |
| const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); |
| accum_data_v2 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); |
| const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); |
| accum_data_v3 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); |
| const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); |
| accum_data_v4 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); |
| const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); |
| accum_data_v5 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); |
| const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); |
| accum_data_v6 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); |
| const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); |
| accum_data_v7 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); |
| } |
| } |
| { |
| const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
| const float* rhs_data = rhs_ptr; |
| { |
| const __m512 dup_rhs_element_j0 = _mm512_set1_ps(rhs_data[0]); |
| accum_data_v0 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j0, accum_data_v0); |
| const __m512 dup_rhs_element_j1 = _mm512_set1_ps(rhs_data[1]); |
| accum_data_v1 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j1, accum_data_v1); |
| const __m512 dup_rhs_element_j2 = _mm512_set1_ps(rhs_data[2]); |
| accum_data_v2 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j2, accum_data_v2); |
| const __m512 dup_rhs_element_j3 = _mm512_set1_ps(rhs_data[3]); |
| accum_data_v3 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j3, accum_data_v3); |
| const __m512 dup_rhs_element_j4 = _mm512_set1_ps(rhs_data[4]); |
| accum_data_v4 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j4, accum_data_v4); |
| const __m512 dup_rhs_element_j5 = _mm512_set1_ps(rhs_data[5]); |
| accum_data_v5 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j5, accum_data_v5); |
| const __m512 dup_rhs_element_j6 = _mm512_set1_ps(rhs_data[6]); |
| accum_data_v6 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j6, accum_data_v6); |
| const __m512 dup_rhs_element_j7 = _mm512_set1_ps(rhs_data[7]); |
| accum_data_v7 = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j7, accum_data_v7); |
| } |
| { |
| float* block_ptr = dst_ptr + (mmm * 8 + 0) * dst_stride; |
| accum_data_v0 = _mm512_min_ps(accum_data_v0, clamp_max_v); |
| accum_data_v0 = _mm512_max_ps(accum_data_v0, clamp_min_v); |
| _mm512_mask_storeu_ps(block_ptr + 0 * dst_stride, row_mask, |
| accum_data_v0); |
| accum_data_v1 = _mm512_min_ps(accum_data_v1, clamp_max_v); |
| accum_data_v1 = _mm512_max_ps(accum_data_v1, clamp_min_v); |
| _mm512_mask_storeu_ps(block_ptr + 1 * dst_stride, row_mask, |
| accum_data_v1); |
| accum_data_v2 = _mm512_min_ps(accum_data_v2, clamp_max_v); |
| accum_data_v2 = _mm512_max_ps(accum_data_v2, clamp_min_v); |
| _mm512_mask_storeu_ps(block_ptr + 2 * dst_stride, row_mask, |
| accum_data_v2); |
| accum_data_v3 = _mm512_min_ps(accum_data_v3, clamp_max_v); |
| accum_data_v3 = _mm512_max_ps(accum_data_v3, clamp_min_v); |
| _mm512_mask_storeu_ps(block_ptr + 3 * dst_stride, row_mask, |
| accum_data_v3); |
| accum_data_v4 = _mm512_min_ps(accum_data_v4, clamp_max_v); |
| accum_data_v4 = _mm512_max_ps(accum_data_v4, clamp_min_v); |
| _mm512_mask_storeu_ps(block_ptr + 4 * dst_stride, row_mask, |
| accum_data_v4); |
| accum_data_v5 = _mm512_min_ps(accum_data_v5, clamp_max_v); |
| accum_data_v5 = _mm512_max_ps(accum_data_v5, clamp_min_v); |
| _mm512_mask_storeu_ps(block_ptr + 5 * dst_stride, row_mask, |
| accum_data_v5); |
| accum_data_v6 = _mm512_min_ps(accum_data_v6, clamp_max_v); |
| accum_data_v6 = _mm512_max_ps(accum_data_v6, clamp_min_v); |
| _mm512_mask_storeu_ps(block_ptr + 6 * dst_stride, row_mask, |
| accum_data_v6); |
| accum_data_v7 = _mm512_min_ps(accum_data_v7, clamp_max_v); |
| accum_data_v7 = _mm512_max_ps(accum_data_v7, clamp_min_v); |
| _mm512_mask_storeu_ps(block_ptr + 7 * dst_stride, row_mask, |
| accum_data_v7); |
| } |
| } |
| } // Inner half-block loop. |
| } // Residual rows, main col-block loop. |
| } // End col-block loop. |
| |
| if (col < end_col) { |
| RUY_DCHECK_GE(end_col - col, 0); |
| RUY_DCHECK_LT(end_col - col, 16); |
| |
| __m512 accum_data_v[8]; |
| |
| const float* rhs_col_ptr = adj_rhs_col_ptr + col * rhs_stride; |
| float* dst_col_ptr = adj_dst_col_ptr + col * dst_stride; |
| |
| for (int row = params.start_row; row < end_row; row += 16) { |
| const int residual_rows = std::min(end_row - row, 16); |
| |
| const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; |
| float* dst_ptr = dst_col_ptr + row; |
| const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; |
| |
| // Initialize with bias. |
| const __mmask16 row_mask = |
| (static_cast<std::uint32_t>(1) << residual_rows) - 1; |
| const __m512 initial_accum_data = |
| _mm512_maskz_loadu_ps(row_mask, bias_ptr); |
| |
| // Process block in two halves, split by columns. |
| for (int mmm = 0; mmm < 2; ++mmm) { |
| for (int j = 0; j < 8; ++j) { |
| accum_data_v[j] = initial_accum_data; |
| } |
| |
| const float* lhs_ptr = lhs_col_ptr; |
| const float* rhs_ptr = rhs_col_ptr + 8 * mmm; |
| for (int d = 0; d < params.depth; ++d) { |
| const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
| const float* rhs_data = rhs_ptr; |
| |
| for (int j = 0; j < 8; ++j) { |
| const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data[j]); |
| accum_data_v[j] = |
| _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v[j]); |
| } |
| lhs_ptr += 16; |
| rhs_ptr += 16; |
| } |
| |
| const int residual_cols = std::min(end_col - col - 8 * mmm, 8); |
| |
| if (residual_rows == 16) { |
| if (residual_cols == 8) { |
| for (int j = 0; j < 8; ++j) { |
| float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; |
| accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); |
| accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); |
| _mm512_storeu_ps(block_ptr, accum_data_v[j]); |
| } |
| } else { |
| for (int j = 0; j < residual_cols; ++j) { |
| float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; |
| accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); |
| accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); |
| _mm512_storeu_ps(block_ptr, accum_data_v[j]); |
| } |
| } |
| } else { |
| for (int j = 0; j < residual_cols; ++j) { |
| float* block_ptr = dst_ptr + (mmm * 8 + j) * dst_stride; |
| accum_data_v[j] = _mm512_min_ps(accum_data_v[j], clamp_max_v); |
| accum_data_v[j] = _mm512_max_ps(accum_data_v[j], clamp_min_v); |
| _mm512_mask_storeu_ps(block_ptr, row_mask, accum_data_v[j]); |
| } |
| } |
| } // Inner half-block loop. |
| } // End row-block loop. |
| } // Residual cols. |
| } |
| |
| void KernelFloatAvx512SingleCol(const KernelParamsFloat<16, 16>& params) { |
| profiler::ScopeLabel label("Kernel kAvx512 float GEMV"); |
| |
| RUY_DCHECK_EQ(params.dst_cols, 1); |
| RUY_DCHECK_EQ(params.last_col, 0); |
| RUY_DCHECK_EQ(params.start_col, 0); |
| |
| // As parameters are defined, we need to scale by sizeof(float). |
| const std::int64_t lhs_stride = params.lhs_stride >> 2; |
| |
| int bias_ptr_block_increment = params.flags & RUY_ASM_FLAG_HAS_BIAS ? 1 : 0; |
| const int end_row = std::min(params.dst_rows, params.last_row + 16); |
| |
| float* adj_dst_col_ptr = params.dst_base_ptr - params.start_row; |
| const float* adj_lhs_col_ptr = |
| params.lhs_base_ptr - params.start_row * lhs_stride; |
| const float* bias_col_ptr = params.bias; |
| |
| const __m512 clamp_max_v = _mm512_set1_ps(params.clamp_max); |
| const __m512 clamp_min_v = _mm512_set1_ps(params.clamp_min); |
| |
| __m512 accum_data_v; |
| |
| const float* rhs_col_ptr = params.rhs_base_ptr; |
| float* dst_col_ptr = adj_dst_col_ptr; |
| |
| int row = params.start_row; |
| for (; row <= end_row - 16; row += 16) { |
| const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; |
| float* dst_ptr = dst_col_ptr + row; |
| const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; |
| |
| // Initialize with bias. |
| accum_data_v = _mm512_loadu_ps(bias_ptr); |
| |
| const float* lhs_ptr = lhs_col_ptr; |
| const float* rhs_ptr = rhs_col_ptr; |
| for (int d = 0; d < params.depth; ++d) { |
| const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
| const float rhs_data = *rhs_ptr; |
| |
| const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); |
| accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); |
| lhs_ptr += 16; |
| rhs_ptr += 16; |
| } |
| |
| accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); |
| accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); |
| _mm512_storeu_ps(dst_ptr, accum_data_v); |
| } // End row-block loop. |
| |
| if (row < end_row) { |
| const int residual_rows = end_row - row; |
| RUY_CHECK_GE(residual_rows, 1); |
| RUY_CHECK_LT(residual_rows, 16); |
| |
| const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride; |
| float* dst_ptr = dst_col_ptr + row; |
| const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment; |
| |
| // Initialize with bias. |
| const __mmask16 row_mask = |
| (static_cast<std::uint32_t>(1) << residual_rows) - 1; |
| accum_data_v = _mm512_maskz_loadu_ps(row_mask, bias_ptr); |
| |
| const float* lhs_ptr = lhs_col_ptr; |
| const float* rhs_ptr = rhs_col_ptr; |
| for (int d = 0; d < params.depth; ++d) { |
| const __m512 lhs_data = _mm512_loadu_ps(lhs_ptr); |
| const float rhs_data = *rhs_ptr; |
| |
| const __m512 dup_rhs_element_j = _mm512_set1_ps(rhs_data); |
| accum_data_v = _mm512_fmadd_ps(lhs_data, dup_rhs_element_j, accum_data_v); |
| lhs_ptr += 16; |
| rhs_ptr += 16; |
| } |
| |
| accum_data_v = _mm512_min_ps(accum_data_v, clamp_max_v); |
| accum_data_v = _mm512_max_ps(accum_data_v, clamp_min_v); |
| _mm512_mask_storeu_ps(dst_ptr, row_mask, accum_data_v); |
| } // End handling of residual rows. |
| } |
| |
| #endif // RUY_PLATFORM_AVX512 && RUY_OPT(ASM) |
| |
| } // namespace ruy |