Efficient support for any channel_dimension for float kernels on AVX2.
Also, restricting x86 single-column kernels to the case of channel_dimension==kRow, similar to what we did for ARM kernels.
PiperOrigin-RevId: 320973990
diff --git a/ruy/create_trmul_params.cc b/ruy/create_trmul_params.cc
index 032b357..619c064 100644
--- a/ruy/create_trmul_params.cc
+++ b/ruy/create_trmul_params.cc
@@ -58,6 +58,13 @@
return false;
#endif
+#if RUY_PLATFORM_X86
+ if (src[Side::kLhs].data_type == Type::Create<float>() &&
+ path == Path::kAvx2) {
+ return false;
+ }
+#endif
+
// Ruy's optimized kernels currently only support the channel_dimension==kRow
// case.
if (channel_dimension != ChannelDimension::kRow) {
diff --git a/ruy/kernel_avx2.cc b/ruy/kernel_avx2.cc
index 5d34e6b..60c6e6b 100644
--- a/ruy/kernel_avx2.cc
+++ b/ruy/kernel_avx2.cc
@@ -18,6 +18,7 @@
#include <cstring>
#include "ruy/check_macros.h"
+#include "ruy/kernel_common.h"
#include "ruy/kernel_x86.h"
#include "ruy/opt_set.h"
#include "ruy/platform.h"
@@ -1467,10 +1468,12 @@
params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
const float* adj_lhs_col_ptr =
params.lhs_base_ptr - params.start_row * lhs_stride;
- const float* bias_col_ptr = params.bias;
+ const float* bias_ptr = params.bias;
const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
+ const bool channel_dimension_is_col =
+ params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
int col = params.start_col;
// Loop through cols by float block size, leaving incomplete remainder
@@ -1485,14 +1488,21 @@
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
float* dst_ptr = dst_col_ptr + row;
- const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
// Initialize with bias.
- const __m256 initial_accum_data =
- intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr);
+ if (channel_dimension_is_col) {
+ const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
+ }
+ } else {
+ const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
+ const __m256 initial_accum_data =
+ intrin_utils::mm256_n_loadu_ps(residual_rows, bias_elem_ptr);
- for (int j = 0; j < 8; ++j) {
- accum_data_v[j] = initial_accum_data;
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = initial_accum_data;
+ }
}
const float* lhs_ptr = lhs_col_ptr;
@@ -1549,14 +1559,21 @@
const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
float* dst_ptr = dst_col_ptr + row;
- const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
// Initialize with bias.
- const __m256 initial_accum_data =
- intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr);
+ if (channel_dimension_is_col) {
+ const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
+ }
+ } else {
+ const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
+ const __m256 initial_accum_data =
+ intrin_utils::mm256_n_loadu_ps(residual_rows, bias_elem_ptr);
- for (int j = 0; j < 8; ++j) {
- accum_data_v[j] = initial_accum_data;
+ for (int j = 0; j < 8; ++j) {
+ accum_data_v[j] = initial_accum_data;
+ }
}
const float* lhs_ptr = lhs_col_ptr;
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h
index 84e9e85..2def91f 100644
--- a/ruy/kernel_x86.h
+++ b/ruy/kernel_x86.h
@@ -49,7 +49,8 @@
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, ¶ms);
- if (dst->layout.cols == 1) {
+ if (dst->layout.cols == 1 &&
+ mul_params.channel_dimension() == ChannelDimension::kRow) {
Kernel8bitAvx512SingleCol(params);
} else {
Kernel8bitAvx512(params);
@@ -73,7 +74,8 @@
KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, ¶ms);
- if (dst->layout.cols == 1) {
+ if (dst->layout.cols == 1 &&
+ mul_params.channel_dimension() == ChannelDimension::kRow) {
KernelFloatAvx512SingleCol(params);
} else {
KernelFloatAvx512(params);
@@ -97,7 +99,8 @@
KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, ¶ms);
- if (dst->layout.cols == 1) {
+ if (dst->layout.cols == 1 &&
+ mul_params.channel_dimension() == ChannelDimension::kRow) {
Kernel8bitAvx2SingleCol(params);
} else {
Kernel8bitAvx2(params);
@@ -121,7 +124,8 @@
KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
end_col, dst, ¶ms);
- if (dst->layout.cols == 1) {
+ if (dst->layout.cols == 1 &&
+ mul_params.channel_dimension() == ChannelDimension::kRow) {
KernelFloatAvx2SingleCol(params);
} else {
KernelFloatAvx2(params);