Skip to content

Commit 1439cb0

Browse files
feginpytorchmergebot
authored andcommitted
[FSDP][optim_state_dict][9/N] Rewrite the all-gather flow of optimizer state to support older GPUs (#91343)
Pull Request resolved: #91343 Approved by: https://github.com/rohan-varma
1 parent 46a81c8 commit 1439cb0

File tree

3 files changed

+119
-50
lines changed

3 files changed

+119
-50
lines changed

test/distributed/_composable/test_fully_shard.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Owner(s): ["oncall: distributed"]
22

3-
import unittest
43
import contextlib
54
import copy
65
import functools
@@ -743,7 +742,6 @@ def _test_optim_state_save_load(self, model1, optim1, model2, optim2) -> None:
743742
for key, value in group1.items():
744743
self.assertEqual(value, group2[key])
745744

746-
@unittest.skip("The test currently fails on CI.")
747745
@skip_if_lt_x_gpu(2)
748746
def test_optim_state_dict_save_load(self):
749747
orig_model = CompositeParamModel(device=torch.device("cuda"))
@@ -755,7 +753,6 @@ def test_optim_state_dict_save_load(self):
755753

756754
self._test_optim_state_save_load(orig_model, orig_optim, composable_model, composable_optim)
757755

758-
@unittest.skip("The test currently fails on CI.")
759756
@skip_if_lt_x_gpu(2)
760757
def test_optim_state_dict_submodule_fully_shard(self):
761758
orig_model = CompositeParamModel(device=torch.device("cuda"))

test/distributed/fsdp/test_fsdp_optim_state.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import bisect
44
import sys
5-
import unittest
65
from enum import auto, Enum
76
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
87

@@ -783,7 +782,6 @@ def test_flatten_sharded_optim_state_dict_transformer(self) -> None:
783782
num_iters=3,
784783
)
785784

786-
@unittest.skip("The test currently fails on CI.")
787785
@skip_if_lt_x_gpu(2)
788786
def test_use_orig_params(self) -> None:
789787
"""Tests :meth:`optim_state_dict` for an FSDP-root nested model."""
@@ -1442,7 +1440,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
14421440
loss.backward()
14431441
optim.step()
14441442

1445-
@unittest.skip("The test currently fails on CI.")
14461443
@skip_if_lt_x_gpu(2)
14471444
def test_compatible_with_named_optimizer(self):
14481445
class TestDummyModel(torch.nn.Module):

torch/distributed/fsdp/_optim_utils.py

Lines changed: 119 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch.distributed as dist
2020
import torch.distributed.fsdp._traversal_utils as traversal_utils
2121
import torch.nn as nn
22+
import torch.nn.functional as F
2223
from torch.distributed._shard.sharded_tensor import ShardedTensor
2324
from torch.distributed.fsdp._common_utils import (
2425
_apply_to_modules,
@@ -1436,6 +1437,113 @@ def return_fn(fqn_to_param_info):
14361437
)
14371438

14381439

1440+
@dataclass
1441+
class StateInfo:
1442+
tensors: Dict[str, _PosDimTensorInfo]
1443+
scalar_tensors: Dict[str, torch.Tensor]
1444+
non_tensors: Dict[str, Any]
1445+
1446+
1447+
@dataclass
1448+
class AllGatherInfo:
1449+
tensors: List[torch.Tensor]
1450+
numels: List[int]
1451+
work: Optional[dist.Work]
1452+
1453+
1454+
def _all_gather_optim_state(
1455+
fsdp_state: _FSDPState, optim_state: Dict[str, Any], param_numel: int
1456+
) -> Dict[str, Any]:
1457+
"""
1458+
All-gathering state from all the ranks. This API is slow as it uses
1459+
``all_gather_object``. However, optim state_dict is not in the critical path.
1460+
We can fuse the communication across differnt state if the performance
1461+
becomes a problem.
1462+
"""
1463+
# Allgather the scalar tensor state, non-tensor states and tensors metadata.
1464+
processed_state = StateInfo({}, {}, {})
1465+
for state_name, value in sorted_items(optim_state):
1466+
if torch.is_tensor(value):
1467+
if value.dim() == 0:
1468+
processed_state.scalar_tensors[state_name] = value
1469+
else:
1470+
processed_state.tensors[state_name] = _PosDimTensorInfo(
1471+
value.shape, value.dtype
1472+
)
1473+
else:
1474+
processed_state.non_tensors = value
1475+
object_list: List[StateInfo] = [
1476+
processed_state for _ in range(fsdp_state.world_size)
1477+
]
1478+
dist.all_gather_object(object_list, processed_state)
1479+
1480+
# Convert the gathered, pre-proccessed state of each rank to the original one.
1481+
gathered_state: Dict[str, Any] = {}
1482+
1483+
all_tensor_states = sorted(
1484+
list(set([n for state in object_list for n in state.tensors.keys()]))
1485+
)
1486+
for name in all_tensor_states:
1487+
numels = []
1488+
dtype = torch.float
1489+
max_numel = 0
1490+
for object_state in object_list:
1491+
numels.append(0)
1492+
info = object_state.tensors.get(name, None)
1493+
if info is not None:
1494+
numels[-1] = info.shape.numel()
1495+
dtype = info.dtype
1496+
max_numel = max(max_numel, numels[-1])
1497+
local_state = (
1498+
optim_state[name]
1499+
if name in optim_state
1500+
else torch.empty(max_numel, dtype=dtype, device=fsdp_state.compute_device)
1501+
)
1502+
if max_numel > local_state.numel():
1503+
local_state = F.pad(local_state, [0, max_numel - local_state.numel()])
1504+
tensors = [
1505+
torch.empty(max_numel, dtype=dtype, device=fsdp_state.compute_device)
1506+
if rank != fsdp_state.rank
1507+
else local_state
1508+
for rank in range(len(object_list))
1509+
]
1510+
work = dist.all_gather(
1511+
tensors, local_state, group=fsdp_state.process_group, async_op=True
1512+
)
1513+
gathered_state[name] = AllGatherInfo(tensors, numels, work)
1514+
1515+
for object_state in object_list:
1516+
for name, non_tensor_value in object_state.non_tensors.items():
1517+
curr_non_tensor_value = gathered_state.get(name, None)
1518+
assert (
1519+
curr_non_tensor_value is None
1520+
or curr_non_tensor_value == non_tensor_value
1521+
), f"Different ranks have different values for {name}."
1522+
gathered_state[name] = non_tensor_value
1523+
1524+
for name, scalar_tensor_value in object_state.scalar_tensors.items():
1525+
curr_scalar_tensor_value = gathered_state.get(name, None)
1526+
assert curr_scalar_tensor_value is None or torch.equal(
1527+
scalar_tensor_value, curr_scalar_tensor_value
1528+
), f"Different ranks have different values for {name}."
1529+
gathered_state[name] = scalar_tensor_value
1530+
1531+
for name, value in list(gathered_state.items()):
1532+
if not isinstance(value, AllGatherInfo):
1533+
continue
1534+
assert value.work is not None
1535+
value.work.wait()
1536+
gathered_state[name] = torch.cat(
1537+
[
1538+
rank_tensor[:rank_numel]
1539+
for rank_tensor, rank_numel in zip(value.tensors, value.numels)
1540+
if rank_numel > 0
1541+
]
1542+
)
1543+
1544+
return gathered_state
1545+
1546+
14391547
def _gather_orig_param_state(
14401548
fsdp_param_info: FSDPParamInfo,
14411549
fqn: str,
@@ -1458,51 +1566,18 @@ def _gather_orig_param_state(
14581566
):
14591567
return optim_state
14601568

1461-
# Gathering state from all ranks. This step may be slow. However,
1462-
# `state_dict()` is not in the critical path. We can fuse the communication
1463-
# if the performance becomes a problem.
1464-
state_objects = {
1465-
state_name: value for state_name, value in sorted_items(optim_state)
1466-
}
1467-
object_list: List[Dict[str, Any]] = [{} for _ in range(fsdp_state.world_size)]
1468-
dist.all_gather_object(object_list, state_objects)
1469-
orig_state: Dict[str, Any] = {}
1470-
for idx, state in enumerate(object_list):
1471-
for state_name, value in state.items():
1472-
curr_value = orig_state.get(state_name, [])
1473-
if torch.is_tensor(value):
1474-
if value.dim() > 0:
1475-
curr_value.append(value.to(fsdp_state.compute_device))
1476-
orig_state[state_name] = curr_value
1477-
else: # zero dim tensor, e.g., step.
1478-
if torch.is_tensor(curr_value):
1479-
assert torch.equal(curr_value, value)
1480-
else:
1481-
orig_state[state_name] = value
1482-
else:
1483-
assert curr_value == [] or curr_value == value
1484-
orig_state[state_name] = value
1569+
gathered_state = _all_gather_optim_state(
1570+
fsdp_state, optim_state, flat_param._numels[param_idx]
1571+
)
14851572

14861573
# Unflatten state values.
1487-
for state_name in orig_state.keys():
1488-
value = orig_state[state_name]
1489-
if not isinstance(value, list) or not torch.is_tensor(value[0]):
1574+
for state_name, value in list(gathered_state.items()):
1575+
if not torch.is_tensor(value) or value.dim() == 0:
14901576
continue
1491-
try:
1492-
value = torch.concat(value)[: flat_param._numels[param_idx]].reshape(
1493-
flat_param._shapes[param_idx]
1494-
)
1495-
except Exception as e:
1496-
raise Exception(
1497-
(
1498-
flat_param._numels[param_idx],
1499-
flat_param._shapes[param_idx],
1500-
len(value),
1501-
value[0].shape,
1502-
state_name,
1503-
fqn,
1504-
)
1505-
)
1577+
1578+
value = value[: flat_param._numels[param_idx]].reshape(
1579+
flat_param._shapes[param_idx]
1580+
)
15061581
if shard_state:
15071582
assert fsdp_state.process_group is not None
15081583
value = _ext_chunk_tensor(
@@ -1513,8 +1588,8 @@ def _gather_orig_param_state(
15131588
fsdp_state.process_group,
15141589
)
15151590
value = value.cpu()
1516-
orig_state[state_name] = value
1517-
return orig_state
1591+
gathered_state[state_name] = value
1592+
return gathered_state
15181593

15191594

15201595
def _shard_orig_param_state(

0 commit comments

Comments
 (0)