Skip to content

Commit d5856ca

Browse files
committed
[Dynamo x FSDP][2/x] Small changes to distributed to make it dynamo friendly
ghstack-source-id: d45babf Pull Request resolved: #106886 [Dynamo x FSDP][2/x] Small changes to distributed to make it dynamo friendly Dist nits
1 parent 1015556 commit d5856ca

File tree

8 files changed

+78
-49
lines changed

8 files changed

+78
-49
lines changed

torch/distributed/_composable/contract.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,5 +190,10 @@ def _get_registry(module: nn.Module) -> Dict[str, RegistryItem]:
190190
Get an ``OrderedDict`` of composable APIs that have been applied to the
191191
``module``, indexed by the API name.
192192
"""
193-
default_registry: Dict[str, RegistryItem] = OrderedDict()
194-
return module.__dict__.setdefault(REGISTRY_KEY, default_registry) # type: ignore[call-overload]
193+
registry = getattr(module, REGISTRY_KEY, None)
194+
if registry is None:
195+
default_registry: Dict[str, RegistryItem] = OrderedDict()
196+
setattr(module, REGISTRY_KEY, default_registry)
197+
return default_registry
198+
else:
199+
return registry

torch/distributed/_composable_state.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,7 @@ def _get_module_state(module: nn.Module) -> Optional[_State]:
2929
if isinstance(module, _State):
3030
return cast(_State, module)
3131
else:
32-
return _module_state_mapping.get(module, None)
32+
if module in _module_state_mapping:
33+
return _module_state_mapping[module]
34+
else:
35+
return None

torch/distributed/distributed_c10d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,7 +555,7 @@ class group(metaclass=_WorldMeta):
555555
pass
556556

557557
class GroupMember(metaclass=_WorldMeta):
558-
NON_GROUP_MEMBER = object()
558+
NON_GROUP_MEMBER = -100
559559

560560

561561
# Default process group state
@@ -982,7 +982,7 @@ def get_backend(group: Optional[ProcessGroup] = None) -> str:
982982
pg = group
983983
if _rank_not_in_group(pg):
984984
raise RuntimeError("Invalid process group specified")
985-
pg_store = _world.pg_map.get(pg, None)
985+
pg_store = _world.pg_map[pg] if pg in _world.pg_map else None
986986
assert pg_store is not None
987987
return pg_store[0]
988988

torch/distributed/fsdp/_common_utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -392,12 +392,16 @@ def f(module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs):
392392
submodule_name == "_fsdp_wrapped_module"
393393
or submodule_name == "_dmp_wrapped_module"
394394
):
395-
warnings.warn(
396-
"An unexpected prefix is detected. This case "
397-
" should only happen when using DMP with FSDP. "
398-
f"prefix = {prefix}, "
399-
f"submodule_name = {submodule_name}"
400-
)
395+
if (
396+
not torch.distributed._functional_collectives.is_torchdynamo_compiling()
397+
):
398+
# TODO(voz): Don't graph break on this
399+
warnings.warn(
400+
"An unexpected prefix is detected. This case "
401+
" should only happen when using DMP with FSDP. "
402+
f"prefix = {prefix}, "
403+
f"submodule_name = {submodule_name}"
404+
)
401405
new_prefix = prefix
402406
elif submodule_name == "module":
403407
warnings.warn(
@@ -511,7 +515,12 @@ def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> No
511515
# FIXME record_stream doesn't work with non-cuda tensors
512516
if tensor.device.type not in ["cuda", torch._C._get_privateuse1_backend_name()]:
513517
return
514-
with no_dispatch():
518+
519+
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
520+
# Don't no dispatch under torch compile like this
521+
with no_dispatch():
522+
tensor.record_stream(stream)
523+
else:
515524
tensor.record_stream(stream)
516525

517526

torch/distributed/fsdp/_exec_order_utils.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -217,18 +217,21 @@ def _check_order(self, handle: FlatParamHandle, is_training: bool) -> None:
217217
# TODO (awgu): Since every module has at most one handle in the
218218
# current implementation, this should never raise the error.
219219
assert self.world_size is not None # mypy
220-
for (r1, n1), (r2, n2) in itertools.combinations(
221-
(
222-
(rank, world_num_valid_indices[rank])
223-
for rank in range(self.world_size)
224-
),
225-
2,
226-
):
227-
if n1 != n2:
228-
raise RuntimeError(
229-
f"{msg_prefix} rank {r1} is all-gathering {n1} parameters "
230-
f"while rank {r2} is all-gathering {n2} parameters"
231-
)
220+
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
221+
# TODO(voz): Don't graph break on this - dynamo hates the n1 != n2
222+
# tensor comparison control flow.
223+
for (r1, n1), (r2, n2) in itertools.combinations(
224+
(
225+
(rank, world_num_valid_indices[rank])
226+
for rank in range(self.world_size)
227+
),
228+
2,
229+
):
230+
if n1 != n2:
231+
raise RuntimeError(
232+
f"{msg_prefix} rank {r1} is all-gathering {n1} parameters "
233+
f"while rank {r2} is all-gathering {n2} parameters"
234+
)
232235
world_indices = torch.zeros( # type: ignore[call-overload]
233236
self.world_size * num_valid_indices, **tensor_kwargs
234237
)
@@ -239,26 +242,31 @@ def _check_order(self, handle: FlatParamHandle, is_training: bool) -> None:
239242
# Copy entire tensor from D2H once to avoid per element D2H copies
240243
world_indices = world_indices.cpu()
241244
# Check that all ranks plan to all-gather the same index parameters
242-
for (r1, i1), (r2, i2) in itertools.combinations(
243-
(
245+
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
246+
# TODO(voz): Don't graph break on this - dynamo hates the i1 != i2
247+
# tensor comparison control flow.
248+
for (r1, i1), (r2, i2) in itertools.combinations(
244249
(
245-
rank,
246-
world_indices[
247-
rank * num_valid_indices : (rank + 1) * num_valid_indices
248-
],
249-
)
250-
for rank in range(self.world_size)
251-
),
252-
2,
253-
):
254-
if i1 != i2:
255-
r1_param_names = self._get_names_from_handle_indices(i1)
256-
r2_param_names = self._get_names_from_handle_indices(i2)
257-
raise RuntimeError(
258-
f"{msg_prefix} rank {r1} is all-gathering parameters "
259-
f"for {r1_param_names} while rank {r2} is all-gathering "
260-
f"parameters for {r2_param_names}"
261-
)
250+
(
251+
rank,
252+
world_indices[
253+
rank
254+
* num_valid_indices : (rank + 1)
255+
* num_valid_indices
256+
],
257+
)
258+
for rank in range(self.world_size)
259+
),
260+
2,
261+
):
262+
if i1 != i2:
263+
r1_param_names = self._get_names_from_handle_indices(i1)
264+
r2_param_names = self._get_names_from_handle_indices(i2)
265+
raise RuntimeError(
266+
f"{msg_prefix} rank {r1} is all-gathering parameters "
267+
f"for {r1_param_names} while rank {r2} is all-gathering "
268+
f"parameters for {r2_param_names}"
269+
)
262270
elif self._checking_order:
263271
# Only issue warnings on the first deviating iteration and stop
264272
# checking thereafter to avoid flooding the console

torch/distributed/fsdp/_init_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,10 @@
7878
ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD,
7979
ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2,
8080
}
81-
HYBRID_SHARDING_STRATEGIES = {
81+
HYBRID_SHARDING_STRATEGIES = [
8282
ShardingStrategy.HYBRID_SHARD,
8383
ShardingStrategy._HYBRID_SHARD_ZERO2,
84-
}
84+
]
8585
NO_RESHARD_AFTER_FORWARD_STRATEGIES = (
8686
ShardingStrategy.SHARD_GRAD_OP,
8787
ShardingStrategy._HYBRID_SHARD_ZERO2,

torch/distributed/fsdp/_runtime_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,12 @@ def _reshard(
348348
"""
349349
handle.reshard(free_unsharded_flat_param)
350350
if state.limit_all_gathers and free_unsharded_flat_param:
351-
free_event = state._device_handle.Event()
352-
free_event.record()
353-
state._free_event_queue.enqueue(free_event)
351+
if not torch.distributed._functional_collectives.is_torchdynamo_compiling():
352+
# We don't run a even queue for freeing under torch compile atm
353+
# But maybe we need to? TODO(voz): Look into this
354+
free_event = state._device_handle.Event()
355+
free_event.record()
356+
state._free_event_queue.enqueue(free_event)
354357
handle.post_reshard()
355358
# Since we prefetch entire handles keys at a time, conservatively mark
356359
# the entire key as no longer prefetched once we free at least one

torch/distributed/fsdp/flat_param.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import torch.nn as nn
2828
import torch.nn.functional as F
2929
from torch import Tensor
30-
from torch.distributed._tensor import DTensor
3130
from torch.distributed.fsdp._common_utils import (
3231
_FSDPDeviceHandle,
3332
_named_parameters_with_duplicates,
@@ -1797,6 +1796,8 @@ def _use_unsharded_views(self, as_params: bool) -> None:
17971796
flat_param = self.flat_param
17981797
self._check_unsharded(flat_param)
17991798
views = self._get_unflat_views()
1799+
from torch.distributed._tensor import DTensor
1800+
18001801
for i, (view, (param_name, module, _)) in enumerate(
18011802
zip(views, flat_param._param_infos)
18021803
):

0 commit comments

Comments
 (0)