[FSDP2] provide public API to share cuda streams across roots#165024
[FSDP2] provide public API to share cuda streams across roots#165024weifengpy wants to merge 5 commits intogh/weifengpy/37/basefrom
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165024
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 0cc16ef with merge base 3a110c9 ( FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-py3-arm64 / test (mps, 1, 1, macos-m2-15) Details for Dev Infra teamRaised by workflow job |
|
|
||
|
|
||
| @contextlib.contextmanager | ||
| def patch_foreach_all_gather(new_foreach_all_gather: Callable): |
There was a problem hiding this comment.
Please use ParamSpec to preserve the typing for type checking if possible
There was a problem hiding this comment.
good suggestion! I just tried but seems to be quite involved for mypy - it requires not only specify type, but also arg name Arg(Stream, 'reduce_scatter_stream'). I might need to learn more about ParamSpec to find a better way. Since this is a unit test util, I might treat it as a follow up
Error (MYPY) [assignment]
Incompatible types in assignment (expression has type
"Callable[[list[FSDPParam], ProcessGroup, bool, Stream, Stream,
device, AllGather], AllGatherResult | None]", variable has type
"Callable[[Arg(list[FSDPParam], 'fsdp_params'), Arg(ProcessGroup,
'group'), Arg(bool, 'async_op'), Arg(Stream, 'all_gather_copy_in_stream'),
Arg(Stream, 'all_gather_stream'), Arg(device, 'device'), Arg(AllGather,
'all_gather_comm')], AllGatherResult | None]")
1020 | )
1021 | dist.barrier()
1022 | torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_all_gather = (
>>> 1023 | new_foreach_all_gather
1024 | )
1025 | try:
1026 | yield
Error (MYPY) [assignment]
Incompatible types in assignment (expression has type
"Callable[[list[FSDPParam], list[Tensor], ProcessGroup, Stream,
ReduceScatter, dtype | None, dtype | None, device, float | None,
ProcessGroup | None, Stream, bool, Tensor | None, Callable[[Tensor],
None] | None, bool], tuple[Tensor, Event, Event, Tensor | None, Event |
None, Tensor | None]]", variable has type "Callable[[Arg(list[FSDPParam],
'fsdp_params'), Arg(list[Tensor], 'unsharded_grads'), Arg(ProcessGroup,
'reduce_scatter_group'), Arg(Stream, 'reduce_scatter_stream'),
Arg(ReduceScatter, 'reduce_scatter_comm'), Arg(dtype | None,
'orig_dtype'), Arg(dtype | None, 'reduce_dtype'), Arg(device, 'device'),
Arg(float | None, 'gradient_divide_factor'), Arg(ProcessGroup | None,
'all_reduce_group'), Arg(Stream, 'all_reduce_stream'), Arg(bool,
'all_reduce_grads'), Arg(Tensor | None, 'partial_reduce_output'),
Arg(Callable[[Tensor], None] | None, 'all_reduce_hook'), DefaultArg(bool,
'force_sum_reduction_for_comms')], tuple[Tensor, Event, Event, Tensor |
None, Event | None, Tensor | None]]")
1066 | )
1067 | dist.barrier()
1068 | torch.distributed.fsdp._fully_shard._fsdp_param_group.foreach_reduce = (
>>> 1069 | new_foreach_reduce
1070 | )
1071 | try:
1072 | yield
…ots" for pipeline parallel, we can have multiple FSDP roots (chunks) ``` model = nn.Sequential([chunk0, chunk1]) fully_shard(model.chunk0) fully_shard(model.chunk1) ``` we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation ``` from torch.distributed.fsdp import share_comm_ctx share_comm_ctx([model.chunk0, model.chunk1]) ``` unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context` Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned]
…ots" for pipeline parallel, we can have multiple FSDP roots (chunks) ``` model = nn.Sequential([chunk0, chunk1]) fully_shard(model.chunk0) fully_shard(model.chunk1) ``` we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation ``` from torch.distributed.fsdp import share_comm_ctx share_comm_ctx([model.chunk0, model.chunk1]) ``` unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context` Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned]
…ots" for pipeline parallel, we can have multiple FSDP roots (chunks) ``` model = nn.Sequential([chunk0, chunk1]) fully_shard(model.chunk0) fully_shard(model.chunk1) ``` we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation ``` from torch.distributed.fsdp import share_comm_ctx share_comm_ctx([model.chunk0, model.chunk1]) ``` unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context` Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned]
…ots" for pipeline parallel, we can have multiple FSDP roots (chunks) ``` model = nn.Sequential([chunk0, chunk1]) fully_shard(model.chunk0) fully_shard(model.chunk1) ``` we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation ``` from torch.distributed.fsdp import share_comm_ctx share_comm_ctx([model.chunk0, model.chunk1]) ``` unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context` Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned]
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
|
@pytorchbot merge -f "unrelated CI error" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
I just want to print CommDebugMode and know if there is communication. implementing `__repr__` for `print(comm_mode)`
```
comm_mode = CommDebugMode()
with comm_mode:
out = torch.mm(inps, weight)
print(comm_mode)
# CommDebugMode(get_total_counts()=0)
```
Tags:
Pull Request resolved: #165006
Approved by: https://github.com/anshul-si
ghstack dependencies: #165024
…h#165024) for pipeline parallel, we can have multiple FSDP roots (chunks) ``` model = nn.Sequential([chunk0, chunk1]) fully_shard(model.chunk0) fully_shard(model.chunk1) ``` we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation ``` from torch.distributed.fsdp import share_comm_ctx share_comm_ctx([model.chunk0, model.chunk1]) ``` unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context` Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#165024 Approved by: https://github.com/mori360
…#165006) I just want to print CommDebugMode and know if there is communication. implementing `__repr__` for `print(comm_mode)` ``` comm_mode = CommDebugMode() with comm_mode: out = torch.mm(inps, weight) print(comm_mode) # CommDebugMode(get_total_counts()=0) ``` Tags: Pull Request resolved: pytorch#165006 Approved by: https://github.com/anshul-si ghstack dependencies: pytorch#165024
…h#165024) for pipeline parallel, we can have multiple FSDP roots (chunks) ``` model = nn.Sequential([chunk0, chunk1]) fully_shard(model.chunk0) fully_shard(model.chunk1) ``` we can call `share_comm_ctx` to share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentation ``` from torch.distributed.fsdp import share_comm_ctx share_comm_ctx([model.chunk0, model.chunk1]) ``` unit test: `pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_context` Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#165024 Approved by: https://github.com/mori360
…#165006) I just want to print CommDebugMode and know if there is communication. implementing `__repr__` for `print(comm_mode)` ``` comm_mode = CommDebugMode() with comm_mode: out = torch.mm(inps, weight) print(comm_mode) # CommDebugMode(get_total_counts()=0) ``` Tags: Pull Request resolved: pytorch#165006 Approved by: https://github.com/anshul-si ghstack dependencies: pytorch#165024
for pipeline parallel, we can have multiple FSDP roots (chunks)
we can call
share_comm_ctxto share all-gather, reduce-scatter, all-reduce cuda streams. this avoids inter-stream memory fragmentationunit test:
pytest -s test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_share_comm_contextStack from ghstack (oldest at bottom):
Summary:
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci