Reshard 2x3
tensor from sharding [[0, 1]]
to sharding [[0, 1]]
on a 2x3
mesh.
unsharded 2x3
tensor
11 12 13 21 22 23
sharded on a 2x3
mesh
sharding = [[0, 1]]
mesh contents:
mesh axis 1 -----------> +----+----+----+ mesh axis 0 | | 11 | 12 | 13 | | +----+----+----+ | | 21 | 22 | 23 | | +----+----+----+ ↓
Transform into sharding = [[1, 0]]
mesh axis 1 -----------> +----+----+----+ mesh axis 0 | | 11 | 13 | 22 | | +----+----+----+ | | 12 | 21 | 23 | | +----+----+----+ ↓
Algorithm: Swap contents on devices that have the same linear index in the 2 shardings.
Reshard 2x3
tensor from sharding [[0, 1]]
to sharding [[1]]
on a 2x3
mesh.
unsharded 2x3
tensor
11 12 13 21 22 23
sharded on a 2x3
mesh
sharding = [[0, 1]]
mesh contents:
mesh axis 1 -----------> +----+----+----+ mesh axis 0 | | 11 | 12 | 13 | | +----+----+----+ | | 21 | 22 | 23 | | +----+----+----+ ↓
Transform into sharding = [[1]]
mesh axis 1 -----------> +----+----+----+ mesh axis 0 | | 11 | 12 | 13 | | | 21 | 22 | 23 | | +----+----+----+ | | 11 | 12 | 13 | | | 21 | 22 | 23 | | +----+----+----+ ↓
Algorithm: All-gather along mesh axis 0.
Reshard 4x6
tensor from sharding [[], [0, 1]]
to sharding [[], [0]]
on a 2x3
mesh.
unsharded 4x6
tensor
11 12 13 14 15 16 21 22 23 24 25 26
sharded on a 2x3
mesh
sharding = [[], [0, 1]]
mesh contents:
mesh axis 1 -----------> +----+----+----+ mesh axis 0 | | 11 | 12 | 13 | | | 21 | 22 | 23 | | +----+----+----+ | | 14 | 15 | 16 | | | 24 | 25 | 26 | | +----+----+----+ ↓
Transform into sharding = [[], [0]]
mesh axis 1 -----------> +----------+----------+ mesh axis 0 | | 11 12 13 | 11 12 13 | | | 21 22 23 | 21 22 23 | | +----------+----------+ | | 14 15 16 | 14 15 16 | | | 24 25 26 | 24 25 26 | | +----------+----------+ ↓
Algorithm: All-gather along mesh axis 1.
Reshard 4x8
tensor from sharding [[0], [1, 2]]
to sharding [[0], [2]]
on a 2x2x2
mesh.
unsharded 4x8
tensor
11 12 13 14 15 16 17 18 21 22 23 24 25 26 27 28 31 32 33 34 35 36 37 38 41 42 43 44 45 46 47 48
sharded on a 2x2x2
mesh
sharding = [[0], [1, 2]]
mesh contents:
mesh axis 2 -----------> +-------+-------+ mesh axis 1 | mesh axis 0 | | 11 12 | 13 14 | | | | 21 22 | 23 24 | | | +-------+-------+ | | | 15 16 | 17 18 | | | | 25 26 | 27 28 | | | +-------+-------+ ↓ | +-------+-------+ | | 31 32 | 33 34 | | | 41 42 | 43 44 | | +-------+-------+ | | 35 36 | 37 38 | | | 45 46 | 47 48 | | +-------+-------+ ↓
Transform into sharding = [[0], [2]]
mesh axis 2 -----------> +-------------+-------------+ mesh axis 1 | mesh axis 0 | | 11 12 13 14 | 15 16 17 18 | | | | 21 22 23 24 | 25 26 27 28 | | | +-------------+-------------+ | | | 11 12 13 14 | 15 16 17 18 | | | | 21 22 23 24 | 25 26 27 28 | | | +-------------+-------------+ ↓ | +-------------+-------------+ | | 31 32 33 34 | 35 36 37 38 | | | 41 42 43 44 | 45 46 47 48 | | +-------------+-------------+ | | 31 32 33 34 | 35 36 37 38 | | | 41 42 43 44 | 45 46 47 48 | | +-------------+-------------+ ↓
Algorithm:
Can't be done with just an all-gather along mesh axis 1. Can be handled by multiple resharding transformations [[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]
Reshard 6x6
tensor from sharding [[0], [1]]
to sharding [[1], [0]]
on a 2x3
mesh.
unsharded 6x6
tensor
11 12 13 14 15 16 21 22 23 24 25 26 31 32 33 34 35 36 41 42 43 44 45 46 51 52 53 54 55 56 61 62 63 64 65 66
sharded on a 2x3
mesh
sharding = [[0], [1]]
mesh axis 1 -----------> +-------+-------+-------+ mesh axis 0 | | 11 12 | 13 14 | 15 16 | | | 21 22 | 23 24 | 25 26 | | | 31 32 | 33 34 | 35 36 | | +-------+-------+-------+ | | 41 42 | 43 44 | 45 46 | | | 51 52 | 53 54 | 55 56 | | | 61 62 | 63 64 | 65 66 | | +-------+-------+-------+ ↓
transform to sharding = [[1], [0]]
mesh axis 1 -----------> +----------+----------+----------+ mesh axis 0 | | 11 12 13 | 31 32 33 | 51 52 53 | | | 21 22 23 | 41 42 43 | 61 62 63 | | +----------+----------+----------+ | | 14 15 16 | 34 35 36 | 54 55 56 | | | 24 25 26 | 44 45 46 | 64 65 66 | | +----------+----------+----------+ ↓ mesh axis 0 -----------> +----------+----------+ mesh axis 1 | | 11 12 13 | 14 15 16 | | | 21 22 23 | 24 25 26 | | +----------+----------+ | | 31 32 33 | 34 35 36 | | | 41 42 43 | 44 45 46 | | +----------+----------+ | | 51 52 53 | 54 55 56 | | | 61 62 63 | 64 65 66 | | +----------+----------+ ↓
Algorithm: TODO
Reshard 6x6
tensor from sharding [[0], [1]]
to sharding [[1], [0]]
on a 2x6
mesh.
unsharded 6x6 tensor
11 12 13 14 15 16 21 22 23 24 25 26 31 32 33 34 35 36 41 42 43 44 45 46 51 52 53 54 55 56 61 62 63 64 65 66
shard on 2x6
mesh
sharding = [[0], [1]]
mesh axis 1 -----------> +----+----+----+----+----+----+ mesh axis 0 | | 11 | 12 | 13 ‖ 14 | 15 | 16 | | | 21 | 22 | 23 ‖ 24 | 23 | 26 | | | 31 | 32 | 33 ‖ 34 | 35 | 36 | | +----+----+----+----+----+----+ | | 41 | 42 | 43 ‖ 44 | 45 | 46 | | | 51 | 52 | 53 ‖ 54 | 55 | 56 | | | 61 | 62 | 63 ‖ 64 | 65 | 66 | | +----+----+----+----+----+----+ ↓
transform to sharding = [[1], [0]]
mesh axis 0 -----------> +----------+----------+ mesh axis 1 | | 11 12 13 | 14 15 16 | | +----------+----------+ | | 21 22 23 | 24 25 26 | | +----------+----------+ | | 31 32 33 | 34 35 36 | | +==========+==========+ | | 41 42 43 | 44 45 46 | | +----------+----------+ | | 51 52 53 | 54 55 56 | | +----------+----------+ | | 61 62 63 | 64 65 66 | | +----------+----------+ ↓
Algorithm: TODO
Reshard KxL tensor from [[0], [1]]
to [[1], [0]]
on MxN
mesh.
M x N
mesh. K x L
tensor t
. d(m, n)
the tensor on device (m, n)
.
sharding = [[0], [1]]
Tensor shard s on each device has size (K ceildiv M, L ceildiv N)
.
d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
substitute
i <- m * (K ceildiv M) + k j <- n * (L ceildiv N) + l
m -> i floordiv (K ceildiv M) n -> j floordiv (L ceildiv N) k -> i % (K ceildiv M) l -> j % (L ceildiv N)
For the inverse map we get
t[i, j] -> d( i floordiv (K ceildiv M), j floordiv (L ceildiv N) )[ i % (K ceildiv M), j % (L ceildiv N) ]
Check:
i = 13, j = 17, M = 3, N = 4, K = 16, L = 23 t[13, 17] = d( 13 floordiv (16 ceildiv 3), 17 floordiv (23 ceilvid 4) )[ 13 % (16 ceildiv 3), 17 % (23 ceilvid 4) ] = d( 13 floordiv 6, 17 floordiv 6 )[ 13 % 6, 17 % 6 ] = d(2, 2)[1, 5] = t[ 2 * (16 ceildiv 3) + 1, 2 * (23 ceildiv 4) + 5 ] = t[ 2 * 6 + 1, 2 * 6 + 5 ] = t[13, 17]
sharding = [[1], [0]]
Tensor shard s on each device has size (K ceildiv N, L ceildiv M)
.
d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
substitute
i <- n * (K ceildiv N) + k j <- m * (L ceildiv M) + l
m -> j floordiv (L ceildiv M) n -> i floordiv (K ceildiv N) k -> i % (K ceildiv N) l -> j % (L ceildiv M)
For the inverse map we get
t[i, j] -> d( j floordiv (L ceildiv M), i floordiv (K ceildiv N) )[ i % (K ceildiv N), j % (L ceildiv M) ]
Check:
i = 9, j = 19, M = 5, N = 2, K = 27, L = 14 t[9, 19] = d( 19 floordiv (14 ceildiv 5), 9 floordiv (27 ceildiv 2) )[ 9 % (27 ceildiv 2), 19 % (14 ceildiv 5) ] = d( 19 floordiv 3, 9 floordiv 14 )[ 9 % 14 19 % 3 ] = d(6, 0)[9, 1] = t[ 0 * (27 ceildiv 2) + 9, 6 * (14 ceildiv 5) + 1 ] = t[ 0 * 14 + 9, 6 * 3 + 1 ] = t[9, 19]
sharding = [[0], [1]]
d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] t[i, j] -> d(i floordiv (K ceildiv M), j floordiv (L ceildiv N))[i % (K ceildiv M), j % (L ceildiv N)]
sharding = [[1], [0]]
d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l] t[i, j] -> d(j floordiv (L ceildiv M), i floordiv (K ceildiv N))[i % (K ceildiv N), j % (L ceildiv M)]
sharding [[0], [1]] -> [[1], [0]]
d1(m, n)
the tensor on device (m, n)
for sharding sharding [[0], [1]]
. d2(m, n)
the tensor on device (m, n)
for sharding sharding [[1], [0]]
.
d1(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] -> d2( (m * (L ceildiv M) + l) floordiv (L ceildiv M), (n * (K ceildiv N) + k) floordiv (K ceildiv N) )[ (n * (K ceildiv N) + k) % (K ceildiv N), (m * (L ceildiv M) + l) % (L ceildiv M) ] = d2(p, q)[u, v]
We want to copy the the data between devices in slices/tiles. What are the source/target tile coordinates? For a fixed (m, n, p, q)
what is the range of (k, l, u, v)
? TODO
Reshard KxL
tensor from sharding [[0], [1]]
to sharding [[1], [0]]
on a 2x3
mesh.
Device placement on a 2x3
mesh
11 12 13 <- devices 21 22 23
sharding [[0], [1]]
tensor axis 1 -----------> +----+----+----+ tensor axis 0 | | 11 | 12 | 13 | | +----+----+----+ | | 21 | 22 | 23 | | +----+----+----+ ↓
transform to sharding [[1], [0]]
tensor axis 1 -----------> +----+----+ tensor axis 0 | | 11 | 21 | | +----+----+ | | 12 | 22 | | +----+----+ | | 13 | 23 | | +----+----+ ↓
+-----------------+--------+--------+-----------------+ | | | | + + + + | 11 | 12 | 13 | + + + + | | | | +-----------------+--------+--------+-----------------+ | | | | + + + + | 21 | 22 | 23 | + + + + | | | | +-----------------+--------+--------+-----------------+ +-----------------+--------+--------+-----------------+ | | | + 11 + 21 + | | | +-----------------+--------+--------+-----------------+ | | | + 12 + 22 + | | | +-----------------+--------+--------+-----------------+ | | | + 13 + 23 + | | | +-----------------+--------+--------+-----------------+ +-----------------+--------+--------+-----------------+ | | | | | + 11 11 + 12 11 + 12 21 + 13 21 + | | | | | +-----------------+--------+--------+-----------------+ | 11 12 | 12 12 | 12 22 | 13 22 | +-----------------+--------+--------+-----------------+ | 21 12 | 22 12 | 22 22 | 23 22 | +-----------------+--------+--------+-----------------+ | | | | | + 21 13 + 22 13 + 22 23 + 23 23 + | | | | | +-----------------+--------+--------+-----------------+
If S
and T
are the source and target shard sizes along some tensor axis. Then we have a period of (S*T)/gcd(S, T)
. Then the cut pattern repeats. TODO
Reshard 6x6
tensor from sharding [[0], []]
to sharding [[], [0]]
on a 3
mesh.
unsharded 6x6
tensor
11 12 13 14 15 16 21 22 23 24 25 26 31 32 33 34 35 36 41 42 43 44 45 46 51 52 53 54 55 56 61 62 63 64 65 66
sharded on a 3
mesh
sharding = [[0], []]
+-------------------+ mesh axis 0 | | 11 12 13 14 15 16 | | | 21 22 23 24 25 26 | | +-------------------+ | | 31 32 33 34 35 36 | | | 41 42 43 44 45 46 | | +-------------------+ | | 51 52 53 54 55 56 | | | 61 62 63 64 65 66 | | +-------------------+ ↓
transform to sharding = [[], [0]]
mesh axis 0 -----------> +-------+-------+-------+ | 11 12 | 13 14 | 15 16 | | 21 22 | 23 24 | 25 26 | | 31 32 | 33 34 | 35 36 | | 41 42 | 43 44 | 45 46 | | 51 52 | 53 54 | 55 56 | | 61 62 | 63 64 | 65 66 | +-------+-------+-------+
Algorithm:
%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>
Reshard 4x4
tensor from sharding [[0], [1, 2]]
to sharding [[0, 1], [2]]
on a 2x2x2
mesh.
unsharded 4x4
tensor
11 12 13 14 21 22 23 24 31 32 33 34 41 42 43 44
sharded on a 2x2x2
mesh
sharding = [[0], [1, 2]]
mesh axis 2 -----------> +----+----+ mesh axis 1 | mesh axis 0 | | 11 | 12 | | | | 21 | 22 | | | +----+----+ | | | 13 | 14 | | | | 23 | 24 | | | +----+----+ ↓ | +----+----+ | | 31 | 32 | | | 41 | 42 | | +----+----+ | | 33 | 34 | | | 43 | 44 | | +----+----+ ↓
transform to sharding = [[0, 1], [2]]
mesh axis 2 -----------> +-------+-------+ mesh axis 1 | mesh axis 0 | | 11 12 | 13 41 | | | +-------+-------+ | | | 21 22 | 23 24 | | | +-------+-------+ ↓ | +-------+-------+ | | 31 32 | 33 34 | | +-------+-------+ | | 41 42 | 43 44 | | +-------+-------+ ↓
Algorithm:
%1 = all_to_all %0 on @mesh mesh_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8>
is not enough.
Can be decomposed into
[[0], [1, 2]] -> [[0], [2, 1]] -> [[0, 1], [2]]
We can decompose each resharding into a sequence of basis reshardings. It is not communication efficient in terms of minimizing the data communicated between devices. An efficient approach would be more complicated to implement. Each device has to receive at most as much data as the size of its target sharding tensor.
Basis:
From replicate to split.
[[]] -> [[1]]
Extract slices without communication.
From split to replicate.
[[0]] -> [[]] [[0, 1]] -> [[1]]
All-gather along mesh axis 0.
Swap mesh axes order when assigned to the same tensor axis.
[[0, 1]] -> [[1, 0]]
Swap contents on devices with the same linear index.
Move mesh axis to different tensor dimension.
[[0], []] -> [[], [0]]
All-to-all.
Example decomposition of
[[0], [1]] -> [[1], [0]]
into
[[0], [1]] -> all-gather along mesh axis 1 -> [[0], []] -> all-to-all along mesh axis 0 -> [[], [0]] -> extract slice along mesh axis 1 -> [[1], [0]]
Example decomposition of
[[3, 2], [], [0, 1]] -> [[0], [1, 2], []]
into
[[3, 2], [], [0, 1]] -> all-to-all along mesh axis 1 -> [[3, 2], [1], [0]] -> all-to-all along mesh axis 2 -> [[3], [1, 2], [0]] -> all-gather along mesh axis 3 -> [[], [1, 2], [0]] -> all-to-all along mesh axis 0 -> [[0], [1, 2], []]