@@ -846,6 +846,37 @@ def _free_low_precision_sharded_param(self):
846846 self ._check_low_precision_shard ()
847847 _free_storage (self .flat_param ._mp_shard ) # type: ignore[attr-defined]
848848
849+ @torch .no_grad ()
850+ def unshard_grad (self ):
851+ if not self .uses_sharded_strategy :
852+ self ._use_unsharded_grad_views ()
853+ return
854+ flat_param = self .flat_param
855+ self ._check_unsharded (flat_param )
856+ padded_unsharded_grad = torch .empty (
857+ flat_param ._padded_unsharded_size , # type: ignore[attr-defined]
858+ device = self .device ,
859+ )
860+ if flat_param .grad is None :
861+ flat_param ._saved_grad_shard = None # type: ignore[attr-defined]
862+ sharded_grad = torch .zeros_like (flat_param ) # type: ignore[attr-defined]
863+ else :
864+ self ._check_sharded (flat_param .grad )
865+ flat_param ._saved_grad_shard = flat_param .grad # type: ignore[attr-defined]
866+ sharded_grad = flat_param ._saved_grad_shard # type: ignore[attr-defined]
867+ dist ._all_gather_base (padded_unsharded_grad , sharded_grad , self .process_group )
868+ unsharded_size = self .flat_param ._unpadded_unsharded_size
869+ flat_param .grad = padded_unsharded_grad [:unsharded_size .numel ()].view (unsharded_size )
870+ self ._use_unsharded_grad_views ()
871+
872+ def reshard_grad (self ):
873+ if self ._use_orig_params :
874+ self ._use_sharded_grad_views ()
875+ if not self .uses_sharded_strategy :
876+ return
877+ self .flat_param .grad = self .flat_param ._saved_grad_shard # type: ignore[attr-defined]
878+ delattr (self .flat_param , "_saved_grad_shard" )
879+
849880 def prepare_gradient_for_backward (self ):
850881 """
851882 Prepares the gradient for the backward computation by saving and
@@ -1093,7 +1124,7 @@ def _use_unsharded_views(self, as_params: bool) -> None:
10931124 be used during forward/backward computation and when hiding the
10941125 original parameters from :meth:`nn.Module.named_parameters`.
10951126 """
1096- self ._check_unsharded ()
1127+ self ._check_unsharded (self . flat_param )
10971128 views = self ._get_unflat_views (self .flat_param )
10981129 for i , (view , (param_name , module , _ )) in enumerate (
10991130 zip (views , self .flat_param ._param_infos )
@@ -1139,6 +1170,41 @@ def _use_unsharded_views(self, as_params: bool) -> None:
11391170 else :
11401171 setattr (module , param_name , prim_param )
11411172
1173+ def _use_unsharded_grad_views (self ) -> None :
1174+ """
1175+ Unflattens the unsharded flattened parameter's gradient by setting the
1176+ original module parameter variables' gradients to be views into it.
1177+ """
1178+ # Expects the gradient to be in `flat_param.grad`
1179+ if self .flat_param .grad is None :
1180+ return
1181+ self ._check_unsharded (self .flat_param .grad )
1182+ views = self ._get_unflat_views (self .flat_param , self .flat_param .grad )
1183+ for i , (view , (param_name , module , _ )) in enumerate (
1184+ zip (views , self .flat_param ._param_infos )
1185+ ):
1186+ p_assert (
1187+ hasattr (module , param_name ),
1188+ f"{ self .flat_param ._prefixed_param_names [i ]} is missing" ,
1189+ )
1190+ param = getattr (module , param_name )
1191+ param .grad = view
1192+ for i , (
1193+ param_name ,
1194+ module ,
1195+ module_name ,
1196+ prim_param_name ,
1197+ prim_module ,
1198+ _ ,
1199+ ) in enumerate (self .flat_param ._shared_param_infos ):
1200+ p_assert (
1201+ hasattr (module , param_name ),
1202+ f"{ module_name + '.' + param_name if module_name else param_name } is missing" ,
1203+ ) # did not save prefixed name
1204+ param = getattr (module , param_name )
1205+ prim_param = getattr (prim_module , prim_param_name )
1206+ param .grad = prim_param .grad
1207+
11421208 @contextlib .contextmanager
11431209 def unflatten_as_params (self ) -> Generator :
11441210 """
@@ -1223,16 +1289,7 @@ def _use_sharded_grad_views(self) -> None:
12231289 """
12241290 flat_param = self .flat_param
12251291 self ._check_sharded (flat_param )
1226- # Priority: `_cpu_grad` > `_saved_grad_shard` > `grad`
1227- # - CPU offloading: `_cpu_grad`
1228- # - No CPU offloading + sharded strategies: `_saved_grad_shard`
1229- # - No CPU offloading + `NO_SHARD`: `grad`
1230- if hasattr (flat_param , "_cpu_grad" ):
1231- grad = flat_param ._cpu_grad # type: ignore[attr-defined]
1232- elif hasattr (flat_param , "_saved_grad_shard" ):
1233- grad = flat_param ._saved_grad_shard # type: ignore[attr-defined]
1234- else :
1235- grad = flat_param .grad
1292+ grad = self .sharded_grad
12361293 if grad is None :
12371294 return # no-op
12381295 self ._check_sharded (grad )
@@ -1474,6 +1531,26 @@ def parameter_module_names(self) -> Iterator[Tuple[str, str]]:
14741531 ):
14751532 yield (param_name , module_name )
14761533
1534+ @property
1535+ def sharded_grad (self ) -> Optional [Tensor ]:
1536+ """Returns the handle's sharded gradient."""
1537+ flat_param = self .flat_param
1538+ # Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad`
1539+ # - CPU offloading: `_cpu_grad`
1540+ # - No CPU offloading + sharded strategies: `_saved_grad_shard`
1541+ # - No CPU offloading + `NO_SHARD`: `grad`
1542+ if hasattr (flat_param , "_cpu_grad" ):
1543+ grad = flat_param ._cpu_grad # type: ignore[attr-defined]
1544+ elif hasattr (flat_param , "_saved_grad_shard" ):
1545+ grad = flat_param ._saved_grad_shard # type: ignore[attr-defined]
1546+ else :
1547+ p_assert (
1548+ flat_param .grad is None or not self .uses_sharded_strategy ,
1549+ "Sharded strategies should use `_cpu_grad` or `_saved_grad_shard`" ,
1550+ )
1551+ grad = flat_param .grad
1552+ return grad
1553+
14771554 #######################
14781555 # CHECKS & INVARIANTS #
14791556 #######################
@@ -1520,13 +1597,13 @@ def _check_low_precision_shard(self):
15201597 f"Expects the low precision shard to be on { self .device } but got { device } " ,
15211598 )
15221599
1523- def _check_unsharded (self ):
1524- msg_prefix = "Expects the flattened parameter to be unsharded "
1525- p_assert (self . flat_param is not None , msg_prefix + "but got `None`" )
1600+ def _check_unsharded (self , tensor : Tensor ):
1601+ msg_prefix = "Expects tensor to be unsharded "
1602+ p_assert (tensor is not None , msg_prefix + "but got `None`" )
15261603 unsharded_size = self .flat_param ._unpadded_unsharded_size
15271604 p_assert (
1528- self . flat_param .size () == unsharded_size ,
1529- msg_prefix + f"with size { unsharded_size } but got { self . flat_param .size ()} " ,
1605+ tensor .size () == unsharded_size ,
1606+ msg_prefix + f"with size { unsharded_size } but got { tensor .size ()} " ,
15301607 )
15311608
15321609 def _check_sharded (self , tensor : Tensor ):
0 commit comments