Skip to content

Commit a9c726c

Browse files
committed
[FSDP] Allow to use TorchDispatch with FSDP
ghstack-source-id: 6786e65 Pull Request resolved: #88014
1 parent ff94494 commit a9c726c

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

torch/distributed/fsdp/_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import dataclasses
22
import traceback
33
from collections import OrderedDict
4-
from typing import Any, Callable, Dict, List, Set, Tuple, Union
4+
from typing import Any, Callable, cast, Dict, List, Set, Tuple, Union
55

66
import torch
77
from torch.nn.modules.batchnorm import _BatchNorm
88
from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined]
99
_is_namedtuple,
1010
)
1111
from torch.nn.utils.rnn import PackedSequence
12+
from torch.utils._mode_utils import no_dispatch
1213

1314

1415
def _contains_batchnorm(module):
@@ -107,6 +108,11 @@ def _same_storage(x: torch.Tensor, y: torch.Tensor) -> bool:
107108
return x.storage().data_ptr() == y.storage().data_ptr()
108109

109110

111+
def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.cuda.Stream) -> None:
112+
with no_dispatch():
113+
tensor.record_stream(cast(torch._C.Stream, stream))
114+
115+
110116
def p_assert(cond: Any, s: Any, raise_assertion_error: bool = True) -> None:
111117
"""This is used as an alternate to ``assert`` when in the backward context
112118
to print the error message ``s`` since otherwise, it is swallowed."""

torch/distributed/fsdp/flat_param.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from itertools import accumulate, chain
66
from typing import (
77
Any,
8-
cast,
98
Dict,
109
Generator,
1110
Iterator,
@@ -29,7 +28,13 @@
2928
)
3029

3130
from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform
32-
from ._utils import _alloc_storage, _free_storage, _same_storage, p_assert
31+
from ._utils import (
32+
_alloc_storage,
33+
_free_storage,
34+
_no_dispatch_record_stream,
35+
_same_storage,
36+
p_assert,
37+
)
3338

3439
__all__ = [
3540
"FlatParameter",
@@ -1121,9 +1126,7 @@ def _free_unsharded_flat_param(self):
11211126
self._check_storage_allocated(unsharded_flat_param)
11221127
self._check_on_compute_device(unsharded_flat_param)
11231128
# Do not free the memory until all ops in the current stream finish
1124-
unsharded_flat_param.record_stream(
1125-
cast(torch._C.Stream, torch.cuda.current_stream())
1126-
)
1129+
_no_dispatch_record_stream(unsharded_flat_param, torch.cuda.current_stream())
11271130
_free_storage(unsharded_flat_param)
11281131

11291132
def _use_sharded_flat_param(self) -> None:

torch/distributed/fsdp/fully_sharded_data_parallel.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@
9090
_apply_to_tensors,
9191
_contains_batchnorm,
9292
_free_storage,
93+
_no_dispatch_record_stream,
9394
_override_batchnorm_mixed_precision,
9495
p_assert,
9596
)
@@ -2686,14 +2687,16 @@ def _post_backward_hook(
26862687
grad.detach(), non_blocking=True
26872688
)
26882689
# Don't let this memory get reused until after the transfer.
2689-
grad.data.record_stream(torch.cuda.current_stream())
2690+
_no_dispatch_record_stream(grad.data, torch.cuda.current_stream())
26902691

26912692
# After _post_backward_hook returns, orig_grad_data will eventually
26922693
# go out of scope, at which point it could otherwise be freed for
26932694
# further reuse by the main stream while the div/reduce_scatter/copy
26942695
# are underway in the post_backward stream. See:
26952696
# github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py
2696-
orig_grad_data.record_stream(self._streams["post_backward"])
2697+
_no_dispatch_record_stream(
2698+
orig_grad_data, self._streams["post_backward"]
2699+
)
26972700

26982701
if handle._use_orig_params:
26992702
# Since the handle's `FlatParameter` completed its gradient
@@ -2727,7 +2730,7 @@ def _cast_grad_to_param_dtype(
27272730
grad.data = grad.data.to(dtype=param.dtype)
27282731
# Do not let the low precision gradient memory get reused until
27292732
# the cast to full parameter precision completes
2730-
low_prec_grad_data.record_stream(torch.cuda.current_stream())
2733+
_no_dispatch_record_stream(low_prec_grad_data, torch.cuda.current_stream())
27312734

27322735
def _should_free_unsharded_flat_param(self, handle: FlatParamHandle):
27332736
return (

0 commit comments

Comments
 (0)