| //===- Simplifications.h - Mesh Simplifications -----------------*- C++ -*-===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #ifndef MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H |
| #define MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H |
| |
| #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
| #include "mlir/IR/DialectRegistry.h" |
| #include "mlir/Support/LogicalResult.h" |
| |
| namespace mlir { |
| namespace mesh { |
| |
| // Insert resharding spmdization of the value `sourceShardValue` |
| // from sharding `source` to sharding `target`. |
| // `sourceShardValue` is the already sharded value according to `source`. |
| // |
| // Example |
| // |
| // ```mlir |
| // mesh.mesh @mesh_1d(shape = 2) |
| // ... |
| // %1 = mesh.shard %0 to <@mesh_1d, [[0]]> : tensor<2xi8> |
| // %2 = mesh.shard %1 to <@mesh_1d, [[]]> annotate_for_users: tensor<2xi8> |
| // ``` |
| // |
| // Will result in |
| // |
| // ```mlir |
| // %1 = mesh.all_gather %0 on @mesh_1d mesh_axes = [0] gather_axis = 0 : |
| // tensor<1xi8> -> tensor<2xi8> |
| // ``` |
| TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, |
| ShardOp target, |
| TypedValue<ShapedType> sourceShardValue); |
| TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source, |
| ShardOp target, |
| TypedValue<ShapedType> sourceShardValue, |
| SymbolTableCollection &symbolTableCollection); |
| |
| void reshardingRegisterDependentDialects(DialectRegistry ®istry); |
| |
| } // namespace mesh |
| } // namespace mlir |
| |
| #endif // MLIR_DIALECT_MESH_TRANSFORMS_SPMDIZATION_H |