Skip to content

[PT-D][Checkpoint] Fix load_sharded_optimizer_state_dict() when default flatten_sharded_tensors to True #92823

@wz337

Description

@wz337

🚀 The feature, motivation and pitch

When we default flatten_sharded_tensors and run the dcp test, we will run into the error below.

python3 test/distributed/checkpoint/test_fsdp_optim_state.py 

The param_groups (last key) is not loaded properly when we default flatten_sharded_tensors to True.
The problematic one(bytes_io not loading properly):

{'optim': ......'param_groups': '<bytes_io>'}}

The correct one when we default flatten_sharded_tensors to False. Check param_groups (the last key).

{'optim':......'param_groups': [{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': ['weight', 'bias']}]}}

*** Additional info:

Full optimizer_state_dict load_sharded_optimizer_state_dict()
The problematic one:

{'optim': {'state': {'weight': {'exp_avg': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 8], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 8], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 8], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 8], placement=rank:3/cuda:3)], size=torch.Size([8, 8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'exp_avg_sq': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 8], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 8], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 8], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 8], placement=rank:3/cuda:3)], size=torch.Size([8, 8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'step': tensor(1., device='cuda:1')}, 'bias': {'exp_avg': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4], shard_sizes=[2], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6], shard_sizes=[2], placement=rank:3/cuda:3)], size=torch.Size([8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'exp_avg_sq': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4], shard_sizes=[2], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6], shard_sizes=[2], placement=rank:3/cuda:3)], size=torch.Size([8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'step': tensor(1., device='cuda:1')}}, 'param_groups': '<bytes_io>'}}

The correct one:

{'optim': {'state': {'weight': {'exp_avg': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 8], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 8], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 8], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 8], placement=rank:3/cuda:3)], size=torch.Size([8, 8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'exp_avg_sq': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 8], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 8], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 8], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 8], placement=rank:3/cuda:3)], size=torch.Size([8, 8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'step': tensor(1., device='cuda:1')}, 'bias': {'exp_avg': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4], shard_sizes=[2], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6], shard_sizes=[2], placement=rank:3/cuda:3)], size=torch.Size([8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'exp_avg_sq': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4], shard_sizes=[2], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6], shard_sizes=[2], placement=rank:3/cuda:3)], size=torch.Size([8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'step': tensor(1., device='cuda:1')}}, 'param_groups': [{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': ['weight', 'bias']}]}}

Error Trace:

(pytorch) irisz@a100-st-p4d24xlarge-53:~/cluster/work/pytorch$ python3 test/distributed/checkpoint/test_fsdp_optim_state.py 
Fail to import hypothesis in common_utils, tests are not derandomized
/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_cuda.py:19: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  CUDA11OrLater = torch.version.cuda and LooseVersion(torch.version.cuda) >= "11.0"
/fsx/users/irisz/conda/envs/pytorch/lib/python3.9/site-packages/setuptools/_distutils/version.py:351: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
  other = LooseVersion(other)
INFO:torch.testing._internal.common_distributed:Started process 0 with pid 79821
INFO:torch.testing._internal.common_distributed:Started process 1 with pid 79822
INFO:torch.testing._internal.common_distributed:Started process 2 with pid 79823
INFO:torch.testing._internal.common_distributed:Started process 3 with pid 79824
Fail to import hypothesis in common_utils, tests are not derandomized
Fail to import hypothesis in common_utils, tests are not derandomized
Fail to import hypothesis in common_utils, tests are not derandomized
Fail to import hypothesis in common_utils, tests are not derandomized
INFO:torch.testing._internal.common_distributed:Starting event listener thread for rank 2
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 2
INFO:torch.testing._internal.common_distributed:Starting event listener thread for rank 0
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 0
INFO:torch.testing._internal.common_distributed:Starting event listener thread for rank 1
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 1
INFO:torch.testing._internal.common_distributed:Starting event listener thread for rank 3
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 3
INFO:torch.distributed.distributed_c10d:Rank 3: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
INFO:torch.distributed.distributed_c10d:Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
INFO:torch.distributed.distributed_c10d:Rank 2: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
INFO:torch.distributed.distributed_c10d:Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
Using temp directory: /tmp/tmpmsq7kqeb
/fsx/users/irisz/work/pytorch/torch/distributed/checkpoint/filesystem.py:157: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if tensor.storage().size() != tensor.numel():
/fsx/users/irisz/work/pytorch/torch/distributed/checkpoint/filesystem.py:157: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if tensor.storage().size() != tensor.numel():
/fsx/users/irisz/work/pytorch/torch/distributed/checkpoint/filesystem.py:157: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if tensor.storage().size() != tensor.numel():
/fsx/users/irisz/work/pytorch/torch/distributed/checkpoint/filesystem.py:157: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  if tensor.storage().size() != tensor.numel():
/fsx/users/irisz/work/pytorch/torch/distributed/distributed_c10d.py:2533: UserWarning: torch.distributed._all_gather_base is a private function and will be deprecated. Please use torch.distributed.all_gather_into_tensor instead.
  warnings.warn(
/fsx/users/irisz/work/pytorch/torch/distributed/distributed_c10d.py:2533: UserWarning: torch.distributed._all_gather_base is a private function and will be deprecated. Please use torch.distributed.all_gather_into_tensor instead.
  warnings.warn(
/fsx/users/irisz/work/pytorch/torch/distributed/distributed_c10d.py:2533: UserWarning: torch.distributed._all_gather_base is a private function and will be deprecated. Please use torch.distributed.all_gather_into_tensor instead.
  warnings.warn(
/fsx/users/irisz/work/pytorch/torch/distributed/distributed_c10d.py:2533: UserWarning: torch.distributed._all_gather_base is a private function and will be deprecated. Please use torch.distributed.all_gather_into_tensor instead.
  warnings.warn(
ERROR:torch.testing._internal.common_distributed:Caught exception: 
Traceback (most recent call last):
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 663, in run_test
    getattr(self, test_name)()
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 541, in wrapper
    fn()
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 172, in wrapper
    func(self)  # type: ignore[misc]
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 172, in wrapper
    return func(*args, **kwargs)
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/checkpoint_utils.py", line 34, in wrapper
    func(self)
  File "/fsx/users/irisz/work/pytorch/test/distributed/checkpoint/test_fsdp_optim_state.py", line 89, in test_distributed_tensor_planner
    flattened_osd = FSDP.flatten_sharded_optim_state_dict(
  File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1434, in flatten_sharded_optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_to_load_impl(
  File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1231, in _optim_state_dict_to_load_impl
    return _rekey_sharded_optim_state_dict(
  File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/_optim_utils.py", line 972, in _rekey_sharded_optim_state_dict
    for unflat_param_name in unflat_param_group["params"]
TypeError: string indices must be integers
 exiting process 1 with exit code: 10
ERROR:torch.testing._internal.common_distributed:Caught exception: 
Traceback (most recent call last):
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 663, in run_test
    getattr(self, test_name)()
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 541, in wrapper
    fn()
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 172, in wrapper
    func(self)  # type: ignore[misc]
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 172, in wrapper
    return func(*args, **kwargs)
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/checkpoint_utils.py", line 34, in wrapper
    func(self)
  File "/fsx/users/irisz/work/pytorch/test/distributed/checkpoint/test_fsdp_optim_state.py", line 89, in test_distributed_tensor_planner
    flattened_osd = FSDP.flatten_sharded_optim_state_dict(
  File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1434, in flatten_sharded_optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_to_load_impl(
  File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1231, in _optim_state_dict_to_load_impl
    return _rekey_sharded_optim_state_dict(
  File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/_optim_utils.py", line 972, in _rekey_sharded_optim_state_dict
    for unflat_param_name in unflat_param_group["params"]
TypeError: string indices must be integers
 exiting process 0 with exit code: 10
ERROR:torch.testing._internal.common_distributed:Caught exception: 
Traceback (most recent call last):
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 663, in run_test
    getattr(self, test_name)()
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 541, in wrapper
    fn()
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 172, in wrapper
    func(self)  # type: ignore[misc]
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 172, in wrapper
    return func(*args, **kwargs)
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/checkpoint_utils.py", line 34, in wrapper
    func(self)
  File "/fsx/users/irisz/work/pytorch/test/distributed/checkpoint/test_fsdp_optim_state.py", line 89, in test_distributed_tensor_planner
    flattened_osd = FSDP.flatten_sharded_optim_state_dict(
  File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1434, in flatten_sharded_optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_to_load_impl(
  File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1231, in _optim_state_dict_to_load_impl
    return _rekey_sharded_optim_state_dict(
  File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/_optim_utils.py", line 972, in _rekey_sharded_optim_state_dict
    for unflat_param_name in unflat_param_group["params"]
TypeError: string indices must be integers
 exiting process 2 with exit code: 10
ERROR:torch.testing._internal.common_distributed:Caught exception: 
Traceback (most recent call last):
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 663, in run_test
    getattr(self, test_name)()
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 541, in wrapper
    fn()
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 172, in wrapper
    func(self)  # type: ignore[misc]
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 172, in wrapper
    return func(*args, **kwargs)
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/checkpoint_utils.py", line 34, in wrapper
    func(self)
  File "/fsx/users/irisz/work/pytorch/test/distributed/checkpoint/test_fsdp_optim_state.py", line 89, in test_distributed_tensor_planner
    flattened_osd = FSDP.flatten_sharded_optim_state_dict(
  File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1434, in flatten_sharded_optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_to_load_impl(
  File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1231, in _optim_state_dict_to_load_impl
    return _rekey_sharded_optim_state_dict(
  File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/_optim_utils.py", line 972, in _rekey_sharded_optim_state_dict
    for unflat_param_name in unflat_param_group["params"]
TypeError: string indices must be integers
 exiting process 3 with exit code: 10
Process 3 terminated with exit code 10, terminating remaining processes.
E
======================================================================
ERROR: test_distributed_tensor_planner (__main__.FsdpOptimStateCheckpoint)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 539, in wrapper
    self._join_processes(fn)
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 765, in _join_processes
    self._check_return_codes(elapsed_time)
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 810, in _check_return_codes
    raise RuntimeError(error)
RuntimeError: Process 3 exited with error code 10 and exception:
Traceback (most recent call last):
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 663, in run_test
    getattr(self, test_name)()
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 541, in wrapper
    fn()
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 172, in wrapper
    func(self)  # type: ignore[misc]
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 172, in wrapper
    return func(*args, **kwargs)
  File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/checkpoint_utils.py", line 34, in wrapper
    func(self)
  File "/fsx/users/irisz/work/pytorch/test/distributed/checkpoint/test_fsdp_optim_state.py", line 89, in test_distributed_tensor_planner
    flattened_osd = FSDP.flatten_sharded_optim_state_dict(
  File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1434, in flatten_sharded_optim_state_dict
    return FullyShardedDataParallel._optim_state_dict_to_load_impl(
  File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1231, in _optim_state_dict_to_load_impl
    return _rekey_sharded_optim_state_dict(
  File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/_optim_utils.py", line 972, in _rekey_sharded_optim_state_dict
    for unflat_param_name in unflat_param_group["params"]
TypeError: string indices must be integers



----------------------------------------------------------------------
Ran 1 test in 10.869s

FAILED (errors=1)

Alternatives

No response

Additional context

No response

cc. @kumpera

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: distributedAdd this issue/PR to distributed oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions