Skip to content

Commit dabe788

Browse files
author
Andrew Gu
committed
[FSDP] Add initial summon_full_params(with_grads=True)
ghstack-source-id: 9c80c33 Pull Request resolved: #85738
1 parent 4a8b722 commit dabe788

File tree

3 files changed

+277
-24
lines changed

3 files changed

+277
-24
lines changed

test/distributed/fsdp/test_fsdp_summon_full_params.py

Lines changed: 123 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,28 @@
11
# Owner(s): ["oncall: distributed"]
2+
import contextlib
23
import itertools
34
import math
45
import sys
56
from copy import deepcopy
6-
from typing import Optional
7+
from typing import List, Optional
78

89
import torch
910
import torch.nn as nn
1011
from torch import distributed as dist
1112
from torch.distributed.fsdp import CPUOffload
1213
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
13-
from torch.distributed.fsdp import MixedPrecision
14+
from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
1415
from torch.distributed.fsdp.flat_param import FlatParamHandle
1516
from torch.distributed.fsdp.wrap import enable_wrap, wrap
17+
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
1618
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
1719
from torch.testing._internal.common_fsdp import (
1820
CUDAInitMode,
1921
DeterministicModel,
2022
FSDPInitMode,
2123
FSDPTest,
2224
NestedWrappedModule,
25+
TransformerWithSharedParams,
2326
)
2427
from torch.testing._internal.common_utils import (
2528
TEST_WITH_DEV_DBG_ASAN,
@@ -578,6 +581,124 @@ def test_named_parameters_buffers(self, prefix: str, recurse: bool):
578581
self.assertEqual(n1, n2)
579582
self.assertEqual(p1, p2)
580583

584+
@skip_if_lt_x_gpu(2)
585+
def test_with_grads(self):
586+
self.run_subtests(
587+
{
588+
"writeback": [False, True],
589+
"offload_to_cpu": [False, True],
590+
"sharding_strategy": [
591+
ShardingStrategy.FULL_SHARD,
592+
ShardingStrategy.SHARD_GRAD_OP,
593+
ShardingStrategy.NO_SHARD,
594+
],
595+
"use_orig_params": [True],
596+
},
597+
self._test_with_grads,
598+
)
599+
600+
def _test_with_grads(
601+
self,
602+
writeback: bool,
603+
offload_to_cpu: bool,
604+
sharding_strategy: ShardingStrategy,
605+
use_orig_params: bool,
606+
):
607+
def _check_grads(
608+
ddp_model: DDP,
609+
fsdp_model: FSDP,
610+
old_fsdp_grads: Optional[List[torch.Tensor]],
611+
):
612+
WRITEBACK_FACTOR = 2
613+
with FSDP.summon_full_params(
614+
fsdp_model,
615+
writeback=writeback,
616+
offload_to_cpu=offload_to_cpu,
617+
with_grads=True,
618+
):
619+
for (n1, p1), (n2, p2) in zip(
620+
ddp_model.module.named_parameters(),
621+
fsdp_model.named_parameters(),
622+
):
623+
# Parameter names are only expected to match because
624+
# `fsdp_model` has top-level FSDP, so its
625+
# `named_parameters()` cleans *all* of the names
626+
self.assertEqual(n1, n2)
627+
assert p1.grad is not None
628+
torch.testing.assert_close(p1.grad, p2.grad)
629+
# Ensure that the tensor is not all zeros, which would
630+
# mean that the multiplication is vacuous
631+
assert torch.count_nonzero(p2.grad) > 0
632+
p2.grad *= WRITEBACK_FACTOR
633+
new_fsdp_grads = [
634+
param.grad for param in fsdp_model.parameters()
635+
if param.grad is not None
636+
]
637+
writeback_persists = writeback or sharding_strategy == ShardingStrategy.NO_SHARD
638+
for old_grad, new_grad in zip(old_fsdp_grads, new_fsdp_grads):
639+
if writeback_persists:
640+
torch.testing.assert_close(old_grad * WRITEBACK_FACTOR, new_grad)
641+
else:
642+
torch.testing.assert_close(old_grad, new_grad)
643+
if writeback_persists:
644+
# Modify the DDP gradients for parity
645+
for param in ddp_model.parameters():
646+
param.grad *= WRITEBACK_FACTOR
647+
648+
def _get_error_context(is_supported: bool):
649+
return (
650+
contextlib.suppress() if is_supported
651+
else self.assertRaises(NotImplementedError)
652+
) # some configs not implemented yet
653+
654+
def _get_fsdp_grads(fsdp_model: FSDP, is_supported: bool):
655+
if is_supported:
656+
return [
657+
param.grad.clone() for param in fsdp_model.parameters()
658+
if param.grad is not None
659+
]
660+
return None # unused
661+
662+
is_supported = use_orig_params and not offload_to_cpu
663+
model = TransformerWithSharedParams.init(
664+
self.process_group,
665+
FSDPInitMode.NO_FSDP,
666+
CUDAInitMode.CUDA_BEFORE,
667+
deterministic=True,
668+
)
669+
ddp_model = DDP(model, device_ids=[self.rank])
670+
fsdp_model = TransformerWithSharedParams.init(
671+
self.process_group,
672+
FSDPInitMode.RECURSIVE,
673+
CUDAInitMode.CUDA_BEFORE,
674+
deterministic=True,
675+
fsdp_kwargs={
676+
"use_orig_params": use_orig_params,
677+
"sharding_strategy": sharding_strategy,
678+
},
679+
)
680+
with FSDP.summon_full_params(fsdp_model):
681+
for p1, p2 in zip(ddp_model.module.parameters(), fsdp_model.parameters()):
682+
assert torch.all(torch.isclose(p1, p2))
683+
684+
# Check `summon_full_params()` after backward
685+
inp = fsdp_model.get_input(torch.device("cuda"))
686+
ddp_out = ddp_model(*inp)
687+
fsdp_out = fsdp_model(*inp)
688+
ddp_out.sum().backward()
689+
fsdp_out.sum().backward()
690+
old_fsdp_grads = _get_fsdp_grads(fsdp_model, is_supported)
691+
with _get_error_context(is_supported):
692+
_check_grads(ddp_model, fsdp_model, old_fsdp_grads)
693+
694+
# Check `summon_full_params()` between forward and backward
695+
inp = fsdp_model.get_input(torch.device("cuda"))
696+
ddp_out = ddp_model(*inp)
697+
fsdp_out = fsdp_model(*inp)
698+
old_fsdp_grads = _get_fsdp_grads(fsdp_model, is_supported)
699+
with _get_error_context(is_supported):
700+
_check_grads(ddp_model, fsdp_model, old_fsdp_grads)
701+
581702

582703
instantiate_parametrized_tests(TestSummonFullParams)
583704
instantiate_parametrized_tests(TestSummonFullParamsNoShard)

torch/distributed/fsdp/flat_param.py

Lines changed: 93 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -846,6 +846,37 @@ def _free_low_precision_sharded_param(self):
846846
self._check_low_precision_shard()
847847
_free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined]
848848

849+
@torch.no_grad()
850+
def unshard_grad(self):
851+
if not self.uses_sharded_strategy:
852+
self._use_unsharded_grad_views()
853+
return
854+
flat_param = self.flat_param
855+
self._check_unsharded(flat_param)
856+
padded_unsharded_grad = torch.empty(
857+
flat_param._padded_unsharded_size, # type: ignore[attr-defined]
858+
device=self.device,
859+
)
860+
if flat_param.grad is None:
861+
flat_param._saved_grad_shard = None # type: ignore[attr-defined]
862+
sharded_grad = torch.zeros_like(flat_param) # type: ignore[attr-defined]
863+
else:
864+
self._check_sharded(flat_param.grad)
865+
flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined]
866+
sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
867+
dist._all_gather_base(padded_unsharded_grad, sharded_grad, self.process_group)
868+
unsharded_size = self.flat_param._unpadded_unsharded_size
869+
flat_param.grad = padded_unsharded_grad[:unsharded_size.numel()].view(unsharded_size)
870+
self._use_unsharded_grad_views()
871+
872+
def reshard_grad(self):
873+
if self._use_orig_params:
874+
self._use_sharded_grad_views()
875+
if not self.uses_sharded_strategy:
876+
return
877+
self.flat_param.grad = self.flat_param._saved_grad_shard # type: ignore[attr-defined]
878+
delattr(self.flat_param, "_saved_grad_shard")
879+
849880
def prepare_gradient_for_backward(self):
850881
"""
851882
Prepares the gradient for the backward computation by saving and
@@ -1093,7 +1124,7 @@ def _use_unsharded_views(self, as_params: bool) -> None:
10931124
be used during forward/backward computation and when hiding the
10941125
original parameters from :meth:`nn.Module.named_parameters`.
10951126
"""
1096-
self._check_unsharded()
1127+
self._check_unsharded(self.flat_param)
10971128
views = self._get_unflat_views(self.flat_param)
10981129
for i, (view, (param_name, module, _)) in enumerate(
10991130
zip(views, self.flat_param._param_infos)
@@ -1139,6 +1170,41 @@ def _use_unsharded_views(self, as_params: bool) -> None:
11391170
else:
11401171
setattr(module, param_name, prim_param)
11411172

1173+
def _use_unsharded_grad_views(self) -> None:
1174+
"""
1175+
Unflattens the unsharded flattened parameter's gradient by setting the
1176+
original module parameter variables' gradients to be views into it.
1177+
"""
1178+
# Expects the gradient to be in `flat_param.grad`
1179+
if self.flat_param.grad is None:
1180+
return
1181+
self._check_unsharded(self.flat_param.grad)
1182+
views = self._get_unflat_views(self.flat_param, self.flat_param.grad)
1183+
for i, (view, (param_name, module, _)) in enumerate(
1184+
zip(views, self.flat_param._param_infos)
1185+
):
1186+
p_assert(
1187+
hasattr(module, param_name),
1188+
f"{self.flat_param._prefixed_param_names[i]} is missing",
1189+
)
1190+
param = getattr(module, param_name)
1191+
param.grad = view
1192+
for i, (
1193+
param_name,
1194+
module,
1195+
module_name,
1196+
prim_param_name,
1197+
prim_module,
1198+
_,
1199+
) in enumerate(self.flat_param._shared_param_infos):
1200+
p_assert(
1201+
hasattr(module, param_name),
1202+
f"{module_name + '.' + param_name if module_name else param_name} is missing",
1203+
) # did not save prefixed name
1204+
param = getattr(module, param_name)
1205+
prim_param = getattr(prim_module, prim_param_name)
1206+
param.grad = prim_param.grad
1207+
11421208
@contextlib.contextmanager
11431209
def unflatten_as_params(self) -> Generator:
11441210
"""
@@ -1223,16 +1289,7 @@ def _use_sharded_grad_views(self) -> None:
12231289
"""
12241290
flat_param = self.flat_param
12251291
self._check_sharded(flat_param)
1226-
# Priority: `_cpu_grad` > `_saved_grad_shard` > `grad`
1227-
# - CPU offloading: `_cpu_grad`
1228-
# - No CPU offloading + sharded strategies: `_saved_grad_shard`
1229-
# - No CPU offloading + `NO_SHARD`: `grad`
1230-
if hasattr(flat_param, "_cpu_grad"):
1231-
grad = flat_param._cpu_grad # type: ignore[attr-defined]
1232-
elif hasattr(flat_param, "_saved_grad_shard"):
1233-
grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
1234-
else:
1235-
grad = flat_param.grad
1292+
grad = self.sharded_grad
12361293
if grad is None:
12371294
return # no-op
12381295
self._check_sharded(grad)
@@ -1474,6 +1531,26 @@ def parameter_module_names(self) -> Iterator[Tuple[str, str]]:
14741531
):
14751532
yield (param_name, module_name)
14761533

1534+
@property
1535+
def sharded_grad(self) -> Optional[Tensor]:
1536+
"""Returns the handle's sharded gradient."""
1537+
flat_param = self.flat_param
1538+
# Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad`
1539+
# - CPU offloading: `_cpu_grad`
1540+
# - No CPU offloading + sharded strategies: `_saved_grad_shard`
1541+
# - No CPU offloading + `NO_SHARD`: `grad`
1542+
if hasattr(flat_param, "_cpu_grad"):
1543+
grad = flat_param._cpu_grad # type: ignore[attr-defined]
1544+
elif hasattr(flat_param, "_saved_grad_shard"):
1545+
grad = flat_param._saved_grad_shard # type: ignore[attr-defined]
1546+
else:
1547+
p_assert(
1548+
flat_param.grad is None or not self.uses_sharded_strategy,
1549+
"Sharded strategies should use `_cpu_grad` or `_saved_grad_shard`",
1550+
)
1551+
grad = flat_param.grad
1552+
return grad
1553+
14771554
#######################
14781555
# CHECKS & INVARIANTS #
14791556
#######################
@@ -1520,13 +1597,13 @@ def _check_low_precision_shard(self):
15201597
f"Expects the low precision shard to be on {self.device} but got {device}",
15211598
)
15221599

1523-
def _check_unsharded(self):
1524-
msg_prefix = "Expects the flattened parameter to be unsharded "
1525-
p_assert(self.flat_param is not None, msg_prefix + "but got `None`")
1600+
def _check_unsharded(self, tensor: Tensor):
1601+
msg_prefix = "Expects tensor to be unsharded "
1602+
p_assert(tensor is not None, msg_prefix + "but got `None`")
15261603
unsharded_size = self.flat_param._unpadded_unsharded_size
15271604
p_assert(
1528-
self.flat_param.size() == unsharded_size,
1529-
msg_prefix + f"with size {unsharded_size} but got {self.flat_param.size()}",
1605+
tensor.size() == unsharded_size,
1606+
msg_prefix + f"with size {unsharded_size} but got {tensor.size()}",
15301607
)
15311608

15321609
def _check_sharded(self, tensor: Tensor):

0 commit comments

Comments
 (0)