Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions torch/distributed/fsdp/_runtime_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
_State,
TrainingState,
)
from torch.distributed.fsdp._utils import _apply_to_tensors, p_assert
from torch.distributed.fsdp._utils import (
_apply_to_tensors,
_no_dispatch_record_stream,
p_assert,
)
from torch.distributed.fsdp.api import BackwardPrefetch
from torch.distributed.fsdp.flat_param import (
_HandlesKey,
Expand Down Expand Up @@ -572,12 +576,16 @@ def _post_backward_hook(
# Since the sharded gradient is produced in the post-backward
# stream and consumed later in the computation stream, inform
# the caching allocator
sharded_grad.data.record_stream(torch.cuda.current_stream())
_no_dispatch_record_stream(
sharded_grad.data, torch.cuda.current_stream()
)

# Since the unsharded gradient is produced in the computation
# stream and consumed in the post-backward stream, inform the
# caching allocator (before it goes out of scope)
unsharded_grad_data.record_stream(state._streams["post_backward"])
_no_dispatch_record_stream(
unsharded_grad_data, state._streams["post_backward"]
)

if handle._use_orig_params:
# Since the handle's `FlatParameter` completed its gradient
Expand Down Expand Up @@ -630,7 +638,7 @@ def _cast_grad_to_param_dtype(
# caching allocator; for the sharded strategies, the gradient is
# produced in the post-backward stream, so this `record_stream()`
# should be a no-op
low_prec_grad_data.record_stream(torch.cuda.current_stream())
_no_dispatch_record_stream(low_prec_grad_data, torch.cuda.current_stream())


def _check_comm_hook(
Expand Down
8 changes: 7 additions & 1 deletion torch/distributed/fsdp/_utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import dataclasses
import traceback
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Set, Tuple, Union
from typing import Any, Callable, cast, Dict, List, Set, Tuple, Union

import torch
from torch.nn.modules.batchnorm import _BatchNorm
from torch.nn.parallel.scatter_gather import ( # type: ignore[attr-defined]
_is_namedtuple,
)
from torch.nn.utils.rnn import PackedSequence
from torch.utils._mode_utils import no_dispatch


def _contains_batchnorm(module):
Expand Down Expand Up @@ -115,3 +116,8 @@ def p_assert(cond: Any, s: str, raise_assertion_error: bool = True) -> None:
traceback.print_stack()
if raise_assertion_error:
raise AssertionError(s)


def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.cuda.Stream) -> None:
with no_dispatch():
tensor.record_stream(cast(torch._C.Stream, stream))
13 changes: 8 additions & 5 deletions torch/distributed/fsdp/flat_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from itertools import accumulate, chain
from typing import (
Any,
cast,
Dict,
Generator,
Iterator,
Expand All @@ -30,7 +29,13 @@
)

from ._fsdp_extensions import _ext_post_unflatten_transform, _ext_pre_flatten_transform
from ._utils import _alloc_storage, _free_storage, _same_storage, p_assert
from ._utils import (
_alloc_storage,
_free_storage,
_no_dispatch_record_stream,
_same_storage,
p_assert,
)

__all__ = [
"FlatParameter",
Expand Down Expand Up @@ -1200,9 +1205,7 @@ def _free_unsharded_flat_param(self):
self._check_storage_allocated(unsharded_flat_param)
self._check_on_compute_device(unsharded_flat_param)
# Do not free the memory until all ops in the current stream finish
unsharded_flat_param.record_stream(
cast(torch._C.Stream, torch.cuda.current_stream())
)
_no_dispatch_record_stream(unsharded_flat_param, torch.cuda.current_stream())
_free_storage(unsharded_flat_param)

def _use_sharded_flat_param(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
from .flat_param import FlatParameter, FlatParamHandle
from .wrap import ParamExecOrderWrapPolicy


_TORCH_FX_AVAIL = True
if not hasattr(torch, "fx"):
_TORCH_FX_AVAIL = False
Expand Down