| /* Copyright 2020 Google LLC. All Rights Reserved. |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| ==============================================================================*/ |
| |
| // Implementation of CreateTrMulParams, see function comment. |
| |
| #ifndef RUY_RUY_CREATE_TRMUL_PARAMS_H_ |
| #define RUY_RUY_CREATE_TRMUL_PARAMS_H_ |
| |
| #include <type_traits> |
| |
| #include "ruy/ctx.h" |
| #include "ruy/kernel.h" |
| #include "ruy/mat.h" |
| #include "ruy/mul_params.h" |
| #include "ruy/pack.h" |
| #include "ruy/path.h" |
| #include "ruy/trmul_params.h" |
| |
| namespace ruy { |
| // While the only entry point to this file is CreateTrMulParams, its templatized |
| // nature requires putting more code in this header than we would like. This |
| // internal implementation code is enclosed in namespace 'detail'. |
| namespace detail { |
| |
| void CreatePackedLayout(const MatLayout& src, const Type& scalar, |
| const KernelLayout& kernel_layout, |
| PMatLayout* packed_layout); |
| |
| template <typename Scalar, typename PackedScalar> |
| void CreatePackedMatrix(Side side, const KernelLayout& kernel_layout, |
| TrMulParams* params) { |
| // Ruy always uses 32-bit signed accumulators for quantized |
| // matrix multiplication, so we would like to always use std::int32_t |
| // unconditionally for SumsType. |
| // However, for floating point types, we still need a reasonable type here to |
| // avoid tripping assertions elsewhere in the code. |
| using SumsType = |
| typename std::conditional<std::is_floating_point<Scalar>::value, Scalar, |
| std::int32_t>::type; |
| |
| const EMat& src = params->src[side]; |
| PEMat* packed_matrix = ¶ms->packed_matrix[side]; |
| packed_matrix->data_type = Type::Create<PackedScalar>(); |
| packed_matrix->sums_type = Type::Create<SumsType>(); |
| CreatePackedLayout(src.layout, packed_matrix->data_type, kernel_layout, |
| &packed_matrix->layout); |
| packed_matrix->zero_point = Pack<PackedScalar, Scalar>(src.zero_point); |
| } |
| |
| template <typename KernelType> |
| struct CheckKernelPathImpl { |
| static void Run(Path) { |
| // Do nothing. |
| // Path fallbacks are normal in general (see RUY_INHERIT_KERNEL). |
| // That is to say that one may instantiate ruy::Mul with a weird combination |
| // of types, such as LhsScalar==float and RhsScalar==double, and have it |
| // work by silently falling back to Path::kStandardCpp. Only in specific |
| // cases do we have dedicated kernels overriding that fallback, and that is |
| // what partial specializations of this template will check. |
| } |
| }; |
| |
| #if RUY_DCHECK_IS_ENABLED |
| template <Path ThePath, typename SrcScalar, typename AccumScalar, |
| typename DstScalar> |
| struct CheckKernelPathImpl<Kernel<ThePath, SrcScalar, SrcScalar, DstScalar, |
| MulParams<AccumScalar, DstScalar>>> |
| final { |
| using KernelType = Kernel<ThePath, SrcScalar, SrcScalar, DstScalar, |
| MulParams<AccumScalar, DstScalar>>; |
| static void Run(Path expected_path) { |
| // We want to assert that we are using a dedicated Kernel specialization and |
| // not a fallback when we know we are in a case where such a kernel |
| // specialization exists. At the moment in the current state of ruy's |
| // architecture support for ARM and x86, that is when LhsScalar==RhsScalar |
| // (already implied in this partial specialization) and when that type is |
| // either float, int8, or uint8. Indeed, we have kernels supporting float |
| // and int8, and we have the packing code converting uint8 to int8 (see |
| // PackedTypeImpl). |
| static constexpr bool kSrcScalarTypeSupportsFastKernels = |
| std::is_same<SrcScalar, float>::value || |
| std::is_same<SrcScalar, std::int8_t>::value || |
| std::is_same<SrcScalar, std::uint8_t>::value; |
| if (kSrcScalarTypeSupportsFastKernels) { |
| RUY_DCHECK_EQ(expected_path, KernelType::kPath); |
| } |
| } |
| }; |
| #endif |
| |
| template <typename KernelType> |
| void CheckKernelPath(Path expected_path) { |
| CheckKernelPathImpl<KernelType>::Run(expected_path); |
| } |
| |
| template <Path ThePath, typename LhsScalar, typename RhsScalar, |
| typename AccumScalar, typename DstScalar> |
| void PopulateTrMulParams(TrMulParams* params) { |
| using PackedLhsScalar = PackedType<ThePath, LhsScalar>; |
| using PackedRhsScalar = PackedType<ThePath, RhsScalar>; |
| using Kernel = |
| Kernel<ThePath, PackedLhsScalar, PackedRhsScalar, AccumScalar, DstScalar>; |
| using LhsKernelLayout = typename Kernel::LhsLayout; |
| using RhsKernelLayout = typename Kernel::RhsLayout; |
| |
| params->path = ThePath; |
| |
| CreatePackedMatrix<LhsScalar, PackedLhsScalar>( |
| Side::kLhs, ToKernelLayout<LhsKernelLayout>(), params); |
| CreatePackedMatrix<RhsScalar, PackedRhsScalar>( |
| Side::kRhs, ToKernelLayout<RhsKernelLayout>(), params); |
| params->run_pack[Side::kLhs] = |
| &RunPack<ThePath, LhsKernelLayout, LhsScalar, PackedLhsScalar>; |
| params->run_pack[Side::kRhs] = |
| &RunPack<ThePath, RhsKernelLayout, RhsScalar, PackedRhsScalar>; |
| params->run_kernel = &RunKernel<Kernel>::Run; |
| CheckKernelPath<Kernel>(ThePath); |
| } |
| |
| // PopulateTrMulParamsAllCompiledPaths calls into one of multiple |
| // instantiations of PopulateTrMulParams. For each bit that is set in |
| // CompiledPaths, it statically instantiates PopulateTrMulParams with a Path |
| // corresponding to that single bit. The call to PopulateTrMulParams is |
| // guarded by a runtime check that it is in fact the dynamically selected path. |
| // |
| // PopulateTrMulParamsAllCompiledPaths is implemented with template |
| // metaprogramming by mutual recursion between PathSearchCountdown and |
| // PathSearchCompiledPaths. |
| // |
| // PopulateTrMulParamsAllCompiledPaths is logically implementing the following |
| // computation: |
| // |
| // template <Path CompiledPaths> |
| // void PopulateTrMulParamsAllCompiledPaths(Path the_path, |
| // TrMulParams* params) { |
| // for (int bit = 8 * sizeof(Path) - 1; bit != -1; bit--) { // [1] |
| // Path current_path = static_cast<Path>(1 << bit); |
| // if ((CompiledPaths & current_path) != Path::kNone) { // [2] |
| // if (current_path == the_path) { // [3] |
| // PopulateTrMulParams<current_path, ...>(the_path, params); |
| // return; |
| // } |
| // } |
| // } |
| // } |
| // |
| // |
| // |
| // [1] - Done by the main definition of PathSearchCountdown. The `bit--` is |
| // done in the recursion of PathSearchOnlyCompiledPaths. |
| // [2] - Done by PathSearchOnlyCompiledPaths's partial template |
| // specialization on InCompiledPaths. This is the check which necessitates |
| // doing the whole computation at C++ compile time. |
| // [3] - Done by the `if` in the main definition of |
| // PathSearchOnlyCompiledPaths. |
| // |
| // The template metaprogramming is necessary because: |
| // - In `PopulateTrMulParams<current_path, ...>`, current_path must be a C++ |
| // compile-time constant. |
| // - PopulateTrMulParamsAllCompiledPaths must not instantiate |
| // inner loops for paths that are not in CompiledPaths, since that can result in |
| // bogus instantiations which cause a compile time failure. |
| template <Path CompiledPaths, int BitNumber, typename LhsScalar, |
| typename RhsScalar, typename AccumScalar, typename DstScalar> |
| struct PathSearchCountdown; |
| |
| template <Path CompiledPaths, bool InCompiledPaths, int BitNumber, |
| typename LhsScalar, typename RhsScalar, typename AccumScalar, |
| typename DstScalar> |
| struct PathSearchOnlyCompiledPaths { |
| static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber); |
| static void Search(Path the_path, TrMulParams* params) { |
| if (kCurrentPath == the_path) { |
| PopulateTrMulParams<kCurrentPath, LhsScalar, RhsScalar, AccumScalar, |
| DstScalar>(params); |
| return; |
| } |
| PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar, |
| AccumScalar, DstScalar>::Search(the_path, params); |
| } |
| }; |
| |
| // Skip this iteration if CompiledPaths doesn't contain the specified path. |
| template <Path CompiledPaths, int BitNumber, typename LhsScalar, |
| typename RhsScalar, typename AccumScalar, typename DstScalar> |
| struct PathSearchOnlyCompiledPaths<CompiledPaths, false, BitNumber, LhsScalar, |
| RhsScalar, AccumScalar, DstScalar> { |
| static void Search(Path the_path, TrMulParams* params) { |
| PathSearchCountdown<CompiledPaths, BitNumber - 1, LhsScalar, RhsScalar, |
| AccumScalar, DstScalar>::Search(the_path, params); |
| } |
| }; |
| |
| template <Path CompiledPaths, int BitNumber, typename LhsScalar, |
| typename RhsScalar, typename AccumScalar, typename DstScalar> |
| struct PathSearchCountdown { |
| static constexpr Path kCurrentPath = static_cast<Path>(1 << BitNumber); |
| static void Search(Path the_path, TrMulParams* params) { |
| PathSearchOnlyCompiledPaths< |
| CompiledPaths, (CompiledPaths & kCurrentPath) != Path::kNone, BitNumber, |
| LhsScalar, RhsScalar, AccumScalar, DstScalar>::Search(the_path, params); |
| } |
| }; |
| |
| // Termination of the countdown. If the counter reaches -1, then we haven't |
| // found the specified path. |
| template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, |
| typename AccumScalar, typename DstScalar> |
| struct PathSearchCountdown<CompiledPaths, -1, LhsScalar, RhsScalar, AccumScalar, |
| DstScalar> { |
| static void Search(Path, TrMulParams*) { RUY_DCHECK(false); } |
| }; |
| |
| template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, |
| typename AccumScalar, typename DstScalar> |
| void PopulateTrMulParamsAllCompiledPaths(Path the_path, TrMulParams* params) { |
| return PathSearchCountdown<CompiledPaths, 8 * sizeof(Path) - 1, LhsScalar, |
| RhsScalar, AccumScalar, |
| DstScalar>::Search(the_path, params); |
| } |
| |
| bool FallBackToStandardCpp(const MatLayout& lhs_layout, |
| const MatLayout& rhs_layout, |
| ChannelDimension channel_dimension); |
| |
| template <typename AccumScalar, typename DstScalar> |
| void StoreMulParams(const MulParams<AccumScalar, DstScalar>& mul_params, |
| ChannelDimension channel_dimension, void* dst) { |
| using MulParamsType = MulParams<AccumScalar, DstScalar>; |
| static_assert(alignof(MulParamsType) <= kMaxMulParamsAlignment, ""); |
| static_assert(sizeof(MulParamsType) <= kMaxMulParamsSize, ""); |
| std::memcpy(dst, &mul_params, sizeof(MulParamsType)); |
| static_assert(sizeof(ChannelDimension) == 1, ""); |
| static_cast<MulParamsType*>(dst)->set_channel_dimension(channel_dimension); |
| } |
| |
| template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, |
| typename AccumScalar, typename DstScalar> |
| void CreateTrMulParamsAssumingColMajorDst( |
| const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs, |
| const Mat<DstScalar>& dst, |
| const MulParams<AccumScalar, DstScalar>& mul_params, |
| ChannelDimension channel_dimension, Path the_path, TrMulParams* params) { |
| RUY_DCHECK(IsColMajor(dst.layout)); |
| |
| // Fill in the fields we already know. |
| params->src[Side::kLhs] = EraseType(lhs); |
| params->src[Side::kRhs] = EraseType(rhs); |
| params->dst = EraseType(dst); |
| StoreMulParams(mul_params, channel_dimension, params->mul_params_bytes); |
| |
| if (FallBackToStandardCpp(lhs.layout, rhs.layout, channel_dimension)) { |
| return PopulateTrMulParams<Path::kStandardCpp, LhsScalar, RhsScalar, |
| AccumScalar, DstScalar>(params); |
| } |
| |
| PopulateTrMulParamsAllCompiledPaths<CompiledPaths, LhsScalar, RhsScalar, |
| AccumScalar, DstScalar>(the_path, params); |
| } |
| |
| } // namespace detail |
| |
| inline ChannelDimension Transpose(ChannelDimension channel_dimension) { |
| return channel_dimension == ChannelDimension::kCol ? ChannelDimension::kRow |
| : ChannelDimension::kCol; |
| } |
| |
| // CreateTrMulParams's output is a TrMulParams object that encodes |
| // all of the input information required by the middle-end, that is, the TrMul |
| // function. |
| // |
| // CreateTrMulParams performs the following tasks: |
| // 1. Reduce to the case of column-major destination, by transposing the |
| // whole problem as needed. |
| // 2. Select the single code path to be taken, out of the set of paths |
| // described by the `CompiledPaths` template parameter, based on the |
| // runtime input parameter `the_path`. |
| // 3. Perform type-erasure, converting templatized typed input parameters |
| // to the un-typed data stored in TrMulParams. |
| template <Path CompiledPaths, typename LhsScalar, typename RhsScalar, |
| typename AccumScalar, typename DstScalar> |
| void CreateTrMulParams(const Mat<LhsScalar>& lhs, const Mat<RhsScalar>& rhs, |
| const Mat<DstScalar>& dst, |
| const MulParams<AccumScalar, DstScalar>& mul_params, |
| Path the_path, TrMulParams* params) { |
| ChannelDimension channel_dimension = mul_params.channel_dimension(); |
| if (IsColMajor(dst.layout)) { |
| detail::CreateTrMulParamsAssumingColMajorDst<CompiledPaths>( |
| lhs, rhs, dst, mul_params, channel_dimension, the_path, params); |
| } else { |
| detail::CreateTrMulParamsAssumingColMajorDst<CompiledPaths>( |
| rhs, lhs, Transpose(dst), mul_params, Transpose(channel_dimension), |
| the_path, params); |
| } |
| } |
| |
| } // namespace ruy |
| |
| #endif // RUY_RUY_CREATE_TRMUL_PARAMS_H_ |