Let cpu_backend_gemm support all storage order combinations, unconditionally using ruy as the backend in combinations other than RowMajor*ColMajor->ColMajor, which were so far not supported. Ruy is different from other back-ends in that it supports all combinations as runtime parameters without a code size increase.

PiperOrigin-RevId: 323013778
diff --git a/ruy/kernel_avx2_fma.cc b/ruy/kernel_avx2_fma.cc
index 463bb4b..4ba73f5 100644
--- a/ruy/kernel_avx2_fma.cc
+++ b/ruy/kernel_avx2_fma.cc
@@ -319,39 +319,6 @@
   return float_val;
 }
 
-inline __m256 mm256_n_loadu_ps(int i, const float* src) {
-  switch (i) {
-    case 0:
-      return _mm256_setzero_ps();
-    case 1:
-      return _mm256_setr_m128(_mm_setr_ps(src[0], .0f, .0f, .0f),
-                              _mm_setzero_ps());
-    case 2:
-      return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], .0f, .0f),
-                              _mm_setzero_ps());
-    case 3:
-      return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], .0f),
-                              _mm_setzero_ps());
-    case 4:
-      return _mm256_setr_m128(_mm_setr_ps(src[0], src[1], src[2], src[3]),
-                              _mm_setzero_ps());
-    case 5:
-      return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], .0f, .0f,
-                            .0f);
-    case 6:
-      return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5], .0f,
-                            .0f);
-    case 7:
-      return _mm256_setr_ps(src[0], src[1], src[2], src[3], src[4], src[5],
-                            src[6], .0f);
-    case 8:
-      return _mm256_loadu_ps(src);
-    default:
-      RUY_DCHECK_LT(i, 9);
-      return _mm256_setzero_ps();
-  }
-}
-
 inline void mm256_n_storeu_ps(float* dst, int residual_rows, const __m256 v) {
   for (int i = 0; i < residual_rows; ++i) {
     dst[i] = intrin_utils::mm256_get1_ps(v, i);
@@ -589,126 +556,26 @@
         // Take bytes 2, 3, 6, 7, 10, 11, ... expanded to 16-bit.
         const __m256i lhs_16_bit_high = _mm256_permute2x128_si256(
             lhs_data_split_expand_bottom, lhs_data_split_expand_top, 0x31);
-        // Accumulate for column 0.
-        {
-          const std::int32_t low_rhs_value = rhs_data[0];
-          const std::int32_t high_rhs_value = rhs_data[1];
+        auto process_column = [=](int col, __m256i& accum) {
+          const std::int32_t low_rhs_value = rhs_data[col * 2];
+          const std::int32_t high_rhs_value = rhs_data[col * 2 + 1];
 
           const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
           const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
 
-          accum_data_v0 = _mm256_add_epi32(
-              accum_data_v0,
-              _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
-          accum_data_v0 = _mm256_add_epi32(
-              accum_data_v0,
-              _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
-        }
-        // Accumulate for column 1.
-        {
-          const std::int32_t low_rhs_value = rhs_data[2];
-          const std::int32_t high_rhs_value = rhs_data[3];
-
-          const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
-          const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
-
-          accum_data_v1 = _mm256_add_epi32(
-              accum_data_v1,
-              _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
-          accum_data_v1 = _mm256_add_epi32(
-              accum_data_v1,
-              _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
-        }
-        // Accumulate for column 2.
-        {
-          const std::int32_t low_rhs_value = rhs_data[4];
-          const std::int32_t high_rhs_value = rhs_data[5];
-
-          const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
-          const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
-
-          accum_data_v2 = _mm256_add_epi32(
-              accum_data_v2,
-              _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
-          accum_data_v2 = _mm256_add_epi32(
-              accum_data_v2,
-              _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
-        }
-        // Accumulate for column 3.
-        {
-          const std::int32_t low_rhs_value = rhs_data[6];
-          const std::int32_t high_rhs_value = rhs_data[7];
-
-          const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
-          const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
-
-          accum_data_v3 = _mm256_add_epi32(
-              accum_data_v3,
-              _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
-          accum_data_v3 = _mm256_add_epi32(
-              accum_data_v3,
-              _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
-        }
-        // Accumulate for column 4.
-        {
-          const std::int32_t low_rhs_value = rhs_data[8];
-          const std::int32_t high_rhs_value = rhs_data[9];
-
-          const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
-          const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
-
-          accum_data_v4 = _mm256_add_epi32(
-              accum_data_v4,
-              _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
-          accum_data_v4 = _mm256_add_epi32(
-              accum_data_v4,
-              _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
-        }
-        // Accumulate for column 5.
-        {
-          const std::int32_t low_rhs_value = rhs_data[10];
-          const std::int32_t high_rhs_value = rhs_data[11];
-
-          const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
-          const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
-
-          accum_data_v5 = _mm256_add_epi32(
-              accum_data_v5,
-              _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
-          accum_data_v5 = _mm256_add_epi32(
-              accum_data_v5,
-              _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
-        }
-        // Accumulate for column 6.
-        {
-          const std::int32_t low_rhs_value = rhs_data[12];
-          const std::int32_t high_rhs_value = rhs_data[13];
-
-          const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
-          const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
-
-          accum_data_v6 = _mm256_add_epi32(
-              accum_data_v6,
-              _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
-          accum_data_v6 = _mm256_add_epi32(
-              accum_data_v6,
-              _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
-        }
-        // Accumulate for column 7.
-        {
-          const std::int32_t low_rhs_value = rhs_data[14];
-          const std::int32_t high_rhs_value = rhs_data[15];
-
-          const __m256i rhs_16_bit_dup_low = _mm256_set1_epi32(low_rhs_value);
-          const __m256i rhs_16_bit_dup_high = _mm256_set1_epi32(high_rhs_value);
-
-          accum_data_v7 = _mm256_add_epi32(
-              accum_data_v7,
-              _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
-          accum_data_v7 = _mm256_add_epi32(
-              accum_data_v7,
-              _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
-        }
+          accum = _mm256_add_epi32(
+              accum, _mm256_madd_epi16(lhs_16_bit_low, rhs_16_bit_dup_low));
+          accum = _mm256_add_epi32(
+              accum, _mm256_madd_epi16(lhs_16_bit_high, rhs_16_bit_dup_high));
+        };
+        process_column(0, accum_data_v0);
+        process_column(1, accum_data_v1);
+        process_column(2, accum_data_v2);
+        process_column(3, accum_data_v3);
+        process_column(4, accum_data_v4);
+        process_column(5, accum_data_v5);
+        process_column(6, accum_data_v6);
+        process_column(7, accum_data_v7);
 
         lhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
         rhs_ptr += kAvx8bitBlockSize * kAvx8bitInnerSize;
@@ -844,8 +711,8 @@
               &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
               &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7);
         }
-        {
-          __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v0, left_shift);
+        auto apply_multiplier = [=](__m256i& accum) {
+          __m256i shifted_accum = _mm256_sllv_epi32(accum, left_shift);
           // Apply the fixed-point part of the multiplier.
           __m256i scaled_v_low = _mm256_mul_epi32(
               _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
@@ -866,176 +733,16 @@
               _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
           results = _mm256_permutevar8x32_epi32(results, repack_perm);
 
-          accum_data_v0 = _mm256_sub_epi32(results, post_scaling_offset);
-        }
-        {
-          __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v1, left_shift);
-          // Apply the fixed-point part of the multiplier.
-          __m256i scaled_v_low = _mm256_mul_epi32(
-              _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
-              m_64bit_low);
-          __m256i scaled_v_high = _mm256_mul_epi32(
-              _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
-              m_64bit_high);
-
-          scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
-          scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
-
-          scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
-          scaled_v_high =
-              _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
-
-          scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
-          __m256i results =
-              _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
-          results = _mm256_permutevar8x32_epi32(results, repack_perm);
-
-          accum_data_v1 = _mm256_sub_epi32(results, post_scaling_offset);
-        }
-        {
-          __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v2, left_shift);
-          // Apply the fixed-point part of the multiplier.
-          __m256i scaled_v_low = _mm256_mul_epi32(
-              _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
-              m_64bit_low);
-          __m256i scaled_v_high = _mm256_mul_epi32(
-              _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
-              m_64bit_high);
-
-          scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
-          scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
-
-          scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
-          scaled_v_high =
-              _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
-
-          scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
-          __m256i results =
-              _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
-          results = _mm256_permutevar8x32_epi32(results, repack_perm);
-
-          accum_data_v2 = _mm256_sub_epi32(results, post_scaling_offset);
-        }
-        {
-          __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v3, left_shift);
-          // Apply the fixed-point part of the multiplier.
-          __m256i scaled_v_low = _mm256_mul_epi32(
-              _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
-              m_64bit_low);
-          __m256i scaled_v_high = _mm256_mul_epi32(
-              _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
-              m_64bit_high);
-
-          scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
-          scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
-
-          scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
-          scaled_v_high =
-              _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
-
-          scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
-          __m256i results =
-              _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
-          results = _mm256_permutevar8x32_epi32(results, repack_perm);
-
-          accum_data_v3 = _mm256_sub_epi32(results, post_scaling_offset);
-        }
-        {
-          __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v4, left_shift);
-          // Apply the fixed-point part of the multiplier.
-          __m256i scaled_v_low = _mm256_mul_epi32(
-              _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
-              m_64bit_low);
-          __m256i scaled_v_high = _mm256_mul_epi32(
-              _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
-              m_64bit_high);
-
-          scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
-          scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
-
-          scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
-          scaled_v_high =
-              _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
-
-          scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
-          __m256i results =
-              _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
-          results = _mm256_permutevar8x32_epi32(results, repack_perm);
-
-          accum_data_v4 = _mm256_sub_epi32(results, post_scaling_offset);
-        }
-        {
-          __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v5, left_shift);
-          // Apply the fixed-point part of the multiplier.
-          __m256i scaled_v_low = _mm256_mul_epi32(
-              _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
-              m_64bit_low);
-          __m256i scaled_v_high = _mm256_mul_epi32(
-              _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
-              m_64bit_high);
-
-          scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
-          scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
-
-          scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
-          scaled_v_high =
-              _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
-
-          scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
-          __m256i results =
-              _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
-          results = _mm256_permutevar8x32_epi32(results, repack_perm);
-
-          accum_data_v5 = _mm256_sub_epi32(results, post_scaling_offset);
-        }
-        {
-          __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v6, left_shift);
-          // Apply the fixed-point part of the multiplier.
-          __m256i scaled_v_low = _mm256_mul_epi32(
-              _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
-              m_64bit_low);
-          __m256i scaled_v_high = _mm256_mul_epi32(
-              _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
-              m_64bit_high);
-
-          scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
-          scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
-
-          scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
-          scaled_v_high =
-              _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
-
-          scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
-          __m256i results =
-              _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
-          results = _mm256_permutevar8x32_epi32(results, repack_perm);
-
-          accum_data_v6 = _mm256_sub_epi32(results, post_scaling_offset);
-        }
-        {
-          __m256i shifted_accum = _mm256_sllv_epi32(accum_data_v7, left_shift);
-          // Apply the fixed-point part of the multiplier.
-          __m256i scaled_v_low = _mm256_mul_epi32(
-              _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 0)),
-              m_64bit_low);
-          __m256i scaled_v_high = _mm256_mul_epi32(
-              _mm256_cvtepi32_epi64(_mm256_extracti128_si256(shifted_accum, 1)),
-              m_64bit_high);
-
-          scaled_v_low = _mm256_add_epi64(scaled_v_low, offset_vector_low);
-          scaled_v_high = _mm256_add_epi64(scaled_v_high, offset_vector_high);
-
-          scaled_v_low = _mm256_srlv_epi64(scaled_v_low, final_right_shift_low);
-          scaled_v_high =
-              _mm256_srlv_epi64(scaled_v_high, final_right_shift_high);
-
-          scaled_v_high = _mm256_slli_epi64(scaled_v_high, 32);
-          __m256i results =
-              _mm256_blend_epi32(scaled_v_low, scaled_v_high, 0xaa);
-          results = _mm256_permutevar8x32_epi32(results, repack_perm);
-
-          accum_data_v7 = _mm256_sub_epi32(results, post_scaling_offset);
-        }
+          accum = _mm256_sub_epi32(results, post_scaling_offset);
+        };
+        apply_multiplier(accum_data_v0);
+        apply_multiplier(accum_data_v1);
+        apply_multiplier(accum_data_v2);
+        apply_multiplier(accum_data_v3);
+        apply_multiplier(accum_data_v4);
+        apply_multiplier(accum_data_v5);
+        apply_multiplier(accum_data_v6);
+        apply_multiplier(accum_data_v7);
         // See above comment: here we transpose again to undo the transposition
         // of the 8x8 block of accumulators used to implement the
         // channels-are-columns case.
@@ -1549,8 +1256,7 @@
         }
       } else {
         const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
-        const __m256 initial_accum_data =
-            intrin_utils::mm256_n_loadu_ps(residual_rows, bias_elem_ptr);
+        const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
 
         for (int j = 0; j < 8; ++j) {
           accum_data_v[j] = initial_accum_data;
@@ -1620,8 +1326,7 @@
         }
       } else {
         const float* bias_elem_ptr = bias_ptr + row * bias_ptr_block_increment;
-        const __m256 initial_accum_data =
-            intrin_utils::mm256_n_loadu_ps(residual_rows, bias_elem_ptr);
+        const __m256 initial_accum_data = _mm256_loadu_ps(bias_elem_ptr);
 
         for (int j = 0; j < 8; ++j) {
           accum_data_v[j] = initial_accum_data;
@@ -1739,7 +1444,7 @@
     const float* bias_ptr = bias_col_ptr + row * bias_ptr_block_increment;
 
     // Initialize with bias.
-    accum_data_v = intrin_utils::mm256_n_loadu_ps(residual_rows, bias_ptr);
+    accum_data_v = _mm256_loadu_ps(bias_ptr);
 
     const float* lhs_ptr = lhs_col_ptr;
     const float* rhs_ptr = rhs_col_ptr;
diff --git a/ruy/kernel_avx512.cc b/ruy/kernel_avx512.cc
index 3d36516..d2d1f4c 100644
--- a/ruy/kernel_avx512.cc
+++ b/ruy/kernel_avx512.cc
@@ -52,88 +52,6 @@
 
 #else  // RUY_PLATFORM_AVX512 && RUY_OPT(ASM)
 
-namespace {
-namespace intrin_utils {
-
-// Transpose a 8x8 matrix of int32's.
-void mm512_transpose16x16_epi32(__m512i* v0, __m512i* v1, __m512i* v2,
-                                __m512i* v3, __m512i* v4, __m512i* v5,
-                                __m512i* v6, __m512i* v7, __m512i* v8,
-                                __m512i* v9, __m512i* va, __m512i* vb,
-                                __m512i* vc, __m512i* vd, __m512i* ve,
-                                __m512i* vf) {
-  __m512i t2x2_0 = _mm512_unpacklo_epi32(*v0, *v1);
-  __m512i t2x2_1 = _mm512_unpackhi_epi32(*v0, *v1);
-  __m512i t2x2_2 = _mm512_unpacklo_epi32(*v2, *v3);
-  __m512i t2x2_3 = _mm512_unpackhi_epi32(*v2, *v3);
-  __m512i t2x2_4 = _mm512_unpacklo_epi32(*v4, *v5);
-  __m512i t2x2_5 = _mm512_unpackhi_epi32(*v4, *v5);
-  __m512i t2x2_6 = _mm512_unpacklo_epi32(*v6, *v7);
-  __m512i t2x2_7 = _mm512_unpackhi_epi32(*v6, *v7);
-  __m512i t2x2_8 = _mm512_unpacklo_epi32(*v8, *v9);
-  __m512i t2x2_9 = _mm512_unpackhi_epi32(*v8, *v9);
-  __m512i t2x2_a = _mm512_unpacklo_epi32(*va, *vb);
-  __m512i t2x2_b = _mm512_unpackhi_epi32(*va, *vb);
-  __m512i t2x2_c = _mm512_unpacklo_epi32(*vc, *vd);
-  __m512i t2x2_d = _mm512_unpackhi_epi32(*vc, *vd);
-  __m512i t2x2_e = _mm512_unpacklo_epi32(*ve, *vf);
-  __m512i t2x2_f = _mm512_unpackhi_epi32(*ve, *vf);
-
-  __m512i t4x4_0 = _mm512_unpacklo_epi64(t2x2_0, t2x2_2);
-  __m512i t4x4_1 = _mm512_unpackhi_epi64(t2x2_0, t2x2_2);
-  __m512i t4x4_2 = _mm512_unpacklo_epi64(t2x2_1, t2x2_3);
-  __m512i t4x4_3 = _mm512_unpackhi_epi64(t2x2_1, t2x2_3);
-  __m512i t4x4_4 = _mm512_unpacklo_epi64(t2x2_4, t2x2_6);
-  __m512i t4x4_5 = _mm512_unpackhi_epi64(t2x2_4, t2x2_6);
-  __m512i t4x4_6 = _mm512_unpacklo_epi64(t2x2_5, t2x2_7);
-  __m512i t4x4_7 = _mm512_unpackhi_epi64(t2x2_5, t2x2_7);
-  __m512i t4x4_8 = _mm512_unpacklo_epi64(t2x2_8, t2x2_a);
-  __m512i t4x4_9 = _mm512_unpackhi_epi64(t2x2_8, t2x2_a);
-  __m512i t4x4_a = _mm512_unpacklo_epi64(t2x2_9, t2x2_b);
-  __m512i t4x4_b = _mm512_unpackhi_epi64(t2x2_9, t2x2_b);
-  __m512i t4x4_c = _mm512_unpacklo_epi64(t2x2_c, t2x2_e);
-  __m512i t4x4_d = _mm512_unpackhi_epi64(t2x2_c, t2x2_e);
-  __m512i t4x4_e = _mm512_unpacklo_epi64(t2x2_d, t2x2_f);
-  __m512i t4x4_f = _mm512_unpackhi_epi64(t2x2_d, t2x2_f);
-
-  __m512i t8x8_0 = _mm512_shuffle_i32x4(t4x4_0, t4x4_4, 0x88);
-  __m512i t8x8_1 = _mm512_shuffle_i32x4(t4x4_1, t4x4_5, 0x88);
-  __m512i t8x8_2 = _mm512_shuffle_i32x4(t4x4_2, t4x4_6, 0x88);
-  __m512i t8x8_3 = _mm512_shuffle_i32x4(t4x4_3, t4x4_7, 0x88);
-  __m512i t8x8_4 = _mm512_shuffle_i32x4(t4x4_0, t4x4_4, 0xdd);
-  __m512i t8x8_5 = _mm512_shuffle_i32x4(t4x4_1, t4x4_5, 0xdd);
-  __m512i t8x8_6 = _mm512_shuffle_i32x4(t4x4_2, t4x4_6, 0xdd);
-  __m512i t8x8_7 = _mm512_shuffle_i32x4(t4x4_3, t4x4_7, 0xdd);
-  __m512i t8x8_8 = _mm512_shuffle_i32x4(t4x4_8, t4x4_c, 0x88);
-  __m512i t8x8_9 = _mm512_shuffle_i32x4(t4x4_9, t4x4_d, 0x88);
-  __m512i t8x8_a = _mm512_shuffle_i32x4(t4x4_a, t4x4_e, 0x88);
-  __m512i t8x8_b = _mm512_shuffle_i32x4(t4x4_b, t4x4_f, 0x88);
-  __m512i t8x8_c = _mm512_shuffle_i32x4(t4x4_8, t4x4_c, 0xdd);
-  __m512i t8x8_d = _mm512_shuffle_i32x4(t4x4_9, t4x4_d, 0xdd);
-  __m512i t8x8_e = _mm512_shuffle_i32x4(t4x4_a, t4x4_e, 0xdd);
-  __m512i t8x8_f = _mm512_shuffle_i32x4(t4x4_b, t4x4_f, 0xdd);
-
-  *v0 = _mm512_shuffle_i32x4(t8x8_0, t8x8_8, 0x88);
-  *v1 = _mm512_shuffle_i32x4(t8x8_1, t8x8_9, 0x88);
-  *v2 = _mm512_shuffle_i32x4(t8x8_2, t8x8_a, 0x88);
-  *v3 = _mm512_shuffle_i32x4(t8x8_3, t8x8_b, 0x88);
-  *v4 = _mm512_shuffle_i32x4(t8x8_4, t8x8_c, 0x88);
-  *v5 = _mm512_shuffle_i32x4(t8x8_5, t8x8_d, 0x88);
-  *v6 = _mm512_shuffle_i32x4(t8x8_6, t8x8_e, 0x88);
-  *v7 = _mm512_shuffle_i32x4(t8x8_7, t8x8_f, 0x88);
-  *v8 = _mm512_shuffle_i32x4(t8x8_0, t8x8_8, 0xdd);
-  *v9 = _mm512_shuffle_i32x4(t8x8_1, t8x8_9, 0xdd);
-  *va = _mm512_shuffle_i32x4(t8x8_2, t8x8_a, 0xdd);
-  *vb = _mm512_shuffle_i32x4(t8x8_3, t8x8_b, 0xdd);
-  *vc = _mm512_shuffle_i32x4(t8x8_4, t8x8_c, 0xdd);
-  *vd = _mm512_shuffle_i32x4(t8x8_5, t8x8_d, 0xdd);
-  *ve = _mm512_shuffle_i32x4(t8x8_6, t8x8_e, 0xdd);
-  *vf = _mm512_shuffle_i32x4(t8x8_7, t8x8_f, 0xdd);
-}
-
-}  // namespace intrin_utils
-}  // namespace
-
 void Kernel8bitAvx512(const KernelParams8bit<16, 16>& params) {
   profiler::ScopeLabel label("Kernel kAvx512 8-bit");
 
@@ -391,6 +309,13 @@
       }
 
       if (params.dst_type_id != DstTypeId<std::int32_t>::kValue) {
+        // The non-per-channel case could equivalently be handled in the per-row
+        // or per-column code path. The per-row code path is slightly more
+        // efficient so we handle it there.
+        const bool per_column_multiplier =
+            (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
+            (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
+
         __m512i m_vector;
         __m512i e_vector;
         // Does not make use of RUY_ASM_FLAG_NEEDS_LEFT_SHIFT.
@@ -426,72 +351,96 @@
             offset_vector,
             _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(right_shift, 1)));
 
-        // This multiplier code is complex and expensive enough on x86, that
-        // we prefer to implement the channels-are-columns case by transposing
-        // around it, rather than duplicate it (which would also require
-        // duplicating the above code computing the multiplier constants).
-        // This is one instance where channels-are-columns has lower performance
-        // than channels-are-rows.
-        const bool transpose_around_multiplier =
-            (params.flags & RUY_ASM_FLAG_HAS_PERCHANNEL) &&
-            (params.flags & RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL);
-        if (transpose_around_multiplier) {
-          // Transpose the 16x16 accumulators block. Will be un-transposed below
-          // after the multplier implementation.
-          intrin_utils::mm512_transpose16x16_epi32(
-              &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
-              &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7,
-              &accum_data_v8, &accum_data_v9, &accum_data_va, &accum_data_vb,
-              &accum_data_vc, &accum_data_vd, &accum_data_ve, &accum_data_vf);
-        }
+        if (per_column_multiplier) {
+          auto apply_multiplier = [=](__m512i& accum, int col) {
+            __m512i perm_64bit_vals = _mm512_set1_epi64(col % 8);
+            // Apply the fixed-point part of the multiplier.
+            __m512i left_shift_val =
+                _mm512_permutexvar_epi32(_mm512_set1_epi32(col), left_shift);
+            __m512i m_64bit_val = _mm512_permutexvar_epi64(
+                perm_64bit_vals, col < 8 ? m_64bit_low : m_64bit_high);
+            __m512i offset_vector_val = _mm512_permutexvar_epi64(
+                perm_64bit_vals,
+                col < 8 ? offset_vector_low : offset_vector_high);
+            __m512i final_right_shift_val = _mm512_permutexvar_epi64(
+                perm_64bit_vals,
+                col < 8 ? final_right_shift_low : final_right_shift_high);
 
-        auto apply_multiplier = [=](__m512i& accum) {
-          accum = _mm512_sllv_epi32(accum, left_shift);
-          // Apply the fixed-point part of the multiplier.
-          __m512i scaled_v_low = _mm512_mul_epi32(
-              _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
-              m_64bit_low);
-          __m512i scaled_v_high = _mm512_mul_epi32(
-              _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
-              m_64bit_high);
+            accum = _mm512_sllv_epi32(accum, left_shift_val);
+            __m512i scaled_v_low = _mm512_mul_epi32(
+                _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
+                m_64bit_val);
+            __m512i scaled_v_high = _mm512_mul_epi32(
+                _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
+                m_64bit_val);
 
-          scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
-          scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+            scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_val);
+            scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_val);
 
-          scaled_v_low = _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
-          scaled_v_high =
-              _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+            scaled_v_low =
+                _mm512_srav_epi64(scaled_v_low, final_right_shift_val);
+            scaled_v_high =
+                _mm512_srav_epi64(scaled_v_high, final_right_shift_val);
 
-          accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
-          accum = _mm512_inserti32x8(accum,
-                                     _mm512_cvtepi64_epi32(scaled_v_high), 1);
-        };
-        apply_multiplier(accum_data_v0);
-        apply_multiplier(accum_data_v1);
-        apply_multiplier(accum_data_v2);
-        apply_multiplier(accum_data_v3);
-        apply_multiplier(accum_data_v4);
-        apply_multiplier(accum_data_v5);
-        apply_multiplier(accum_data_v6);
-        apply_multiplier(accum_data_v7);
-        apply_multiplier(accum_data_v8);
-        apply_multiplier(accum_data_v9);
-        apply_multiplier(accum_data_va);
-        apply_multiplier(accum_data_vb);
-        apply_multiplier(accum_data_vc);
-        apply_multiplier(accum_data_vd);
-        apply_multiplier(accum_data_ve);
-        apply_multiplier(accum_data_vf);
+            accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+            accum = _mm512_inserti32x8(accum,
+                                       _mm512_cvtepi64_epi32(scaled_v_high), 1);
+          };
+          apply_multiplier(accum_data_v0, 0);
+          apply_multiplier(accum_data_v1, 1);
+          apply_multiplier(accum_data_v2, 2);
+          apply_multiplier(accum_data_v3, 3);
+          apply_multiplier(accum_data_v4, 4);
+          apply_multiplier(accum_data_v5, 5);
+          apply_multiplier(accum_data_v6, 6);
+          apply_multiplier(accum_data_v7, 7);
+          apply_multiplier(accum_data_v8, 8);
+          apply_multiplier(accum_data_v9, 9);
+          apply_multiplier(accum_data_va, 10);
+          apply_multiplier(accum_data_vb, 11);
+          apply_multiplier(accum_data_vc, 12);
+          apply_multiplier(accum_data_vd, 13);
+          apply_multiplier(accum_data_ve, 14);
+          apply_multiplier(accum_data_vf, 15);
+        } else {  // not per-column, so per-row
+          auto apply_multiplier = [=](__m512i& accum) {
+            accum = _mm512_sllv_epi32(accum, left_shift);
+            // Apply the fixed-point part of the multiplier.
+            __m512i scaled_v_low = _mm512_mul_epi32(
+                _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 0)),
+                m_64bit_low);
+            __m512i scaled_v_high = _mm512_mul_epi32(
+                _mm512_cvtepi32_epi64(_mm512_extracti32x8_epi32(accum, 1)),
+                m_64bit_high);
 
-        if (transpose_around_multiplier) {
-          // See above comment: here we transpose again to undo the
-          // transposition of the 16x16 block of accumulators used to implement
-          // the channels-are-columns case.
-          intrin_utils::mm512_transpose16x16_epi32(
-              &accum_data_v0, &accum_data_v1, &accum_data_v2, &accum_data_v3,
-              &accum_data_v4, &accum_data_v5, &accum_data_v6, &accum_data_v7,
-              &accum_data_v8, &accum_data_v9, &accum_data_va, &accum_data_vb,
-              &accum_data_vc, &accum_data_vd, &accum_data_ve, &accum_data_vf);
+            scaled_v_low = _mm512_add_epi64(scaled_v_low, offset_vector_low);
+            scaled_v_high = _mm512_add_epi64(scaled_v_high, offset_vector_high);
+
+            scaled_v_low =
+                _mm512_srav_epi64(scaled_v_low, final_right_shift_low);
+            scaled_v_high =
+                _mm512_srav_epi64(scaled_v_high, final_right_shift_high);
+
+            accum = _mm512_castsi256_si512(_mm512_cvtepi64_epi32(scaled_v_low));
+            accum = _mm512_inserti32x8(accum,
+                                       _mm512_cvtepi64_epi32(scaled_v_high), 1);
+          };
+          apply_multiplier(accum_data_v0);
+          apply_multiplier(accum_data_v1);
+          apply_multiplier(accum_data_v2);
+          apply_multiplier(accum_data_v3);
+          apply_multiplier(accum_data_v4);
+          apply_multiplier(accum_data_v5);
+          apply_multiplier(accum_data_v6);
+          apply_multiplier(accum_data_v7);
+          apply_multiplier(accum_data_v8);
+          apply_multiplier(accum_data_v9);
+          apply_multiplier(accum_data_va);
+          apply_multiplier(accum_data_vb);
+          apply_multiplier(accum_data_vc);
+          apply_multiplier(accum_data_vd);
+          apply_multiplier(accum_data_ve);
+          apply_multiplier(accum_data_vf);
         }
 
         if (params.dst_zero_point != 0) {