1919import torch .distributed as dist
2020import torch .distributed .fsdp ._traversal_utils as traversal_utils
2121import torch .nn as nn
22+ import torch .nn .functional as F
2223from torch .distributed ._shard .sharded_tensor import ShardedTensor
2324from torch .distributed .fsdp ._common_utils import (
2425 _apply_to_modules ,
@@ -1436,6 +1437,113 @@ def return_fn(fqn_to_param_info):
14361437 )
14371438
14381439
1440+ @dataclass
1441+ class StateInfo :
1442+ tensors : Dict [str , _PosDimTensorInfo ]
1443+ scalar_tensors : Dict [str , torch .Tensor ]
1444+ non_tensors : Dict [str , Any ]
1445+
1446+
1447+ @dataclass
1448+ class AllGatherInfo :
1449+ tensors : List [torch .Tensor ]
1450+ numels : List [int ]
1451+ work : Optional [dist .Work ]
1452+
1453+
1454+ def _all_gather_optim_state (
1455+ fsdp_state : _FSDPState , optim_state : Dict [str , Any ], param_numel : int
1456+ ) -> Dict [str , Any ]:
1457+ """
1458+ All-gathering state from all the ranks. This API is slow as it uses
1459+ ``all_gather_object``. However, optim state_dict is not in the critical path.
1460+ We can fuse the communication across differnt state if the performance
1461+ becomes a problem.
1462+ """
1463+ # Allgather the scalar tensor state, non-tensor states and tensors metadata.
1464+ processed_state = StateInfo ({}, {}, {})
1465+ for state_name , value in sorted_items (optim_state ):
1466+ if torch .is_tensor (value ):
1467+ if value .dim () == 0 :
1468+ processed_state .scalar_tensors [state_name ] = value
1469+ else :
1470+ processed_state .tensors [state_name ] = _PosDimTensorInfo (
1471+ value .shape , value .dtype
1472+ )
1473+ else :
1474+ processed_state .non_tensors = value
1475+ object_list : List [StateInfo ] = [
1476+ processed_state for _ in range (fsdp_state .world_size )
1477+ ]
1478+ dist .all_gather_object (object_list , processed_state )
1479+
1480+ # Convert the gathered, pre-proccessed state of each rank to the original one.
1481+ gathered_state : Dict [str , Any ] = {}
1482+
1483+ all_tensor_states = sorted (
1484+ list (set ([n for state in object_list for n in state .tensors .keys ()]))
1485+ )
1486+ for name in all_tensor_states :
1487+ numels = []
1488+ dtype = torch .float
1489+ max_numel = 0
1490+ for object_state in object_list :
1491+ numels .append (0 )
1492+ info = object_state .tensors .get (name , None )
1493+ if info is not None :
1494+ numels [- 1 ] = info .shape .numel ()
1495+ dtype = info .dtype
1496+ max_numel = max (max_numel , numels [- 1 ])
1497+ local_state = (
1498+ optim_state [name ]
1499+ if name in optim_state
1500+ else torch .empty (max_numel , dtype = dtype , device = fsdp_state .compute_device )
1501+ )
1502+ if max_numel > local_state .numel ():
1503+ local_state = F .pad (local_state , [0 , max_numel - local_state .numel ()])
1504+ tensors = [
1505+ torch .empty (max_numel , dtype = dtype , device = fsdp_state .compute_device )
1506+ if rank != fsdp_state .rank
1507+ else local_state
1508+ for rank in range (len (object_list ))
1509+ ]
1510+ work = dist .all_gather (
1511+ tensors , local_state , group = fsdp_state .process_group , async_op = True
1512+ )
1513+ gathered_state [name ] = AllGatherInfo (tensors , numels , work )
1514+
1515+ for object_state in object_list :
1516+ for name , non_tensor_value in object_state .non_tensors .items ():
1517+ curr_non_tensor_value = gathered_state .get (name , None )
1518+ assert (
1519+ curr_non_tensor_value is None
1520+ or curr_non_tensor_value == non_tensor_value
1521+ ), f"Different ranks have different values for { name } ."
1522+ gathered_state [name ] = non_tensor_value
1523+
1524+ for name , scalar_tensor_value in object_state .scalar_tensors .items ():
1525+ curr_scalar_tensor_value = gathered_state .get (name , None )
1526+ assert curr_scalar_tensor_value is None or torch .equal (
1527+ scalar_tensor_value , curr_scalar_tensor_value
1528+ ), f"Different ranks have different values for { name } ."
1529+ gathered_state [name ] = scalar_tensor_value
1530+
1531+ for name , value in list (gathered_state .items ()):
1532+ if not isinstance (value , AllGatherInfo ):
1533+ continue
1534+ assert value .work is not None
1535+ value .work .wait ()
1536+ gathered_state [name ] = torch .cat (
1537+ [
1538+ rank_tensor [:rank_numel ]
1539+ for rank_tensor , rank_numel in zip (value .tensors , value .numels )
1540+ if rank_numel > 0
1541+ ]
1542+ )
1543+
1544+ return gathered_state
1545+
1546+
14391547def _gather_orig_param_state (
14401548 fsdp_param_info : FSDPParamInfo ,
14411549 fqn : str ,
@@ -1458,51 +1566,18 @@ def _gather_orig_param_state(
14581566 ):
14591567 return optim_state
14601568
1461- # Gathering state from all ranks. This step may be slow. However,
1462- # `state_dict()` is not in the critical path. We can fuse the communication
1463- # if the performance becomes a problem.
1464- state_objects = {
1465- state_name : value for state_name , value in sorted_items (optim_state )
1466- }
1467- object_list : List [Dict [str , Any ]] = [{} for _ in range (fsdp_state .world_size )]
1468- dist .all_gather_object (object_list , state_objects )
1469- orig_state : Dict [str , Any ] = {}
1470- for idx , state in enumerate (object_list ):
1471- for state_name , value in state .items ():
1472- curr_value = orig_state .get (state_name , [])
1473- if torch .is_tensor (value ):
1474- if value .dim () > 0 :
1475- curr_value .append (value .to (fsdp_state .compute_device ))
1476- orig_state [state_name ] = curr_value
1477- else : # zero dim tensor, e.g., step.
1478- if torch .is_tensor (curr_value ):
1479- assert torch .equal (curr_value , value )
1480- else :
1481- orig_state [state_name ] = value
1482- else :
1483- assert curr_value == [] or curr_value == value
1484- orig_state [state_name ] = value
1569+ gathered_state = _all_gather_optim_state (
1570+ fsdp_state , optim_state , flat_param ._numels [param_idx ]
1571+ )
14851572
14861573 # Unflatten state values.
1487- for state_name in orig_state .keys ():
1488- value = orig_state [state_name ]
1489- if not isinstance (value , list ) or not torch .is_tensor (value [0 ]):
1574+ for state_name , value in list (gathered_state .items ()):
1575+ if not torch .is_tensor (value ) or value .dim () == 0 :
14901576 continue
1491- try :
1492- value = torch .concat (value )[: flat_param ._numels [param_idx ]].reshape (
1493- flat_param ._shapes [param_idx ]
1494- )
1495- except Exception as e :
1496- raise Exception (
1497- (
1498- flat_param ._numels [param_idx ],
1499- flat_param ._shapes [param_idx ],
1500- len (value ),
1501- value [0 ].shape ,
1502- state_name ,
1503- fqn ,
1504- )
1505- )
1577+
1578+ value = value [: flat_param ._numels [param_idx ]].reshape (
1579+ flat_param ._shapes [param_idx ]
1580+ )
15061581 if shard_state :
15071582 assert fsdp_state .process_group is not None
15081583 value = _ext_chunk_tensor (
@@ -1513,8 +1588,8 @@ def _gather_orig_param_state(
15131588 fsdp_state .process_group ,
15141589 )
15151590 value = value .cpu ()
1516- orig_state [state_name ] = value
1517- return orig_state
1591+ gathered_state [state_name ] = value
1592+ return gathered_state
15181593
15191594
15201595def _shard_orig_param_state (
0 commit comments