|
1 | 1 | .. role:: hidden |
2 | 2 | :class: hidden-section |
3 | 3 |
|
4 | | -Tensor Parallelism |
5 | | -======================== |
6 | | -.. py:module:: torch.distributed.tensor.parallel |
| 4 | +Tensor Parallelism - torch.distributed.tensor.parallel |
| 5 | +====================================================== |
| 6 | + |
| 7 | +We built Tensor Parallelism(TP) on top of DistributedTensor(DTensor) and |
| 8 | +provide several Parallelism styles: Rowwise, Colwise and Pairwise Parallelism. |
| 9 | + |
| 10 | +.. warning :: |
| 11 | + Tensor Parallelism is experimental and subject to change. |
| 12 | +
|
| 13 | +The entrypoint to parallelize your module and using tensor parallelism is: |
| 14 | + |
| 15 | +.. automodule:: torch.distributed.tensor.parallel |
| 16 | + |
7 | 17 | .. currentmodule:: torch.distributed.tensor.parallel |
| 18 | + |
| 19 | +.. autofunction:: parallelize_module |
| 20 | + |
| 21 | +Tensor Parallelism supports the following parallel styles: |
| 22 | + |
| 23 | +.. autoclass:: torch.distributed.tensor.parallel.style.RowwiseParallel |
| 24 | + :members: |
| 25 | + |
| 26 | +.. autoclass:: torch.distributed.tensor.parallel.style.ColwiseParallel |
| 27 | + :members: |
| 28 | + |
| 29 | +.. autoclass:: torch.distributed.tensor.parallel.style.PairwiseParallel |
| 30 | + :members: |
| 31 | + |
| 32 | +Because we use DTensor within Tensor Parallelism, we need to specify the |
| 33 | +input and output placement of the module with DTensors so it can expectedly |
| 34 | +interacts with the module before and after. The followings are functions |
| 35 | +used for input/output preparation: |
| 36 | + |
| 37 | + |
| 38 | +.. currentmodule:: torch.distributed.tensor.parallel.style |
| 39 | + |
| 40 | +.. autofunction:: make_input_replicate_1d |
| 41 | +.. autofunction:: make_input_shard_1d |
| 42 | +.. autofunction:: make_input_shard_1d_last_dim |
| 43 | +.. autofunction:: make_output_replicate_1d |
| 44 | +.. autofunction:: make_output_tensor |
| 45 | +.. autofunction:: make_output_shard_1d |
| 46 | + |
| 47 | +Currently, there are some constraints which makes it hard for the `nn.MultiheadAttention` |
| 48 | +module to work out of box for Tensor Parallelism, so we built this multihead_attention |
| 49 | +module for Tensor Parallelism users. Also, in ``parallelize_module``, we automatically |
| 50 | +swap ``nn.MultiheadAttention`` to this custom module when specifying ``PairwiseParallel``. |
| 51 | + |
| 52 | +.. autoclass:: torch.distributed.tensor.parallel.multihead_attention_tp.TensorParallelMultiheadAttention |
| 53 | + :members: |
| 54 | + |
| 55 | +We also enabled 2D parallelism to integrate with ``FullyShardedDataParallel``. |
| 56 | +Users just need to call the following API explicitly: |
| 57 | + |
| 58 | + |
| 59 | +.. currentmodule:: torch.distributed.tensor.parallel.fsdp |
| 60 | +.. autofunction:: is_available |
0 commit comments