Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions test/distributed/_composable/test_fully_shard.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Owner(s): ["oncall: distributed"]

import unittest
import contextlib
import copy
import functools
Expand Down Expand Up @@ -743,7 +742,6 @@ def _test_optim_state_save_load(self, model1, optim1, model2, optim2) -> None:
for key, value in group1.items():
self.assertEqual(value, group2[key])

@unittest.skip("The test currently fails on CI.")
@skip_if_lt_x_gpu(2)
def test_optim_state_dict_save_load(self):
orig_model = CompositeParamModel(device=torch.device("cuda"))
Expand All @@ -755,7 +753,6 @@ def test_optim_state_dict_save_load(self):

self._test_optim_state_save_load(orig_model, orig_optim, composable_model, composable_optim)

@unittest.skip("The test currently fails on CI.")
@skip_if_lt_x_gpu(2)
def test_optim_state_dict_submodule_fully_shard(self):
orig_model = CompositeParamModel(device=torch.device("cuda"))
Expand Down
3 changes: 0 additions & 3 deletions test/distributed/fsdp/test_fsdp_optim_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import bisect
import sys
import unittest
from enum import auto, Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Type

Expand Down Expand Up @@ -783,7 +782,6 @@ def test_flatten_sharded_optim_state_dict_transformer(self) -> None:
num_iters=3,
)

@unittest.skip("The test currently fails on CI.")
@skip_if_lt_x_gpu(2)
def test_use_orig_params(self) -> None:
"""Tests :meth:`optim_state_dict` for an FSDP-root nested model."""
Expand Down Expand Up @@ -1442,7 +1440,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
loss.backward()
optim.step()

@unittest.skip("The test currently fails on CI.")
@skip_if_lt_x_gpu(2)
def test_compatible_with_named_optimizer(self):
class TestDummyModel(torch.nn.Module):
Expand Down
163 changes: 119 additions & 44 deletions torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch.distributed as dist
import torch.distributed.fsdp._traversal_utils as traversal_utils
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed.fsdp._common_utils import (
_apply_to_modules,
Expand Down Expand Up @@ -1436,6 +1437,113 @@ def return_fn(fqn_to_param_info):
)


@dataclass
class StateInfo:
tensors: Dict[str, _PosDimTensorInfo]
scalar_tensors: Dict[str, torch.Tensor]
non_tensors: Dict[str, Any]


@dataclass
class AllGatherInfo:
tensors: List[torch.Tensor]
numels: List[int]
work: Optional[dist.Work]


def _all_gather_optim_state(
fsdp_state: _FSDPState, optim_state: Dict[str, Any], param_numel: int
) -> Dict[str, Any]:
"""
All-gathering state from all the ranks. This API is slow as it uses
``all_gather_object``. However, optim state_dict is not in the critical path.
We can fuse the communication across differnt state if the performance
becomes a problem.
"""
# Allgather the scalar tensor state, non-tensor states and tensors metadata.
processed_state = StateInfo({}, {}, {})
for state_name, value in sorted_items(optim_state):
if torch.is_tensor(value):
if value.dim() == 0:
processed_state.scalar_tensors[state_name] = value
else:
processed_state.tensors[state_name] = _PosDimTensorInfo(
value.shape, value.dtype
)
else:
processed_state.non_tensors = value
object_list: List[StateInfo] = [
processed_state for _ in range(fsdp_state.world_size)
]
dist.all_gather_object(object_list, processed_state)

# Convert the gathered, pre-proccessed state of each rank to the original one.
gathered_state: Dict[str, Any] = {}

all_tensor_states = sorted(
list(set([n for state in object_list for n in state.tensors.keys()]))
)
for name in all_tensor_states:
numels = []
dtype = torch.float
max_numel = 0
for object_state in object_list:
numels.append(0)
info = object_state.tensors.get(name, None)
if info is not None:
numels[-1] = info.shape.numel()
dtype = info.dtype
max_numel = max(max_numel, numels[-1])
local_state = (
optim_state[name]
if name in optim_state
else torch.empty(max_numel, dtype=dtype, device=fsdp_state.compute_device)
)
if max_numel > local_state.numel():
local_state = F.pad(local_state, [0, max_numel - local_state.numel()])
tensors = [
torch.empty(max_numel, dtype=dtype, device=fsdp_state.compute_device)
if rank != fsdp_state.rank
else local_state
for rank in range(len(object_list))
]
work = dist.all_gather(
tensors, local_state, group=fsdp_state.process_group, async_op=True
)
gathered_state[name] = AllGatherInfo(tensors, numels, work)

for object_state in object_list:
for name, non_tensor_value in object_state.non_tensors.items():
curr_non_tensor_value = gathered_state.get(name, None)
assert (
curr_non_tensor_value is None
or curr_non_tensor_value == non_tensor_value
), f"Different ranks have different values for {name}."
gathered_state[name] = non_tensor_value

for name, scalar_tensor_value in object_state.scalar_tensors.items():
curr_scalar_tensor_value = gathered_state.get(name, None)
assert curr_scalar_tensor_value is None or torch.equal(
scalar_tensor_value, curr_scalar_tensor_value
), f"Different ranks have different values for {name}."
gathered_state[name] = scalar_tensor_value

for name, value in list(gathered_state.items()):
if not isinstance(value, AllGatherInfo):
continue
assert value.work is not None
value.work.wait()
gathered_state[name] = torch.cat(
[
rank_tensor[:rank_numel]
for rank_tensor, rank_numel in zip(value.tensors, value.numels)
if rank_numel > 0
]
)

return gathered_state


def _gather_orig_param_state(
fsdp_param_info: FSDPParamInfo,
fqn: str,
Expand All @@ -1458,51 +1566,18 @@ def _gather_orig_param_state(
):
return optim_state

# Gathering state from all ranks. This step may be slow. However,
# `state_dict()` is not in the critical path. We can fuse the communication
# if the performance becomes a problem.
state_objects = {
state_name: value for state_name, value in sorted_items(optim_state)
}
object_list: List[Dict[str, Any]] = [{} for _ in range(fsdp_state.world_size)]
dist.all_gather_object(object_list, state_objects)
orig_state: Dict[str, Any] = {}
for idx, state in enumerate(object_list):
for state_name, value in state.items():
curr_value = orig_state.get(state_name, [])
if torch.is_tensor(value):
if value.dim() > 0:
curr_value.append(value.to(fsdp_state.compute_device))
orig_state[state_name] = curr_value
else: # zero dim tensor, e.g., step.
if torch.is_tensor(curr_value):
assert torch.equal(curr_value, value)
else:
orig_state[state_name] = value
else:
assert curr_value == [] or curr_value == value
orig_state[state_name] = value
gathered_state = _all_gather_optim_state(
fsdp_state, optim_state, flat_param._numels[param_idx]
)

# Unflatten state values.
for state_name in orig_state.keys():
value = orig_state[state_name]
if not isinstance(value, list) or not torch.is_tensor(value[0]):
for state_name, value in list(gathered_state.items()):
if not torch.is_tensor(value) or value.dim() == 0:
continue
try:
value = torch.concat(value)[: flat_param._numels[param_idx]].reshape(
flat_param._shapes[param_idx]
)
except Exception as e:
raise Exception(
(
flat_param._numels[param_idx],
flat_param._shapes[param_idx],
len(value),
value[0].shape,
state_name,
fqn,
)
)

value = value[: flat_param._numels[param_idx]].reshape(
flat_param._shapes[param_idx]
)
if shard_state:
assert fsdp_state.process_group is not None
value = _ext_chunk_tensor(
Expand All @@ -1513,8 +1588,8 @@ def _gather_orig_param_state(
fsdp_state.process_group,
)
value = value.cpu()
orig_state[state_name] = value
return orig_state
gathered_state[state_name] = value
return gathered_state


def _shard_orig_param_state(
Expand Down