Skip to content

Commit 41e3189

Browse files
fduwjjpytorchmergebot
authored andcommitted
[PT-D][Tensor parallelism] Add documentations for TP (#94421)
This is far from completed and we will definitely polish it down the road. Pull Request resolved: #94421 Approved by: https://github.com/wz337
1 parent 5b8e485 commit 41e3189

File tree

3 files changed

+61
-8
lines changed

3 files changed

+61
-8
lines changed
Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,60 @@
11
.. role:: hidden
22
:class: hidden-section
33

4-
Tensor Parallelism
5-
========================
6-
.. py:module:: torch.distributed.tensor.parallel
4+
Tensor Parallelism - torch.distributed.tensor.parallel
5+
======================================================
6+
7+
We built Tensor Parallelism(TP) on top of DistributedTensor(DTensor) and
8+
provide several Parallelism styles: Rowwise, Colwise and Pairwise Parallelism.
9+
10+
.. warning ::
11+
Tensor Parallelism is experimental and subject to change.
12+
13+
The entrypoint to parallelize your module and using tensor parallelism is:
14+
15+
.. automodule:: torch.distributed.tensor.parallel
16+
717
.. currentmodule:: torch.distributed.tensor.parallel
18+
19+
.. autofunction:: parallelize_module
20+
21+
Tensor Parallelism supports the following parallel styles:
22+
23+
.. autoclass:: torch.distributed.tensor.parallel.style.RowwiseParallel
24+
:members:
25+
26+
.. autoclass:: torch.distributed.tensor.parallel.style.ColwiseParallel
27+
:members:
28+
29+
.. autoclass:: torch.distributed.tensor.parallel.style.PairwiseParallel
30+
:members:
31+
32+
Because we use DTensor within Tensor Parallelism, we need to specify the
33+
input and output placement of the module with DTensors so it can expectedly
34+
interacts with the module before and after. The followings are functions
35+
used for input/output preparation:
36+
37+
38+
.. currentmodule:: torch.distributed.tensor.parallel.style
39+
40+
.. autofunction:: make_input_replicate_1d
41+
.. autofunction:: make_input_shard_1d
42+
.. autofunction:: make_input_shard_1d_last_dim
43+
.. autofunction:: make_output_replicate_1d
44+
.. autofunction:: make_output_tensor
45+
.. autofunction:: make_output_shard_1d
46+
47+
Currently, there are some constraints which makes it hard for the `nn.MultiheadAttention`
48+
module to work out of box for Tensor Parallelism, so we built this multihead_attention
49+
module for Tensor Parallelism users. Also, in ``parallelize_module``, we automatically
50+
swap ``nn.MultiheadAttention`` to this custom module when specifying ``PairwiseParallel``.
51+
52+
.. autoclass:: torch.distributed.tensor.parallel.multihead_attention_tp.TensorParallelMultiheadAttention
53+
:members:
54+
55+
We also enabled 2D parallelism to integrate with ``FullyShardedDataParallel``.
56+
Users just need to call the following API explicitly:
57+
58+
59+
.. currentmodule:: torch.distributed.tensor.parallel.fsdp
60+
.. autofunction:: is_available

torch/distributed/tensor/parallel/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
ColwiseParallel,
99
make_input_replicate_1d,
1010
make_input_shard_1d,
11-
make_input_shard_1d_dim_last,
11+
make_input_shard_1d_last_dim,
1212
make_output_replicate_1d,
1313
make_output_shard_1d,
1414
make_output_tensor,
@@ -25,7 +25,7 @@
2525
"TensorParallelMultiheadAttention",
2626
"make_input_replicate_1d",
2727
"make_input_shard_1d",
28-
"make_input_shard_1d_dim_last",
28+
"make_input_shard_1d_last_dim",
2929
"make_output_replicate_1d",
3030
"make_output_tensor",
3131
"make_output_shard_1d",

torch/distributed/tensor/parallel/style.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"PairwiseParallel",
1919
"make_input_replicate_1d",
2020
"make_input_shard_1d",
21-
"make_input_shard_1d_dim_last",
21+
"make_input_shard_1d_last_dim",
2222
"make_output_replicate_1d",
2323
"make_output_tensor",
2424
"make_output_shard_1d",
@@ -62,7 +62,7 @@ class RowwiseParallel(ParallelStyle):
6262
"""
6363

6464
def __init__(self) -> None:
65-
super().__init__(make_input_shard_1d_dim_last, make_output_replicate_1d)
65+
super().__init__(make_input_shard_1d_last_dim, make_output_replicate_1d)
6666

6767

6868
class ColwiseParallel(ParallelStyle):
@@ -112,7 +112,7 @@ def make_input_shard_1d(
112112
)
113113

114114

115-
def make_input_shard_1d_dim_last(
115+
def make_input_shard_1d_last_dim(
116116
input: Union[torch.Tensor, DTensor],
117117
device_mesh: Optional[DeviceMesh] = None,
118118
) -> DTensor:

0 commit comments

Comments
 (0)