| //===- Spmdization.cpp --------------------------------------------- C++ --===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| |
| #include "mlir/Dialect/Mesh/Transforms/Spmdization.h" |
| |
| #include "mlir/Dialect/Func/IR/FuncOps.h" |
| #include "mlir/Dialect/Mesh/IR/MeshDialect.h" |
| #include "mlir/Dialect/Mesh/IR/MeshOps.h" |
| #include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h" |
| #include "mlir/Dialect/Tensor/IR/Tensor.h" |
| #include "mlir/IR/Builders.h" |
| #include "mlir/IR/BuiltinAttributes.h" |
| #include "mlir/IR/BuiltinTypeInterfaces.h" |
| #include "mlir/IR/BuiltinTypes.h" |
| #include "mlir/IR/Diagnostics.h" |
| #include "mlir/IR/IRMapping.h" |
| #include "mlir/IR/ImplicitLocOpBuilder.h" |
| #include "mlir/IR/Location.h" |
| #include "mlir/IR/MLIRContext.h" |
| #include "mlir/IR/SymbolTable.h" |
| #include "mlir/IR/Value.h" |
| #include "mlir/Interfaces/ControlFlowInterfaces.h" |
| #include "mlir/Interfaces/FunctionInterfaces.h" |
| #include "mlir/Pass/Pass.h" |
| #include "mlir/Support/LLVM.h" |
| #include "mlir/Support/LogicalResult.h" |
| #include "llvm/ADT/APInt.h" |
| #include "llvm/ADT/DenseSet.h" |
| #include "llvm/ADT/STLExtras.h" |
| #include "llvm/ADT/SmallVector.h" |
| #include "llvm/Support/Casting.h" |
| #include <iterator> |
| #include <optional> |
| #include <tuple> |
| #include <type_traits> |
| |
| namespace mlir::mesh { |
| |
| template <typename SourceAxes, typename TargetAxes> |
| static bool arePartialAxesCompatible(const SourceAxes &sourceAxes, |
| const TargetAxes &targetAxes) { |
| return llvm::all_of(targetAxes, [&sourceAxes](auto &targetAxis) { |
| return sourceAxes.contains(targetAxis); |
| }); |
| } |
| |
| // Return the reduced value and its corresponding sharding. |
| // Example: |
| // sourceSharding = <@mesh_1d, [[0]], partial = sum[0]> |
| // targetSharding = <@mesh_1d, [[]]> |
| // Then will apply all-reduce on the source value |
| // and return it with the sharding <@mesh_1d, [[0]]>. |
| static std::tuple<TypedValue<ShapedType>, MeshShardingAttr> |
| handlePartialAxesDuringResharding(OpBuilder &builder, |
| MeshShardingAttr sourceSharding, |
| MeshShardingAttr targetSharding, |
| TypedValue<ShapedType> sourceShard) { |
| if (sourceSharding.getPartialAxes().empty() && |
| targetSharding.getPartialAxes().empty()) { |
| return {sourceShard, sourceSharding}; |
| } |
| assert(targetSharding.getPartialAxes().empty() || |
| (!sourceSharding.getPartialAxes().empty() && |
| sourceSharding.getPartialType() == targetSharding.getPartialType())); |
| using Axis = std::decay_t<decltype(sourceSharding.getPartialAxes().front())>; |
| using AxisSet = llvm::SmallDenseSet<Axis>; |
| AxisSet sourceShardingPartialAxesSet(sourceSharding.getPartialAxes().begin(), |
| sourceSharding.getPartialAxes().end()); |
| AxisSet targetShardingPartialAxesSet(targetSharding.getPartialAxes().begin(), |
| targetSharding.getPartialAxes().end()); |
| assert(arePartialAxesCompatible(sourceShardingPartialAxesSet, |
| targetShardingPartialAxesSet)); |
| llvm::SmallVector<MeshAxis> allReduceMeshAxes; |
| llvm::copy_if(sourceShardingPartialAxesSet, |
| std::back_inserter(allReduceMeshAxes), |
| [&targetShardingPartialAxesSet](Axis a) { |
| return !targetShardingPartialAxesSet.contains(a); |
| }); |
| if (allReduceMeshAxes.empty()) { |
| return {sourceShard, sourceSharding}; |
| } |
| |
| builder.setInsertionPointAfterValue(sourceShard); |
| TypedValue<ShapedType> resultValue = cast<TypedValue<ShapedType>>( |
| builder |
| .create<AllReduceOp>(sourceShard.getLoc(), sourceShard.getType(), |
| sourceSharding.getMesh().getLeafReference(), |
| allReduceMeshAxes, sourceShard, |
| sourceSharding.getPartialType()) |
| .getResult()); |
| |
| llvm::SmallVector<MeshAxis> remainingPartialAxes; |
| llvm::copy_if(sourceShardingPartialAxesSet, |
| std::back_inserter(allReduceMeshAxes), |
| [&targetShardingPartialAxesSet](Axis a) { |
| return targetShardingPartialAxesSet.contains(a); |
| }); |
| MeshShardingAttr resultSharding = |
| MeshShardingAttr::get(builder.getContext(), sourceSharding.getMesh(), |
| sourceSharding.getSplitAxes(), remainingPartialAxes, |
| sourceSharding.getPartialType()); |
| return {resultValue, resultSharding}; |
| } |
| |
| static MeshShardingAttr |
| targetShardingInSplitLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, |
| int64_t splitTensorAxis, MeshAxis splitMeshAxis) { |
| SmallVector<MeshAxesAttr> targetShardingSplitAxes = |
| llvm::to_vector(sourceSharding.getSplitAxes()); |
| while (static_cast<int64_t>(targetShardingSplitAxes.size()) <= |
| splitTensorAxis) { |
| targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); |
| } |
| auto targetSplitAxes = |
| llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); |
| targetSplitAxes.push_back(splitMeshAxis); |
| targetShardingSplitAxes[splitTensorAxis] = |
| MeshAxesAttr::get(ctx, targetSplitAxes); |
| return MeshShardingAttr::get( |
| ctx, sourceSharding.getMesh(), targetShardingSplitAxes, |
| sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); |
| } |
| |
| // Split a replicated tensor along a mesh axis. |
| // e.g. [[0, 1]] -> [[0, 1, 2]]. |
| // Returns the spmdized target value with its sharding. |
| static std::tuple<TypedValue<ShapedType>, MeshShardingAttr> |
| splitLastAxisInResharding(ImplicitLocOpBuilder &builder, |
| MeshShardingAttr sourceSharding, |
| TypedValue<ShapedType> sourceShard, MeshOp mesh, |
| int64_t splitTensorAxis, MeshAxis splitMeshAxis) { |
| TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( |
| builder |
| .create<AllSliceOp>(sourceShard, mesh, |
| ArrayRef<MeshAxis>(splitMeshAxis), |
| splitTensorAxis) |
| .getResult()); |
| MeshShardingAttr targetSharding = targetShardingInSplitLastAxis( |
| builder.getContext(), sourceSharding, splitTensorAxis, splitMeshAxis); |
| return {targetShard, targetSharding}; |
| } |
| |
| // Detect if the resharding is of type e.g. |
| // [[0, 1]] -> [[0, 1, 2]]. |
| // If detected, returns the corresponding tensor axis mesh axis pair. |
| // Does not detect insertions like |
| // [[0, 1]] -> [[0, 2, 1]]. |
| static std::optional<std::tuple<int64_t, MeshAxis>> |
| detectSplitLastAxisInResharding(MeshShardingAttr sourceSharding, |
| MeshShardingAttr targetSharding) { |
| for (size_t tensorAxis = 0; tensorAxis < targetSharding.getSplitAxes().size(); |
| ++tensorAxis) { |
| if (sourceSharding.getSplitAxes().size() > tensorAxis) { |
| if (sourceSharding.getSplitAxes()[tensorAxis].size() + 1 != |
| targetSharding.getSplitAxes()[tensorAxis].size()) { |
| continue; |
| } |
| if (!llvm::equal( |
| sourceSharding.getSplitAxes()[tensorAxis].asArrayRef(), |
| llvm::make_range( |
| targetSharding.getSplitAxes()[tensorAxis] |
| .asArrayRef() |
| .begin(), |
| targetSharding.getSplitAxes()[tensorAxis].asArrayRef().end() - |
| 1))) { |
| continue; |
| } |
| } else { |
| if (targetSharding.getSplitAxes()[tensorAxis].size() != 1) { |
| continue; |
| } |
| } |
| return std::make_tuple( |
| tensorAxis, |
| targetSharding.getSplitAxes()[tensorAxis].asArrayRef().back()); |
| } |
| return std::nullopt; |
| } |
| |
| static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>> |
| trySplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, |
| MeshShardingAttr sourceSharding, |
| MeshShardingAttr targetSharding, |
| TypedValue<ShapedType> sourceShard) { |
| if (auto detectRes = |
| detectSplitLastAxisInResharding(sourceSharding, targetSharding)) { |
| auto [tensorAxis, meshAxis] = detectRes.value(); |
| return splitLastAxisInResharding(builder, sourceSharding, sourceShard, mesh, |
| tensorAxis, meshAxis); |
| } |
| |
| return std::nullopt; |
| } |
| |
| // Detect if the resharding is of type e.g. |
| // [[0, 1, 2]] -> [[0, 1]]. |
| // If detected, returns the corresponding tensor axis mesh axis pair. |
| static std::optional<std::tuple<int64_t, MeshAxis>> |
| detectUnsplitLastAxisInResharding(MeshShardingAttr sourceSharding, |
| MeshShardingAttr targetSharding) { |
| for (size_t tensorAxis = 0; tensorAxis < sourceSharding.getSplitAxes().size(); |
| ++tensorAxis) { |
| if (targetSharding.getSplitAxes().size() > tensorAxis) { |
| if (sourceSharding.getSplitAxes()[tensorAxis].size() != |
| targetSharding.getSplitAxes()[tensorAxis].size() + 1) |
| continue; |
| if (!llvm::equal( |
| llvm::make_range( |
| sourceSharding.getSplitAxes()[tensorAxis] |
| .asArrayRef() |
| .begin(), |
| sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().end() - |
| 1), |
| targetSharding.getSplitAxes()[tensorAxis].asArrayRef())) |
| continue; |
| } else { |
| if (sourceSharding.getSplitAxes()[tensorAxis].size() != 1) |
| continue; |
| } |
| return std::make_tuple( |
| tensorAxis, |
| sourceSharding.getSplitAxes()[tensorAxis].asArrayRef().back()); |
| } |
| return std::nullopt; |
| } |
| |
| static MeshShardingAttr |
| targetShardingInUnsplitLastAxis(MLIRContext *ctx, |
| MeshShardingAttr sourceSharding, |
| int64_t splitTensorAxis) { |
| SmallVector<MeshAxesAttr> targetShardingSplitAxes = |
| llvm::to_vector(sourceSharding.getSplitAxes()); |
| assert(static_cast<int64_t>(targetShardingSplitAxes.size()) > |
| splitTensorAxis); |
| auto targetSplitAxes = |
| llvm::to_vector(targetShardingSplitAxes[splitTensorAxis].asArrayRef()); |
| |
| targetSplitAxes.pop_back(); |
| targetShardingSplitAxes[splitTensorAxis] = |
| MeshAxesAttr::get(ctx, targetSplitAxes); |
| return MeshShardingAttr::get( |
| ctx, sourceSharding.getMesh(), targetShardingSplitAxes, |
| sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); |
| } |
| |
| static ShapedType allGatherResultShapeInUnsplitLastAxis( |
| ShapedType sourceShape, int64_t splitCount, int64_t splitTensorAxis) { |
| SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape()); |
| targetShape[splitTensorAxis] = |
| gatherDimension(targetShape[splitTensorAxis], splitCount); |
| return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); |
| } |
| |
| static std::tuple<TypedValue<ShapedType>, MeshShardingAttr> |
| unsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, |
| MeshShardingAttr sourceSharding, |
| ShapedType sourceUnshardedShape, |
| TypedValue<ShapedType> sourceShard, MeshOp mesh, |
| int64_t splitTensorAxis, MeshAxis splitMeshAxis) { |
| MLIRContext *ctx = builder.getContext(); |
| builder.setInsertionPointAfterValue(sourceShard); |
| |
| MeshShardingAttr targetSharding = |
| targetShardingInUnsplitLastAxis(ctx, sourceSharding, splitMeshAxis); |
| ShapedType allGatherResultShape = allGatherResultShapeInUnsplitLastAxis( |
| sourceShard.getType(), mesh.getShape()[splitMeshAxis], splitTensorAxis); |
| Value allGatherResult = builder.create<AllGatherOp>( |
| RankedTensorType::get(allGatherResultShape.getShape(), |
| allGatherResultShape.getElementType()), |
| mesh.getSymName(), SmallVector<MeshAxis>({splitMeshAxis}), sourceShard, |
| APInt(64, splitTensorAxis)); |
| ShapedType targetShape = |
| shardShapedType(sourceUnshardedShape, mesh, targetSharding); |
| TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( |
| builder.create<tensor::CastOp>(targetShape, allGatherResult).getResult()); |
| return {targetShard, targetSharding}; |
| } |
| |
| static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>> |
| tryUnsplitLastAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, |
| MeshShardingAttr sourceSharding, |
| MeshShardingAttr targetSharding, |
| ShapedType sourceUnshardedShape, |
| TypedValue<ShapedType> sourceShard) { |
| if (auto detectRes = |
| detectUnsplitLastAxisInResharding(sourceSharding, targetSharding)) { |
| auto [tensorAxis, meshAxis] = detectRes.value(); |
| return unsplitLastAxisInResharding(builder, sourceSharding, |
| sourceUnshardedShape, sourceShard, mesh, |
| tensorAxis, meshAxis); |
| } |
| |
| return std::nullopt; |
| } |
| |
| // Detect if the resharding is of type e.g. |
| // [[0, 1], [2]] -> [[0], [1, 2]]. |
| // Only moving the last axis counts. |
| // If detected, returns the corresponding (source_tensor_axis, |
| // target_tensor_axis, mesh_axis) tuple. |
| static std::optional<std::tuple<int64_t, int64_t, MeshAxis>> |
| detectMoveLastSplitAxisInResharding(MeshShardingAttr sourceSharding, |
| MeshShardingAttr targetSharding) { |
| for (size_t sourceTensorAxis = 0; |
| sourceTensorAxis < sourceSharding.getSplitAxes().size(); |
| ++sourceTensorAxis) { |
| for (size_t targetTensorAxis = 0; |
| targetTensorAxis < targetSharding.getSplitAxes().size(); |
| ++targetTensorAxis) { |
| if (sourceTensorAxis == targetTensorAxis) |
| continue; |
| if (sourceSharding.getSplitAxes()[sourceTensorAxis].empty() || |
| targetSharding.getSplitAxes()[targetTensorAxis].empty() || |
| sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back() != |
| targetSharding.getSplitAxes()[targetTensorAxis] |
| .asArrayRef() |
| .back()) |
| continue; |
| if (!llvm::equal( |
| llvm::make_range(sourceSharding.getSplitAxes()[sourceTensorAxis] |
| .asArrayRef() |
| .begin(), |
| sourceSharding.getSplitAxes()[sourceTensorAxis] |
| .asArrayRef() |
| .end() - |
| 1), |
| llvm::make_range(targetSharding.getSplitAxes()[targetTensorAxis] |
| .asArrayRef() |
| .begin(), |
| targetSharding.getSplitAxes()[targetTensorAxis] |
| .asArrayRef() |
| .end() - |
| 1))) |
| continue; |
| return std::make_tuple( |
| sourceTensorAxis, targetTensorAxis, |
| sourceSharding.getSplitAxes()[sourceTensorAxis].asArrayRef().back()); |
| } |
| } |
| return std::nullopt; |
| } |
| |
| static MeshShardingAttr |
| targetShardingInMoveLastAxis(MLIRContext *ctx, MeshShardingAttr sourceSharding, |
| int64_t sourceTensorAxis, |
| int64_t targetTensorAxis) { |
| SmallVector<MeshAxesAttr> targetShardingSplitAxes = |
| llvm::to_vector(sourceSharding.getSplitAxes()); |
| while (static_cast<int64_t>(targetShardingSplitAxes.size()) <= |
| targetTensorAxis) { |
| targetShardingSplitAxes.push_back(MeshAxesAttr::get(ctx, {})); |
| } |
| |
| auto sourceSplitAxes = |
| llvm::to_vector(targetShardingSplitAxes[sourceTensorAxis].asArrayRef()); |
| assert(!sourceSplitAxes.empty()); |
| auto meshAxis = sourceSplitAxes.back(); |
| sourceSplitAxes.pop_back(); |
| targetShardingSplitAxes[sourceTensorAxis] = |
| MeshAxesAttr::get(ctx, sourceSplitAxes); |
| |
| auto targetSplitAxes = |
| llvm::to_vector(targetShardingSplitAxes[targetTensorAxis].asArrayRef()); |
| targetSplitAxes.push_back(meshAxis); |
| targetShardingSplitAxes[targetTensorAxis] = |
| MeshAxesAttr::get(ctx, targetSplitAxes); |
| |
| return MeshShardingAttr::get( |
| ctx, sourceSharding.getMesh(), targetShardingSplitAxes, |
| sourceSharding.getPartialAxes(), sourceSharding.getPartialType()); |
| } |
| |
| static ShapedType allToAllResultShapeInMoveLastAxis(ShapedType sourceShape, |
| int64_t splitCount, |
| int64_t sourceTensorAxis, |
| int64_t targetTensorAxis) { |
| SmallVector<int64_t> targetShape = llvm::to_vector(sourceShape.getShape()); |
| targetShape[sourceTensorAxis] = |
| gatherDimension(targetShape[sourceTensorAxis], splitCount); |
| targetShape[targetTensorAxis] = |
| shardDimension(targetShape[targetTensorAxis], splitCount); |
| return sourceShape.cloneWith(targetShape, sourceShape.getElementType()); |
| } |
| |
| static std::tuple<TypedValue<ShapedType>, MeshShardingAttr> |
| moveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, |
| MeshShardingAttr sourceSharding, |
| ShapedType sourceUnshardedShape, |
| TypedValue<ShapedType> sourceShard, |
| int64_t sourceTensorAxis, |
| int64_t targetTensorAxis, MeshAxis meshAxis) { |
| MLIRContext *ctx = builder.getContext(); |
| builder.setInsertionPointAfterValue(sourceShard); |
| |
| MeshShardingAttr targetSharding = targetShardingInMoveLastAxis( |
| ctx, sourceSharding, sourceTensorAxis, targetTensorAxis); |
| ShapedType allToAllResultShape = allToAllResultShapeInMoveLastAxis( |
| sourceShard.getType(), mesh.getShape()[meshAxis], sourceTensorAxis, |
| targetTensorAxis); |
| Value allToAllResult = builder.create<AllToAllOp>( |
| RankedTensorType::get(allToAllResultShape.getShape(), |
| allToAllResultShape.getElementType()), |
| mesh.getSymName(), SmallVector<MeshAxis>({meshAxis}), sourceShard, |
| APInt(64, targetTensorAxis), APInt(64, sourceTensorAxis)); |
| ShapedType targetShape = |
| shardShapedType(sourceUnshardedShape, mesh, targetSharding); |
| TypedValue<ShapedType> targetShard = cast<TypedValue<ShapedType>>( |
| builder.create<tensor::CastOp>(targetShape, allToAllResult).getResult()); |
| return {targetShard, targetSharding}; |
| } |
| |
| static std::optional<std::tuple<TypedValue<ShapedType>, MeshShardingAttr>> |
| tryMoveLastSplitAxisInResharding(ImplicitLocOpBuilder &builder, MeshOp mesh, |
| MeshShardingAttr sourceSharding, |
| MeshShardingAttr targetSharding, |
| ShapedType sourceUnshardedShape, |
| TypedValue<ShapedType> sourceShard) { |
| if (auto detectRes = |
| detectMoveLastSplitAxisInResharding(sourceSharding, targetSharding)) { |
| auto [sourceTensorAxis, targetTensorAxis, meshAxis] = detectRes.value(); |
| return moveLastSplitAxisInResharding( |
| builder, mesh, sourceSharding, sourceUnshardedShape, sourceShard, |
| sourceTensorAxis, targetTensorAxis, meshAxis); |
| } |
| |
| return std::nullopt; |
| } |
| |
| // Handles only resharding on a 1D mesh. |
| // Currently the sharded tensor axes must be exactly divisible by the single |
| // mesh axis size. |
| static TypedValue<ShapedType> |
| reshardOn1DMesh(ImplicitLocOpBuilder &builder, MeshOp mesh, |
| MeshShardingAttr sourceSharding, |
| MeshShardingAttr targetSharding, |
| TypedValue<ShapedType> sourceUnshardedValue, |
| TypedValue<ShapedType> sourceShard) { |
| assert(sourceShard.getType() == |
| shardShapedType(sourceUnshardedValue.getType(), mesh, sourceSharding)); |
| [[maybe_unused]] ShapedType targetShardType = |
| shardShapedType(sourceUnshardedValue.getType(), mesh, targetSharding); |
| assert(sourceShard.getType().getRank() == targetShardType.getRank()); |
| assert(mesh.getRank() == 1 && "Only 1D meshes are currently supported."); |
| |
| auto [reducedSourceShard, reducedSourceSharding] = |
| handlePartialAxesDuringResharding(builder, sourceSharding, targetSharding, |
| sourceShard); |
| |
| if (reducedSourceSharding == targetSharding) { |
| return reducedSourceShard; |
| } |
| |
| TypedValue<ShapedType> targetShard; |
| MeshShardingAttr actualTargetSharding; |
| if (auto tryRes = tryMoveLastSplitAxisInResharding( |
| builder, mesh, reducedSourceSharding, targetSharding, |
| sourceUnshardedValue.getType(), reducedSourceShard)) { |
| std::tie(targetShard, actualTargetSharding) = tryRes.value(); |
| } else if (auto tryRes = trySplitLastAxisInResharding( |
| builder, mesh, reducedSourceSharding, targetSharding, |
| reducedSourceShard)) { |
| std::tie(targetShard, actualTargetSharding) = tryRes.value(); |
| } else if (auto tryRes = tryUnsplitLastAxisInResharding( |
| builder, mesh, reducedSourceSharding, targetSharding, |
| sourceUnshardedValue.getType(), reducedSourceShard)) { |
| std::tie(targetShard, actualTargetSharding) = tryRes.value(); |
| } else { |
| assert(false && "Did not find any pattern to apply."); |
| } |
| |
| assert(actualTargetSharding == targetSharding); |
| assert(targetShard.getType() == targetShardType); |
| return targetShard; |
| } |
| |
| TypedValue<ShapedType> reshard(ImplicitLocOpBuilder &builder, MeshOp mesh, |
| MeshShardingAttr sourceSharding, |
| MeshShardingAttr targetSharding, |
| TypedValue<ShapedType> sourceUnshardedValue, |
| TypedValue<ShapedType> sourceShard) { |
| // Resort to handling only 1D meshes since the general case is complicated if |
| // it needs to be communication efficient in terms of minimizing the data |
| // transfered between devices. |
| return reshardOn1DMesh(builder, mesh, sourceSharding, targetSharding, |
| sourceUnshardedValue, sourceShard); |
| } |
| |
| TypedValue<ShapedType> reshard(OpBuilder &builder, MeshOp mesh, ShardOp source, |
| ShardOp target, |
| TypedValue<ShapedType> sourceShardValue) { |
| assert(!source.getAnnotateForUsers()); |
| assert(target.getAnnotateForUsers()); |
| assert(source.getResult() == target.getOperand()); |
| ImplicitLocOpBuilder implicitLocOpBuilder(target->getLoc(), builder); |
| return reshard( |
| implicitLocOpBuilder, mesh, source.getShard(), target.getShard(), |
| cast<TypedValue<ShapedType>>(source.getSrc()), sourceShardValue); |
| } |
| |
| TypedValue<ShapedType> reshard(OpBuilder &builder, ShardOp source, |
| ShardOp target, |
| TypedValue<ShapedType> sourceShardValue, |
| SymbolTableCollection &symbolTableCollection) { |
| MeshOp srcMesh = getMesh(source, symbolTableCollection); |
| assert(srcMesh && srcMesh == getMesh(target, symbolTableCollection)); |
| return reshard(builder, srcMesh, source, target, sourceShardValue); |
| } |
| |
| void reshardingRegisterDependentDialects(DialectRegistry ®istry) { |
| registry.insert<mesh::MeshDialect, tensor::TensorDialect>(); |
| } |
| |
| #define GEN_PASS_DEF_SPMDIZATION |
| #include "mlir/Dialect/Mesh/Transforms/Passes.h.inc" |
| |
| using UnshardedToShardedValueMap = DenseMap<Value, Value>; |
| |
| // Get the types of block arguments for an spmdized block. |
| // Reads the sharding annotations of the arguments to deduce the sharded types. |
| // Types that are not ranked tensors are left unchanged. |
| SmallVector<Type> |
| shardedBlockArgumentTypes(Block &block, |
| SymbolTableCollection &symbolTableCollection) { |
| SmallVector<Type> res; |
| llvm::transform( |
| block.getArguments(), std::back_inserter(res), |
| [&symbolTableCollection](BlockArgument arg) { |
| auto rankedTensorArg = dyn_cast<TypedValue<RankedTensorType>>(arg); |
| if (!rankedTensorArg) { |
| return arg.getType(); |
| } |
| |
| assert(rankedTensorArg.hasOneUse()); |
| Operation *useOp = *rankedTensorArg.getUsers().begin(); |
| ShardOp shardOp = llvm::dyn_cast<ShardOp>(useOp); |
| assert(shardOp); |
| MeshOp mesh = getMesh(shardOp, symbolTableCollection); |
| return cast<Type>(shardShapedType(rankedTensorArg.getType(), mesh, |
| shardOp.getShardAttr())); |
| }); |
| return res; |
| } |
| |
| static LogicalResult spmdizeOperation( |
| Operation &op, ArrayRef<Value> spmdizedOperands, |
| ArrayRef<MeshShardingAttr> operandShardings, |
| ArrayRef<MeshShardingAttr> resultShardings, IRMapping &spmdizationMap, |
| SymbolTableCollection &symbolTableCollection, OpBuilder &builder) { |
| ShardingInterface shardingInterface = llvm::dyn_cast<ShardingInterface>(op); |
| if (!shardingInterface) { |
| // If there is no sharding interface we are conservative and assume that |
| // the op should be fully replicated no all devices. |
| spmdizeFullyReplicatedOperation(op, spmdizedOperands, operandShardings, |
| resultShardings, spmdizationMap, |
| symbolTableCollection, builder); |
| } else { |
| if (failed(shardingInterface.spmdize(spmdizedOperands, operandShardings, |
| resultShardings, spmdizationMap, |
| symbolTableCollection, builder))) { |
| return failure(); |
| } |
| } |
| |
| assert(llvm::all_of(op.getResults(), [&spmdizationMap](OpResult result) { |
| return spmdizationMap.contains(result); |
| })); |
| |
| return success(); |
| } |
| |
| // Retrieve the sharding annotations for the operands of the given operation. |
| // If the type is not a ranked tensor it is not require to have an annotation. |
| static SmallVector<MeshShardingAttr> getOperandShardings(Operation &op) { |
| SmallVector<MeshShardingAttr> res; |
| res.reserve(op.getNumOperands()); |
| llvm::transform(op.getOperands(), std::back_inserter(res), [](Value operand) { |
| TypedValue<RankedTensorType> rankedTensor = |
| dyn_cast<TypedValue<RankedTensorType>>(operand); |
| if (!rankedTensor) { |
| return MeshShardingAttr(); |
| } |
| |
| Operation *definingOp = operand.getDefiningOp(); |
| assert(definingOp); |
| ShardOp shardOp = llvm::cast<ShardOp>(definingOp); |
| return shardOp.getShard(); |
| }); |
| return res; |
| } |
| |
| // Retrieve the sharding annotations for the results of the given operation. |
| // If the type is not a ranked tensor it is not require to have an annotation. |
| static SmallVector<MeshShardingAttr> getResultShardings(Operation &op) { |
| SmallVector<MeshShardingAttr> res; |
| res.reserve(op.getNumResults()); |
| llvm::transform(op.getResults(), std::back_inserter(res), |
| [](OpResult result) { |
| TypedValue<RankedTensorType> rankedTensor = |
| dyn_cast<TypedValue<RankedTensorType>>(result); |
| if (!rankedTensor) { |
| return MeshShardingAttr(); |
| } |
| |
| assert(result.hasOneUse()); |
| Operation *userOp = *result.getUsers().begin(); |
| ShardOp shardOp = llvm::cast<ShardOp>(userOp); |
| return shardOp.getShard(); |
| }); |
| return res; |
| } |
| |
| static LogicalResult |
| spmdizeOperation(ShardOp shardOp, IRMapping &spmdizationMap, |
| SymbolTableCollection &symbolTableCollection, |
| OpBuilder &builder) { |
| Value targetSpmdValue; |
| |
| // Check if 2 shard ops are chained. If not there is no need for resharding |
| // as the source and target shared the same sharding. |
| ShardOp srcShardOp = |
| dyn_cast_or_null<ShardOp>(shardOp.getOperand().getDefiningOp()); |
| if (!srcShardOp) { |
| targetSpmdValue = spmdizationMap.lookup(shardOp.getOperand()); |
| } else { |
| // Insert resharding. |
| assert(!srcShardOp.getAnnotateForUsers() && shardOp.getAnnotateForUsers()); |
| TypedValue<ShapedType> srcSpmdValue = cast<TypedValue<ShapedType>>( |
| spmdizationMap.lookup(srcShardOp.getOperand())); |
| targetSpmdValue = reshard(builder, srcShardOp, shardOp, srcSpmdValue, |
| symbolTableCollection); |
| } |
| |
| assert(!spmdizationMap.contains(shardOp.getResult())); |
| spmdizationMap.map(shardOp.getResult(), targetSpmdValue); |
| return success(); |
| } |
| |
| static LogicalResult |
| spmdizeOperation(Operation &op, IRMapping &spmdizationMap, |
| SymbolTableCollection &symbolTableCollection, |
| OpBuilder &builder) { |
| ShardOp shardOp = llvm::dyn_cast<ShardOp>(op); |
| if (shardOp) { |
| return spmdizeOperation(shardOp, spmdizationMap, symbolTableCollection, |
| builder); |
| } |
| |
| SmallVector<Value> spmdizedOperands; |
| llvm::transform(op.getOperands(), std::back_inserter(spmdizedOperands), |
| [&spmdizationMap](Value operand) { |
| assert(spmdizationMap.contains(operand)); |
| return spmdizationMap.lookup(operand); |
| }); |
| return spmdizeOperation(op, spmdizedOperands, getOperandShardings(op), |
| getResultShardings(op), spmdizationMap, |
| symbolTableCollection, builder); |
| } |
| |
| static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap, |
| SymbolTableCollection &symbolTableCollection, |
| OpBuilder &builder) { |
| SmallVector<Location> argLocations; |
| llvm::transform(block.getArguments(), std::back_inserter(argLocations), |
| [](BlockArgument arg) { return arg.getLoc(); }); |
| Block *newBlock = builder.createBlock( |
| block.getParent(), {}, |
| shardedBlockArgumentTypes(block, symbolTableCollection), argLocations); |
| for (auto [unshardedBlockArg, spmdizedBlockArg] : |
| llvm::zip(block.getArguments(), newBlock->getArguments())) { |
| spmdizationMap.map(unshardedBlockArg, spmdizedBlockArg); |
| } |
| |
| OpBuilder::InsertionGuard insertionGuard(builder); |
| builder.setInsertionPointToEnd(newBlock); |
| for (Operation &op : block.getOperations()) { |
| if (failed(spmdizeOperation(op, spmdizationMap, symbolTableCollection, |
| builder))) { |
| return failure(); |
| } |
| } |
| |
| return success(); |
| } |
| |
| static LogicalResult |
| spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap, |
| SymbolTableCollection &symbolTableCollection) { |
| OpBuilder builder(op.getFunctionBody()); |
| |
| // Snapshot the original blocks to not mess up the iteration when adding new |
| // blocks. |
| SmallVector<Block *> originalBlocks; |
| llvm::transform(op.getBlocks(), std::back_inserter(originalBlocks), |
| [](Block &b) { return &b; }); |
| |
| for (Block *block : originalBlocks) { |
| if (failed(spmdizeBlock(*block, spmdizationMap, symbolTableCollection, |
| builder))) { |
| return failure(); |
| } |
| } |
| |
| for (Block *block : originalBlocks) { |
| block->erase(); |
| } |
| |
| // Find a return op and change the function results signature to its operands |
| // signature. |
| Operation *returnOp = nullptr; |
| for (Block &block : op.getFunctionBody()) { |
| if (block.empty()) { |
| continue; |
| } |
| |
| if (block.back().hasTrait<OpTrait::ReturnLike>()) { |
| returnOp = &block.back(); |
| break; |
| } |
| } |
| assert(returnOp); |
| op.setType(FunctionType::get(op->getContext(), |
| op.getFunctionBody().front().getArgumentTypes(), |
| returnOp->getOperandTypes())); |
| |
| return success(); |
| } |
| |
| namespace { |
| |
| struct Spmdization : public impl::SpmdizationBase<Spmdization> { |
| void runOnOperation() override { |
| IRMapping spmdizationMap; |
| SymbolTableCollection symbolTableCollection; |
| if (failed(spmdizeFuncOp(getOperation(), spmdizationMap, |
| symbolTableCollection))) { |
| return signalPassFailure(); |
| } |
| } |
| |
| void getDependentDialects(DialectRegistry ®istry) const override { |
| reshardingRegisterDependentDialects(registry); |
| registry.insert<mesh::MeshDialect>(); |
| } |
| }; |
| |
| } // namespace |
| |
| } // namespace mlir::mesh |