Skip to content

Commit 2dca88f

Browse files
committed
[FSDP][optim_state_dict][9/N] Specially treat zero dim tensor to ensure all_gather_object work correctly on older GPUs
ghstack-source-id: 9bd47e7 Pull Request resolved: #91343
1 parent ddf3976 commit 2dca88f

File tree

3 files changed

+74
-50
lines changed

3 files changed

+74
-50
lines changed

test/distributed/_composable/test_fully_shard.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Owner(s): ["oncall: distributed"]
22

3-
import unittest
43
import contextlib
54
import copy
65
import functools
@@ -743,7 +742,6 @@ def _test_optim_state_save_load(self, model1, optim1, model2, optim2) -> None:
743742
for key, value in group1.items():
744743
self.assertEqual(value, group2[key])
745744

746-
@unittest.skip("The test currently fails on CI.")
747745
@skip_if_lt_x_gpu(2)
748746
def test_optim_state_dict_save_load(self):
749747
orig_model = CompositeParamModel(device=torch.device("cuda"))
@@ -755,7 +753,6 @@ def test_optim_state_dict_save_load(self):
755753

756754
self._test_optim_state_save_load(orig_model, orig_optim, composable_model, composable_optim)
757755

758-
@unittest.skip("The test currently fails on CI.")
759756
@skip_if_lt_x_gpu(2)
760757
def test_optim_state_dict_submodule_fully_shard(self):
761758
orig_model = CompositeParamModel(device=torch.device("cuda"))

test/distributed/fsdp/test_fsdp_optim_state.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import bisect
44
import sys
5-
import unittest
65
from enum import auto, Enum
76
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
87

@@ -783,7 +782,6 @@ def test_flatten_sharded_optim_state_dict_transformer(self) -> None:
783782
num_iters=3,
784783
)
785784

786-
@unittest.skip("The test currently fails on CI.")
787785
@skip_if_lt_x_gpu(2)
788786
def test_use_orig_params(self) -> None:
789787
"""Tests :meth:`optim_state_dict` for an FSDP-root nested model."""
@@ -1442,7 +1440,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
14421440
loss.backward()
14431441
optim.step()
14441442

1445-
@unittest.skip("The test currently fails on CI.")
14461443
@skip_if_lt_x_gpu(2)
14471444
def test_compatible_with_named_optimizer(self):
14481445
class TestDummyModel(torch.nn.Module):

torch/distributed/fsdp/_optim_utils.py

Lines changed: 74 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1436,6 +1436,71 @@ def return_fn(fqn_to_param_info):
14361436
)
14371437

14381438

1439+
def _all_gather_optim_state(
1440+
fsdp_state: _FSDPState,
1441+
optim_state: Dict[str, Any],
1442+
) -> Dict[str, Any]:
1443+
"""
1444+
All-gathering state from all the ranks. This API is slow as it uses
1445+
``all_gather_object``. However, optim state_dict is not in the critical path.
1446+
We can fuse the communication across differnt state if the performance
1447+
becomes a problem.
1448+
"""
1449+
1450+
# Pre-processed state to prepare for the all_gather_object call.
1451+
IS_ZERO_DIM_TENSOR_KEY = "__is_zero_dim_tensor"
1452+
processed_state: Dict[str, Any] = {}
1453+
for state_name, value in sorted_items(optim_state):
1454+
if torch.is_tensor(value):
1455+
if value.dim() == 0:
1456+
processed_state[state_name] = value.item()
1457+
processed_state[f"{state_name}{IS_ZERO_DIM_TENSOR_KEY}"] = value.dtype
1458+
else:
1459+
processed_state[state_name] = value.to(fsdp_state.compute_device)
1460+
else:
1461+
processed_state[state_name] = value
1462+
1463+
# Allgather the state
1464+
object_list: List[Dict[str, Any]] = [{} for _ in range(fsdp_state.world_size)]
1465+
dist.all_gather_object(object_list, processed_state)
1466+
1467+
# Convert the gathered, pre-proccessed state of each rank to the original one.
1468+
gathered_state: Dict[str, Any] = {}
1469+
for object_state in object_list:
1470+
for name, object_value in object_state.items():
1471+
if IS_ZERO_DIM_TENSOR_KEY in name:
1472+
continue
1473+
curr_object_value = gathered_state.get(name, None)
1474+
dtype = object_state.get(f"{name}{IS_ZERO_DIM_TENSOR_KEY}", None)
1475+
if dtype is not None:
1476+
zero_dim_tensor = torch.tensor(object_value, dtype=dtype)
1477+
if curr_object_value is not None:
1478+
assert torch.equal(
1479+
zero_dim_tensor, curr_object_value
1480+
), f"Different ranks have different value for {name}."
1481+
else:
1482+
gathered_state[name] = zero_dim_tensor
1483+
elif torch.is_tensor(object_value):
1484+
if curr_object_value is not None:
1485+
curr_object_value.append(object_value.to(fsdp_state.compute_device))
1486+
else:
1487+
gathered_state[name] = [object_value.to(fsdp_state.compute_device)]
1488+
else:
1489+
if curr_object_value is not None:
1490+
assert (
1491+
curr_object_value == object_value
1492+
), f"Different ranks have different value for {name}."
1493+
else:
1494+
gathered_state[name] = object_value
1495+
1496+
for name, value in list(gathered_state.items()):
1497+
if not isinstance(value, list) or not torch.is_tensor(value[0]):
1498+
continue
1499+
gathered_state[name] = torch.cat(value)
1500+
1501+
return gathered_state
1502+
1503+
14391504
def _gather_orig_param_state(
14401505
fsdp_param_info: FSDPParamInfo,
14411506
fqn: str,
@@ -1458,51 +1523,16 @@ def _gather_orig_param_state(
14581523
):
14591524
return optim_state
14601525

1461-
# Gathering state from all ranks. This step may be slow. However,
1462-
# `state_dict()` is not in the critical path. We can fuse the communication
1463-
# if the performance becomes a problem.
1464-
state_objects = {
1465-
state_name: value for state_name, value in sorted_items(optim_state)
1466-
}
1467-
object_list: List[Dict[str, Any]] = [{} for _ in range(fsdp_state.world_size)]
1468-
dist.all_gather_object(object_list, state_objects)
1469-
orig_state: Dict[str, Any] = {}
1470-
for idx, state in enumerate(object_list):
1471-
for state_name, value in state.items():
1472-
curr_value = orig_state.get(state_name, [])
1473-
if torch.is_tensor(value):
1474-
if value.dim() > 0:
1475-
curr_value.append(value.to(fsdp_state.compute_device))
1476-
orig_state[state_name] = curr_value
1477-
else: # zero dim tensor, e.g., step.
1478-
if torch.is_tensor(curr_value):
1479-
assert torch.equal(curr_value, value)
1480-
else:
1481-
orig_state[state_name] = value
1482-
else:
1483-
assert curr_value == [] or curr_value == value
1484-
orig_state[state_name] = value
1526+
gathered_state = _all_gather_optim_state(fsdp_state, optim_state)
14851527

14861528
# Unflatten state values.
1487-
for state_name in orig_state.keys():
1488-
value = orig_state[state_name]
1489-
if not isinstance(value, list) or not torch.is_tensor(value[0]):
1529+
for state_name, value in list(gathered_state.items()):
1530+
if not torch.is_tensor(value) or value.dim() == 0:
14901531
continue
1491-
try:
1492-
value = torch.concat(value)[: flat_param._numels[param_idx]].reshape(
1493-
flat_param._shapes[param_idx]
1494-
)
1495-
except Exception as e:
1496-
raise Exception(
1497-
(
1498-
flat_param._numels[param_idx],
1499-
flat_param._shapes[param_idx],
1500-
len(value),
1501-
value[0].shape,
1502-
state_name,
1503-
fqn,
1504-
)
1505-
)
1532+
1533+
value = value[: flat_param._numels[param_idx]].reshape(
1534+
flat_param._shapes[param_idx]
1535+
)
15061536
if shard_state:
15071537
assert fsdp_state.process_group is not None
15081538
value = _ext_chunk_tensor(
@@ -1513,8 +1543,8 @@ def _gather_orig_param_state(
15131543
fsdp_state.process_group,
15141544
)
15151545
value = value.cpu()
1516-
orig_state[state_name] = value
1517-
return orig_state
1546+
gathered_state[state_name] = value
1547+
return gathered_state
15181548

15191549

15201550
def _shard_orig_param_state(

0 commit comments

Comments
 (0)