Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 56 additions & 3 deletions docs/source/distributed.tensor.parallel.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,60 @@
.. role:: hidden
:class: hidden-section

Tensor Parallelism
========================
.. py:module:: torch.distributed.tensor.parallel
Tensor Parallelism - torch.distributed.tensor.parallel
======================================================

We built Tensor Parallelism(TP) on top of DistributedTensor(DTensor) and
provide several Parallelism styles: Rowwise, Colwise and Pairwise Parallelism.

.. warning ::
Tensor Parallelism is experimental and subject to change.
The entrypoint to parallelize your module and using tensor parallelism is:

.. automodule:: torch.distributed.tensor.parallel

.. currentmodule:: torch.distributed.tensor.parallel

.. autofunction:: parallelize_module

Tensor Parallelism supports the following parallel styles:

.. autoclass:: torch.distributed.tensor.parallel.style.RowwiseParallel
:members:

.. autoclass:: torch.distributed.tensor.parallel.style.ColwiseParallel
:members:

.. autoclass:: torch.distributed.tensor.parallel.style.PairwiseParallel
:members:

Because we use DTensor within Tensor Parallelism, we need to specify the
input and output placement of the module with DTensors so it can expectedly
interacts with the module before and after. The followings are functions
used for input/output preparation:


.. currentmodule:: torch.distributed.tensor.parallel.style

.. autofunction:: make_input_replicate_1d
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I'm wondering if we make those APIs be something like mark_input_replicate_1d instead of make? The real thing we are doing is to mark the inputs/outputs instead of turning them into (i.e. we are using from_local instead of distribute_tensor)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand where you are from... Although internally we don't have drastical changes here. From users' perspective we still change a Tensor to DTensor. So I really don't like the word "mark". (Definition: https://www.merriam-webster.com/dictionary/mark). It does not contain any meaning related to change. I would prefer verbs like "change", "convert", "transform", etc.. Or even construct maybe..

.. autofunction:: make_input_shard_1d
.. autofunction:: make_input_shard_1d_last_dim
.. autofunction:: make_output_replicate_1d
.. autofunction:: make_output_tensor
.. autofunction:: make_output_shard_1d

Currently, there are some constraints which makes it hard for the `nn.MultiheadAttention`
module to work out of box for Tensor Parallelism, so we built this multihead_attention
module for Tensor Parallelism users. Also, in ``parallelize_module``, we automatically
swap ``nn.MultiheadAttention`` to this custom module when specifying ``PairwiseParallel``.

.. autoclass:: torch.distributed.tensor.parallel.multihead_attention_tp.TensorParallelMultiheadAttention
:members:

We also enabled 2D parallelism to integrate with ``FullyShardedDataParallel``.
Users just need to call the following API explicitly:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remembered we have a FSDP extension, Is TP automatically register the extension now?

Also, I wonder if we should give a small code snippet showing how the 2-D parallel look like

Copy link
Contributor Author

@fduwjj fduwjj Feb 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The registrations is in the is_available. Let me send a follow-up PR for this one.



.. currentmodule:: torch.distributed.tensor.parallel.fsdp
.. autofunction:: is_available
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need to add this API to the doc? I remembered is_available is introduced when we are in tau, but since now it's pytorch I think fsdp should always be available?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because of 2D hook registration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will send a follow-up PR to address the naming of this one.

4 changes: 2 additions & 2 deletions torch/distributed/tensor/parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
ColwiseParallel,
make_input_replicate_1d,
make_input_shard_1d,
make_input_shard_1d_dim_last,
make_input_shard_1d_last_dim,
make_output_replicate_1d,
make_output_shard_1d,
make_output_tensor,
Expand All @@ -25,7 +25,7 @@
"TensorParallelMultiheadAttention",
"make_input_replicate_1d",
"make_input_shard_1d",
"make_input_shard_1d_dim_last",
"make_input_shard_1d_last_dim",
"make_output_replicate_1d",
"make_output_tensor",
"make_output_shard_1d",
Expand Down
6 changes: 3 additions & 3 deletions torch/distributed/tensor/parallel/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"PairwiseParallel",
"make_input_replicate_1d",
"make_input_shard_1d",
"make_input_shard_1d_dim_last",
"make_input_shard_1d_last_dim",
"make_output_replicate_1d",
"make_output_tensor",
"make_output_shard_1d",
Expand Down Expand Up @@ -62,7 +62,7 @@ class RowwiseParallel(ParallelStyle):
"""

def __init__(self) -> None:
super().__init__(make_input_shard_1d_dim_last, make_output_replicate_1d)
super().__init__(make_input_shard_1d_last_dim, make_output_replicate_1d)


class ColwiseParallel(ParallelStyle):
Expand Down Expand Up @@ -112,7 +112,7 @@ def make_input_shard_1d(
)


def make_input_shard_1d_dim_last(
def make_input_shard_1d_last_dim(
input: Union[torch.Tensor, DTensor],
device_mesh: Optional[DeviceMesh] = None,
) -> DTensor:
Expand Down