Skip to content

[FSDP2] provide public API to share cuda streams across roots#165024

Closed
weifengpy wants to merge 5 commits intogh/weifengpy/37/basefrom
gh/weifengpy/37/head
Closed

[FSDP2] provide public API to share cuda streams across roots#165024
weifengpy wants to merge 5 commits intogh/weifengpy/37/basefrom
gh/weifengpy/37/head

Conversation

@weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Oct 9, 2025

for pipeline parallel, we can have multiple FSDP roots (chunks)

model = nn.Sequential([chunk0, chunk1])
fully_shard(model.chunk0)
fully_shard(model.chunk1)

we can call share_comm_ctx to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation

from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([model.chunk0, model.chunk1])

unit test: pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context

Stack from ghstack (oldest at bottom):

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165024

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 0cc16ef with merge base 3a110c9 (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@weifengpy
Copy link
Contributor Author

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 9, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m2-15)

Details for Dev Infra team Raised by workflow job



@contextlib.contextmanager
def patch_foreach_all_gather(new_foreach_all_gather: Callable):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use ParamSpec to preserve the typing for type checking if possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good suggestion! I just tried but seems to be quite involved for mypy - it requires not only specify type, but also arg name Arg(Stream, 'reduce_scatter_stream'). I might need to learn more about ParamSpec to find a better way. Since this is a unit test util, I might treat it as a follow up

 Error (MYPY) [assignment]
    Incompatible types in assignment (expression has type
    "Callable[[list[FSDPParam], ProcessGroup, bool, Stream, Stream,
    device, AllGather], AllGatherResult | None]", variable has type
    "Callable[[Arg(list[FSDPParam], 'fsdp_params'), Arg(ProcessGroup,
    'group'), Arg(bool, 'async_op'), Arg(Stream, 'all_gather_copy_in_stream'),
    Arg(Stream, 'all_gather_stream'), Arg(device, 'device'), Arg(AllGather,
    'all_gather_comm')], AllGatherResult | None]")

        1020  |    )
        1021  |    dist.barrier()
        1022  |    torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
    >>> 1023  |        new_foreach_all_gather
        1024  |    )
        1025  |    try:
        1026  |        yield

  Error (MYPY) [assignment]
    Incompatible types in assignment (expression has type
    "Callable[[list[FSDPParam], list[Tensor], ProcessGroup, Stream,
    ReduceScatter, dtype | None, dtype | None, device, float | None,
    ProcessGroup | None, Stream, bool, Tensor | None, Callable[[Tensor],
    None] | None, bool], tuple[Tensor, Event, Event, Tensor | None, Event |
    None, Tensor | None]]", variable has type "Callable[[Arg(list[FSDPParam],
    'fsdp_params'), Arg(list[Tensor], 'unsharded_grads'), Arg(ProcessGroup,
    'reduce_scatter_group'), Arg(Stream, 'reduce_scatter_stream'),
    Arg(ReduceScatter, 'reduce_scatter_comm'), Arg(dtype | None,
    'orig_dtype'), Arg(dtype | None, 'reduce_dtype'), Arg(device, 'device'),
    Arg(float | None, 'gradient_divide_factor'), Arg(ProcessGroup | None,
    'all_reduce_group'), Arg(Stream, 'all_reduce_stream'), Arg(bool,
    'all_reduce_grads'), Arg(Tensor | None, 'partial_reduce_output'),
    Arg(Callable[[Tensor], None] | None, 'all_reduce_hook'), DefaultArg(bool,
    'force_sum_reduction_for_comms')], tuple[Tensor, Event, Event, Tensor |
    None, Event | None, Tensor | None]]")

        1066  |    )
        1067  |    dist.barrier()
        1068  |    torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
    >>> 1069  |        new_foreach_reduce
        1070  |    )
        1071  |    try:
        1072  |        yield

…ots"

for pipeline parallel, we can have multiple FSDP roots (chunks)
```
model = nn.Sequential([chunk0, chunk1])
fully_shard(model.chunk0)
fully_shard(model.chunk1)
```

we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation
```
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([model.chunk0, model.chunk1])
```

unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context`



Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
…ots"

for pipeline parallel, we can have multiple FSDP roots (chunks)
```
model = nn.Sequential([chunk0, chunk1])
fully_shard(model.chunk0)
fully_shard(model.chunk1)
```

we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation
```
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([model.chunk0, model.chunk1])
```

unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context`



Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
…ots"

for pipeline parallel, we can have multiple FSDP roots (chunks)
```
model = nn.Sequential([chunk0, chunk1])
fully_shard(model.chunk0)
fully_shard(model.chunk1)
```

we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation
```
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([model.chunk0, model.chunk1])
```

unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context`



Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
…ots"

for pipeline parallel, we can have multiple FSDP roots (chunks)
```
model = nn.Sequential([chunk0, chunk1])
fully_shard(model.chunk0)
fully_shard(model.chunk1)
```

we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation
```
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([model.chunk0, model.chunk1])
```

unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context`



Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci

[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Oct 13, 2025
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: cfa17a5
Pull Request resolved: #165024
@weifengpy
Copy link
Contributor Author

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

@weifengpy
Copy link
Contributor Author

@pytorchbot merge -f "unrelated CI error"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Oct 14, 2025
I just want to print CommDebugMode and know if there is communication. implementing `__repr__` for `print(comm_mode)`

```
comm_mode = CommDebugMode()
with comm_mode:
    out = torch.mm(inps, weight)
print(comm_mode)
# CommDebugMode(get_total_counts()=0)
```

Tags:

Pull Request resolved: #165006
Approved by: https://github.com/anshul-si
ghstack dependencies: #165024
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 15, 2025
…h#165024)

for pipeline parallel, we can have multiple FSDP roots (chunks)
```
model = nn.Sequential([chunk0, chunk1])
fully_shard(model.chunk0)
fully_shard(model.chunk1)
```

we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation
```
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([model.chunk0, model.chunk1])
```

unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context`

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: pytorch#165024
Approved by: https://github.com/mori360
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 15, 2025
…#165006)

I just want to print CommDebugMode and know if there is communication. implementing `__repr__` for `print(comm_mode)`

```
comm_mode = CommDebugMode()
with comm_mode:
    out = torch.mm(inps, weight)
print(comm_mode)
# CommDebugMode(get_total_counts()=0)
```

Tags:

Pull Request resolved: pytorch#165006
Approved by: https://github.com/anshul-si
ghstack dependencies: pytorch#165024
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…h#165024)

for pipeline parallel, we can have multiple FSDP roots (chunks)
```
model = nn.Sequential([chunk0, chunk1])
fully_shard(model.chunk0)
fully_shard(model.chunk1)
```

we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation
```
from torch.distributed.fsdp import share_comm_ctx
share_comm_ctx([model.chunk0, model.chunk1])
```

unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context`

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Pull Request resolved: pytorch#165024
Approved by: https://github.com/mori360
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…#165006)

I just want to print CommDebugMode and know if there is communication. implementing `__repr__` for `print(comm_mode)`

```
comm_mode = CommDebugMode()
with comm_mode:
    out = torch.mm(inps, weight)
print(comm_mode)
# CommDebugMode(get_total_counts()=0)
```

Tags:

Pull Request resolved: pytorch#165006
Approved by: https://github.com/anshul-si
ghstack dependencies: pytorch#165024
@github-actions github-actions bot deleted the gh/weifengpy/37/head branch November 14, 2025 02:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants