Efficient support for any channel_dimension for float kernels on AVX2.
Also, restricting x86 single-column kernels to the case of channel_dimension==kRow, similar to what we did for ARM kernels.

PiperOrigin-RevId: 320973990
diff --git a/ruy/create_trmul_params.cc b/ruy/create_trmul_params.cc
index 032b357..619c064 100644
--- a/ruy/create_trmul_params.cc
+++ b/ruy/create_trmul_params.cc
@@ -58,6 +58,13 @@
   return false;
 #endif
 
+#if RUY_PLATFORM_X86
+  if (src[Side::kLhs].data_type == Type::Create<float>() &&
+      path == Path::kAvx2) {
+    return false;
+  }
+#endif
+
   // Ruy's optimized kernels currently only support the channel_dimension==kRow
   // case.
   if (channel_dimension != ChannelDimension::kRow) {
diff --git a/ruy/kernel_avx2.cc b/ruy/kernel_avx2.cc
index 5d34e6b..60c6e6b 100644
--- a/ruy/kernel_avx2.cc
+++ b/ruy/kernel_avx2.cc
@@ -18,6 +18,7 @@
 #include <cstring>
 
 #include "ruy/check_macros.h"
+#include "ruy/kernel_common.h"
 #include "ruy/kernel_x86.h"
 #include "ruy/opt_set.h"
 #include "ruy/platform.h"
@@ -1467,10 +1468,12 @@
       params.dst_base_ptr - params.start_col * dst_stride - params.start_row;
   const float* adj_lhs_col_ptr =
       params.lhs_base_ptr - params.start_row * lhs_stride;
-  const float* bias_col_ptr = params.bias;
+  const float* bias_ptr = params.bias;
 
   const __m256 clamp_max_v = _mm256_set1_ps(params.clamp_max);
   const __m256 clamp_min_v = _mm256_set1_ps(params.clamp_min);
+  const bool channel_dimension_is_col =
+      params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL;
 
   int col = params.start_col;
   // Loop through cols by float block size, leaving incomplete remainder
@@ -1485,14 +1488,21 @@
 
       const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
       float* dst_ptr = dst_col_ptr + row;
-      const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
 
       // Initialize with bias.
-      const __m256 initial_accum_data =
-          intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr);
+      if (channel_dimension_is_col) {
+        const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
+        for (int j = 0; j < 8; ++j) {
+          accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
+        }
+      } else {
+        const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
+        const __m256 initial_accum_data =
+            intrin_utils::mm256_n_loadu_ps(residual_rows, bias_elem_ptr);
 
-      for (int j = 0; j < 8; ++j) {
-        accum_data_v[j] = initial_accum_data;
+        for (int j = 0; j < 8; ++j) {
+          accum_data_v[j] = initial_accum_data;
+        }
       }
 
       const float* lhs_ptr = lhs_col_ptr;
@@ -1549,14 +1559,21 @@
 
       const float* lhs_col_ptr = adj_lhs_col_ptr + row * lhs_stride;
       float* dst_ptr = dst_col_ptr + row;
-      const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
 
       // Initialize with bias.
-      const __m256 initial_accum_data =
-          intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr);
+      if (channel_dimension_is_col) {
+        const float* bias_elem_ptr = bias_ptr + col * bias_ptr_block_increment;
+        for (int j = 0; j < 8; ++j) {
+          accum_data_v[j] = _mm256_broadcast_ss(bias_elem_ptr + j);
+        }
+      } else {
+        const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
+        const __m256 initial_accum_data =
+            intrin_utils::mm256_n_loadu_ps(residual_rows, bias_elem_ptr);
 
-      for (int j = 0; j < 8; ++j) {
-        accum_data_v[j] = initial_accum_data;
+        for (int j = 0; j < 8; ++j) {
+          accum_data_v[j] = initial_accum_data;
+        }
       }
 
       const float* lhs_ptr = lhs_col_ptr;
diff --git a/ruy/kernel_x86.h b/ruy/kernel_x86.h
index 84e9e85..2def91f 100644
--- a/ruy/kernel_x86.h
+++ b/ruy/kernel_x86.h
@@ -49,7 +49,8 @@
     KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
     MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
                          end_col, dst, &params);
-    if (dst->layout.cols == 1) {
+    if (dst->layout.cols == 1 &&
+        mul_params.channel_dimension() == ChannelDimension::kRow) {
       Kernel8bitAvx512SingleCol(params);
     } else {
       Kernel8bitAvx512(params);
@@ -73,7 +74,8 @@
     KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
     MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
                           end_col, dst, &params);
-    if (dst->layout.cols == 1) {
+    if (dst->layout.cols == 1 &&
+        mul_params.channel_dimension() == ChannelDimension::kRow) {
       KernelFloatAvx512SingleCol(params);
     } else {
       KernelFloatAvx512(params);
@@ -97,7 +99,8 @@
     KernelParams8bit<LhsLayout::kCols, RhsLayout::kCols> params;
     MakeKernelParams8bit(lhs, rhs, mul_params, start_row, start_col, end_row,
                          end_col, dst, &params);
-    if (dst->layout.cols == 1) {
+    if (dst->layout.cols == 1 &&
+        mul_params.channel_dimension() == ChannelDimension::kRow) {
       Kernel8bitAvx2SingleCol(params);
     } else {
       Kernel8bitAvx2(params);
@@ -121,7 +124,8 @@
     KernelParamsFloat<LhsLayout::kCols, RhsLayout::kCols> params;
     MakeKernelParamsFloat(lhs, rhs, mul_params, start_row, start_col, end_row,
                           end_col, dst, &params);
-    if (dst->layout.cols == 1) {
+    if (dst->layout.cols == 1 &&
+        mul_params.channel_dimension() == ChannelDimension::kRow) {
       KernelFloatAvx2SingleCol(params);
     } else {
       KernelFloatAvx2(params);