Resharding Spmdization Examples

Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[0, 1]] on a 2x3 mesh.

unsharded 2x3 tensor

11 12 13
21 22 23

sharded on a 2x3 mesh

sharding = [[0, 1]]

mesh contents:

mesh axis 1
----------->
+----+----+----+ mesh axis 0 |
| 11 | 12 | 13 |             |
+----+----+----+             |
| 21 | 22 | 23 |             |
+----+----+----+             ↓

Transform into sharding = [[1, 0]]

mesh axis 1
----------->
+----+----+----+ mesh axis 0 |
| 11 | 13 | 22 |             |
+----+----+----+             |
| 12 | 21 | 23 |             |
+----+----+----+             ↓

Algorithm: Swap contents on devices that have the same linear index in the 2 shardings.


Reshard 2x3 tensor from sharding [[0, 1]] to sharding [[1]] on a 2x3 mesh.

unsharded 2x3 tensor

11 12 13
21 22 23

sharded on a 2x3 mesh

sharding = [[0, 1]]

mesh contents:

mesh axis 1
----------->
+----+----+----+ mesh axis 0 |
| 11 | 12 | 13 |             |
+----+----+----+             |
| 21 | 22 | 23 |             |
+----+----+----+             ↓

Transform into sharding = [[1]]

mesh axis 1
----------->
+----+----+----+ mesh axis 0 |
| 11 | 12 | 13 |             |
| 21 | 22 | 23 |             |
+----+----+----+             |
| 11 | 12 | 13 |             |
| 21 | 22 | 23 |             |
+----+----+----+             ↓

Algorithm: All-gather along mesh axis 0.


Reshard 4x6 tensor from sharding [[], [0, 1]] to sharding [[], [0]] on a 2x3 mesh.

unsharded 4x6 tensor

11 12 13 14 15 16
21 22 23 24 25 26

sharded on a 2x3 mesh

sharding = [[], [0, 1]]

mesh contents:

mesh axis 1
----------->
+----+----+----+ mesh axis 0 |
| 11 | 12 | 13 |             |
| 21 | 22 | 23 |             |
+----+----+----+             |
| 14 | 15 | 16 |             |
| 24 | 25 | 26 |             |
+----+----+----+             ↓

Transform into sharding = [[], [0]]

mesh axis 1
----------->
+----------+----------+ mesh axis 0 |
| 11 12 13 | 11 12 13 |             |
| 21 22 23 | 21 22 23 |             |
+----------+----------+             |
| 14 15 16 | 14 15 16 |             |
| 24 25 26 | 24 25 26 |             |
+----------+----------+             ↓

Algorithm: All-gather along mesh axis 1.


Reshard 4x8 tensor from sharding [[0], [1, 2]] to sharding [[0], [2]] on a 2x2x2 mesh.

unsharded 4x8 tensor

11 12 13 14 15 16 17 18
21 22 23 24 25 26 27 28
31 32 33 34 35 36 37 38
41 42 43 44 45 46 47 48

sharded on a 2x2x2 mesh

sharding = [[0], [1, 2]]

mesh contents:

mesh axis 2
----------->
+-------+-------+ mesh axis 1 | mesh axis 0 |
| 11 12 | 13 14 |             |             |
| 21 22 | 23 24 |             |             |
+-------+-------+             |             |
| 15 16 | 17 18 |             |             |
| 25 26 | 27 28 |             |             |
+-------+-------+             ↓             |
+-------+-------+                           |
| 31 32 | 33 34 |                           |
| 41 42 | 43 44 |                           |
+-------+-------+                           |
| 35 36 | 37 38 |                           |
| 45 46 | 47 48 |                           |
+-------+-------+                           ↓

Transform into sharding = [[0], [2]]

mesh axis 2
----------->
+-------------+-------------+ mesh axis 1 | mesh axis 0 |
| 11 12 13 14 | 15 16 17 18 |             |             |
| 21 22 23 24 | 25 26 27 28 |             |             |
+-------------+-------------+             |             |
| 11 12 13 14 | 15 16 17 18 |             |             |
| 21 22 23 24 | 25 26 27 28 |             |             |
+-------------+-------------+             ↓             |
+-------------+-------------+                           |
| 31 32 33 34 | 35 36 37 38 |                           |
| 41 42 43 44 | 45 46 47 48 |                           |
+-------------+-------------+                           |
| 31 32 33 34 | 35 36 37 38 |                           |
| 41 42 43 44 | 45 46 47 48 |                           |
+-------------+-------------+                           ↓

Algorithm:

Can't be done with just an all-gather along mesh axis 1. Can be handled by multiple resharding transformations [[0], [1, 2]] -> [[0], [2, 1]] -> [[0], [2]]


Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.

unsharded 6x6 tensor

11 12 13 14 15 16
21 22 23 24 25 26
31 32 33 34 35 36
41 42 43 44 45 46
51 52 53 54 55 56
61 62 63 64 65 66

sharded on a 2x3 mesh

sharding = [[0], [1]]

mesh axis 1
----------->
+-------+-------+-------+ mesh axis 0 |
| 11 12 | 13 14 | 15 16 |             |
| 21 22 | 23 24 | 25 26 |             |
| 31 32 | 33 34 | 35 36 |             |
+-------+-------+-------+             |
| 41 42 | 43 44 | 45 46 |             |
| 51 52 | 53 54 | 55 56 |             |
| 61 62 | 63 64 | 65 66 |             |
+-------+-------+-------+             ↓

transform to sharding = [[1], [0]]

mesh axis 1
----------->
+----------+----------+----------+ mesh axis 0 |
| 11 12 13 | 31 32 33 | 51 52 53 |             |
| 21 22 23 | 41 42 43 | 61 62 63 |             |
+----------+----------+----------+             |
| 14 15 16 | 34 35 36 | 54 55 56 |             |
| 24 25 26 | 44 45 46 | 64 65 66 |             |
+----------+----------+----------+             ↓

mesh axis 0
----------->
+----------+----------+ mesh axis 1 |
| 11 12 13 | 14 15 16 |             |
| 21 22 23 | 24 25 26 |             |
+----------+----------+             |
| 31 32 33 | 34 35 36 |             |
| 41 42 43 | 44 45 46 |             |
+----------+----------+             |
| 51 52 53 | 54 55 56 |             |
| 61 62 63 | 64 65 66 |             |
+----------+----------+             ↓

Algorithm: TODO


Reshard 6x6 tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x6 mesh.

unsharded 6x6 tensor

11 12 13 14 15 16
21 22 23 24 25 26
31 32 33 34 35 36
41 42 43 44 45 46
51 52 53 54 55 56
61 62 63 64 65 66

shard on 2x6 mesh

sharding = [[0], [1]]

mesh axis 1
----------->
+----+----+----+----+----+----+ mesh axis 0 |
| 11 | 12 | 13 ‖ 14 | 15 | 16 |             |
| 21 | 22 | 23 ‖ 24 | 23 | 26 |             |
| 31 | 32 | 33 ‖ 34 | 35 | 36 |             |
+----+----+----+----+----+----+             |
| 41 | 42 | 43 ‖ 44 | 45 | 46 |             |
| 51 | 52 | 53 ‖ 54 | 55 | 56 |             |
| 61 | 62 | 63 ‖ 64 | 65 | 66 |             |
+----+----+----+----+----+----+             ↓

transform to sharding = [[1], [0]]

mesh axis 0
----------->
+----------+----------+ mesh axis 1 |
| 11 12 13 | 14 15 16 |             |
+----------+----------+             |
| 21 22 23 | 24 25 26 |             |
+----------+----------+             |
| 31 32 33 | 34 35 36 |             |
+==========+==========+             |
| 41 42 43 | 44 45 46 |             |
+----------+----------+             |
| 51 52 53 | 54 55 56 |             |
+----------+----------+             |
| 61 62 63 | 64 65 66 |             |
+----------+----------+             ↓

Algorithm: TODO


Reshard KxL tensor from [[0], [1]] to [[1], [0]] on MxN mesh.

M x N mesh. K x L tensor t. d(m, n) the tensor on device (m, n).

sharding = [[0], [1]] Tensor shard s on each device has size (K ceildiv M, L ceildiv N).

d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]

substitute

i <- m * (K ceildiv M) + k
j <- n * (L ceildiv N) + l
m -> i floordiv (K ceildiv M)
n -> j floordiv (L ceildiv N)
k -> i % (K ceildiv M)
l -> j % (L ceildiv N)

For the inverse map we get

t[i, j] -> d(
    i floordiv (K ceildiv M), j floordiv (L ceildiv N)
)[
    i % (K ceildiv M), j % (L ceildiv N)
]

Check:

i = 13, j = 17, M = 3, N = 4, K = 16, L = 23
t[13, 17] = d(
    13 floordiv (16 ceildiv 3),
    17 floordiv (23 ceilvid 4)
)[
    13 % (16 ceildiv 3),
    17 % (23 ceilvid 4)
]
= d(
    13 floordiv 6,
    17 floordiv 6
)[
    13 % 6,
    17 % 6
]
= d(2, 2)[1, 5]
= t[
    2 * (16 ceildiv 3) + 1,
    2 * (23 ceildiv 4) + 5
]
= t[
    2 * 6 + 1,
    2 * 6 + 5
]
= t[13, 17]

sharding = [[1], [0]] Tensor shard s on each device has size (K ceildiv N, L ceildiv M).

d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]

substitute

i <- n * (K ceildiv N) + k
j <- m * (L ceildiv M) + l
m -> j floordiv (L ceildiv M)
n -> i floordiv (K ceildiv N)
k -> i % (K ceildiv N)
l -> j % (L ceildiv M)

For the inverse map we get

t[i, j] -> d(
    j floordiv (L ceildiv M), i floordiv (K ceildiv N)
)[
    i % (K ceildiv N), j % (L ceildiv M)
]

Check:

i = 9, j = 19, M = 5, N = 2, K = 27, L = 14
t[9, 19] = d(
    19 floordiv (14 ceildiv 5),
    9 floordiv (27 ceildiv 2)
)[
    9 % (27 ceildiv 2),
    19 % (14 ceildiv 5)
]
= d(
    19 floordiv 3,
    9 floordiv 14
)[
    9 % 14
    19 % 3
]
= d(6, 0)[9, 1]
= t[
    0 * (27 ceildiv 2) + 9,
    6 * (14 ceildiv 5) + 1
]
= t[
    0 * 14 + 9,
    6 * 3 + 1
]
= t[9, 19]

sharding = [[0], [1]]

d(m, n)[k, l] -> t[m * (K ceildiv M) + k, n * (L ceildiv N) + l]
t[i, j] -> d(i floordiv (K ceildiv M), j floordiv (L ceildiv N))[i % (K ceildiv M), j % (L ceildiv N)]

sharding = [[1], [0]]

d(m, n)[k, l] -> t[n * (K ceildiv N) + k, m * (L ceildiv M) + l]
t[i, j] -> d(j floordiv (L ceildiv M), i floordiv (K ceildiv N))[i % (K ceildiv N), j % (L ceildiv M)]

sharding [[0], [1]] -> [[1], [0]] d1(m, n) the tensor on device (m, n) for sharding sharding [[0], [1]]. d2(m, n) the tensor on device (m, n) for sharding sharding [[1], [0]].

d1(m, n)[k, l] ->
t[m * (K ceildiv M) + k, n * (L ceildiv N) + l] ->
d2(
    (m * (L ceildiv M) + l) floordiv (L ceildiv M),
    (n * (K ceildiv N) + k) floordiv (K ceildiv N)
)[
    (n * (K ceildiv N) + k) % (K ceildiv N),
    (m * (L ceildiv M) + l) % (L ceildiv M)
]
= d2(p, q)[u, v]

We want to copy the the data between devices in slices/tiles. What are the source/target tile coordinates? For a fixed (m, n, p, q) what is the range of (k, l, u, v)? TODO


Reshard KxL tensor from sharding [[0], [1]] to sharding [[1], [0]] on a 2x3 mesh.

Device placement on a 2x3 mesh

11 12 13  <- devices
21 22 23

sharding [[0], [1]]

tensor axis 1
----------->
+----+----+----+ tensor axis 0 |
| 11 | 12 | 13 |               |
+----+----+----+               |
| 21 | 22 | 23 |               |
+----+----+----+               ↓

transform to sharding [[1], [0]]

tensor axis 1
----------->
+----+----+ tensor axis 0 |
| 11 | 21 |               |
+----+----+               |
| 12 | 22 |               |
+----+----+               |
| 13 | 23 |               |
+----+----+               ↓
+-----------------+--------+--------+-----------------+
|                 |                 |                 |
+                 +                 +                 +
|       11        |        12       |        13       |
+                 +                 +                 +
|                 |                 |                 |
+-----------------+--------+--------+-----------------+
|                 |                 |                 |
+                 +                 +                 +
|       21        |        22       |        23       |
+                 +                 +                 +
|                 |                 |                 |
+-----------------+--------+--------+-----------------+

+-----------------+--------+--------+-----------------+
|                          |                          |
+          11              +               21         +
|                          |                          |
+-----------------+--------+--------+-----------------+
|                          |                          |
+          12              +               22         +
|                          |                          |
+-----------------+--------+--------+-----------------+
|                          |                          |
+          13              +               23         +
|                          |                          |
+-----------------+--------+--------+-----------------+

+-----------------+--------+--------+-----------------+
|                 |        |        |                 |
+     11  11      + 12  11 + 12  21 +     13  21      +
|                 |        |        |                 |
+-----------------+--------+--------+-----------------+
|     11  12      | 12  12 | 12  22 |     13  22      |
+-----------------+--------+--------+-----------------+
|     21  12      | 22  12 | 22  22 |     23  22      |
+-----------------+--------+--------+-----------------+
|                 |        |        |                 |
+     21  13      + 22  13 + 22  23 +     23  23      +
|                 |        |        |                 |
+-----------------+--------+--------+-----------------+

If S and T are the source and target shard sizes along some tensor axis. Then we have a period of (S*T)/gcd(S, T). Then the cut pattern repeats. TODO


Reshard 6x6 tensor from sharding [[0], []] to sharding [[], [0]] on a 3 mesh.

unsharded 6x6 tensor

11 12 13 14 15 16
21 22 23 24 25 26
31 32 33 34 35 36
41 42 43 44 45 46
51 52 53 54 55 56
61 62 63 64 65 66

sharded on a 3 mesh

sharding = [[0], []]

+-------------------+ mesh axis 0 |
| 11 12 13 14 15 16 |             |
| 21 22 23 24 25 26 |             |
+-------------------+             |
| 31 32 33 34 35 36 |             |
| 41 42 43 44 45 46 |             |
+-------------------+             |
| 51 52 53 54 55 56 |             |
| 61 62 63 64 65 66 |             |
+-------------------+             ↓

transform to sharding = [[], [0]]

mesh axis 0
----------->
+-------+-------+-------+
| 11 12 | 13 14 | 15 16 |
| 21 22 | 23 24 | 25 26 |
| 31 32 | 33 34 | 35 36 |
| 41 42 | 43 44 | 45 46 |
| 51 52 | 53 54 | 55 56 |
| 61 62 | 63 64 | 65 66 |
+-------+-------+-------+

Algorithm:

%1 = all_to_all %0 on @mesh mesh_axes = [0] split_axis = 1 concat_axis = 0 : tensor<2x6xi8> -> tensor<6x2xi8>

Reshard 4x4 tensor from sharding [[0], [1, 2]] to sharding [[0, 1], [2]] on a 2x2x2 mesh.

unsharded 4x4 tensor

11 12 13 14
21 22 23 24
31 32 33 34
41 42 43 44

sharded on a 2x2x2 mesh

sharding = [[0], [1, 2]]

mesh axis 2
----------->
+----+----+ mesh axis 1 | mesh axis 0 |
| 11 | 12 |             |             |
| 21 | 22 |             |             |
+----+----+             |             |
| 13 | 14 |             |             |
| 23 | 24 |             |             |
+----+----+             ↓             |
+----+----+                           |
| 31 | 32 |                           |
| 41 | 42 |                           |
+----+----+                           |
| 33 | 34 |                           |
| 43 | 44 |                           |
+----+----+                           ↓

transform to sharding = [[0, 1], [2]]

mesh axis 2
----------->
+-------+-------+ mesh axis 1 | mesh axis 0 |
| 11 12 | 13 41 |             |             |
+-------+-------+             |             |
| 21 22 | 23 24 |             |             |
+-------+-------+             ↓             |
+-------+-------+                           |
| 31 32 | 33 34 |                           |
+-------+-------+                           |
| 41 42 | 43 44 |                           |
+-------+-------+                           ↓

Algorithm:

%1 = all_to_all %0 on @mesh mesh_axes = [2] split_axis = 1 concat_axis = 0 : tensor<2x1xi8> -> tensor<1x2xi8>

is not enough.

Can be decomposed into

[[0], [1, 2]] -> [[0], [2, 1]] -> [[0, 1], [2]]

Decomposition into basis of reshardings

We can decompose each resharding into a sequence of basis reshardings. It is not communication efficient in terms of minimizing the data communicated between devices. An efficient approach would be more complicated to implement. Each device has to receive at most as much data as the size of its target sharding tensor.


Basis:

  • From replicate to split.

    [[]] -> [[1]]
    

    Extract slices without communication.

  • From split to replicate.

    [[0]] -> [[]]
    [[0, 1]] -> [[1]]
    

    All-gather along mesh axis 0.

  • Swap mesh axes order when assigned to the same tensor axis.

    [[0, 1]] -> [[1, 0]]
    

    Swap contents on devices with the same linear index.

  • Move mesh axis to different tensor dimension.

    [[0], []] -> [[], [0]]
    

    All-to-all.


Example decomposition of

[[0], [1]] -> [[1], [0]]

into

[[0], [1]] -> all-gather along mesh axis 1    ->
[[0], []]  -> all-to-all along mesh axis 0    ->
[[], [0]]  -> extract slice along mesh axis 1 ->
[[1], [0]]

Example decomposition of

[[3, 2], [], [0, 1]] -> [[0], [1, 2], []]

into

[[3, 2], [], [0, 1]] -> all-to-all along mesh axis 1 ->
[[3, 2], [1], [0]]   -> all-to-all along mesh axis 2 ->
[[3], [1, 2], [0]]   -> all-gather along mesh axis 3 ->
[[], [1, 2], [0]]    -> all-to-all along mesh axis 0 ->
[[0], [1, 2], []]