@@ -1436,6 +1436,71 @@ def return_fn(fqn_to_param_info):
14361436 )
14371437
14381438
1439+ def _all_gather_optim_state (
1440+ fsdp_state : _FSDPState ,
1441+ optim_state : Dict [str , Any ],
1442+ ) -> Dict [str , Any ]:
1443+ """
1444+ All-gathering state from all the ranks. This API is slow as it uses
1445+ ``all_gather_object``. However, optim state_dict is not in the critical path.
1446+ We can fuse the communication across differnt state if the performance
1447+ becomes a problem.
1448+ """
1449+
1450+ # Pre-processed state to prepare for the all_gather_object call.
1451+ IS_ZERO_DIM_TENSOR_KEY = "__is_zero_dim_tensor"
1452+ processed_state : Dict [str , Any ] = {}
1453+ for state_name , value in sorted_items (optim_state ):
1454+ if torch .is_tensor (value ):
1455+ if value .dim () == 0 :
1456+ processed_state [state_name ] = value .item ()
1457+ processed_state [f"{ state_name } { IS_ZERO_DIM_TENSOR_KEY } " ] = value .dtype
1458+ else :
1459+ processed_state [state_name ] = value .to (fsdp_state .compute_device )
1460+ else :
1461+ processed_state [state_name ] = value
1462+
1463+ # Allgather the state
1464+ object_list : List [Dict [str , Any ]] = [{} for _ in range (fsdp_state .world_size )]
1465+ dist .all_gather_object (object_list , processed_state )
1466+
1467+ # Convert the gathered, pre-proccessed state of each rank to the original one.
1468+ gathered_state : Dict [str , Any ] = {}
1469+ for object_state in object_list :
1470+ for name , object_value in object_state .items ():
1471+ if IS_ZERO_DIM_TENSOR_KEY in name :
1472+ continue
1473+ curr_object_value = gathered_state .get (name , None )
1474+ dtype = object_state .get (f"{ name } { IS_ZERO_DIM_TENSOR_KEY } " , None )
1475+ if dtype is not None :
1476+ zero_dim_tensor = torch .tensor (object_value , dtype = dtype )
1477+ if curr_object_value is not None :
1478+ assert torch .equal (
1479+ zero_dim_tensor , curr_object_value
1480+ ), f"Different ranks have different value for { name } ."
1481+ else :
1482+ gathered_state [name ] = zero_dim_tensor
1483+ elif torch .is_tensor (object_value ):
1484+ if curr_object_value is not None :
1485+ curr_object_value .append (object_value .to (fsdp_state .compute_device ))
1486+ else :
1487+ gathered_state [name ] = [object_value .to (fsdp_state .compute_device )]
1488+ else :
1489+ if curr_object_value is not None :
1490+ assert (
1491+ curr_object_value == object_value
1492+ ), f"Different ranks have different value for { name } ."
1493+ else :
1494+ gathered_state [name ] = object_value
1495+
1496+ for name , value in list (gathered_state .items ()):
1497+ if not isinstance (value , list ) or not torch .is_tensor (value [0 ]):
1498+ continue
1499+ gathered_state [name ] = torch .cat (value )
1500+
1501+ return gathered_state
1502+
1503+
14391504def _gather_orig_param_state (
14401505 fsdp_param_info : FSDPParamInfo ,
14411506 fqn : str ,
@@ -1458,51 +1523,16 @@ def _gather_orig_param_state(
14581523 ):
14591524 return optim_state
14601525
1461- # Gathering state from all ranks. This step may be slow. However,
1462- # `state_dict()` is not in the critical path. We can fuse the communication
1463- # if the performance becomes a problem.
1464- state_objects = {
1465- state_name : value for state_name , value in sorted_items (optim_state )
1466- }
1467- object_list : List [Dict [str , Any ]] = [{} for _ in range (fsdp_state .world_size )]
1468- dist .all_gather_object (object_list , state_objects )
1469- orig_state : Dict [str , Any ] = {}
1470- for idx , state in enumerate (object_list ):
1471- for state_name , value in state .items ():
1472- curr_value = orig_state .get (state_name , [])
1473- if torch .is_tensor (value ):
1474- if value .dim () > 0 :
1475- curr_value .append (value .to (fsdp_state .compute_device ))
1476- orig_state [state_name ] = curr_value
1477- else : # zero dim tensor, e.g., step.
1478- if torch .is_tensor (curr_value ):
1479- assert torch .equal (curr_value , value )
1480- else :
1481- orig_state [state_name ] = value
1482- else :
1483- assert curr_value == [] or curr_value == value
1484- orig_state [state_name ] = value
1526+ gathered_state = _all_gather_optim_state (fsdp_state , optim_state )
14851527
14861528 # Unflatten state values.
1487- for state_name in orig_state .keys ():
1488- value = orig_state [state_name ]
1489- if not isinstance (value , list ) or not torch .is_tensor (value [0 ]):
1529+ for state_name , value in list (gathered_state .items ()):
1530+ if not torch .is_tensor (value ) or value .dim () == 0 :
14901531 continue
1491- try :
1492- value = torch .concat (value )[: flat_param ._numels [param_idx ]].reshape (
1493- flat_param ._shapes [param_idx ]
1494- )
1495- except Exception as e :
1496- raise Exception (
1497- (
1498- flat_param ._numels [param_idx ],
1499- flat_param ._shapes [param_idx ],
1500- len (value ),
1501- value [0 ].shape ,
1502- state_name ,
1503- fqn ,
1504- )
1505- )
1532+
1533+ value = value [: flat_param ._numels [param_idx ]].reshape (
1534+ flat_param ._shapes [param_idx ]
1535+ )
15061536 if shard_state :
15071537 assert fsdp_state .process_group is not None
15081538 value = _ext_chunk_tensor (
@@ -1513,8 +1543,8 @@ def _gather_orig_param_state(
15131543 fsdp_state .process_group ,
15141544 )
15151545 value = value .cpu ()
1516- orig_state [state_name ] = value
1517- return orig_state
1546+ gathered_state [state_name ] = value
1547+ return gathered_state
15181548
15191549
15201550def _shard_orig_param_state (
0 commit comments