-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue
Description
🚀 The feature, motivation and pitch
When we default flatten_sharded_tensors and run the dcp test, we will run into the error below.
python3 test/distributed/checkpoint/test_fsdp_optim_state.py
The param_groups (last key) is not loaded properly when we default flatten_sharded_tensors to True.
The problematic one(bytes_io not loading properly):
{'optim': ......'param_groups': '<bytes_io>'}}
The correct one when we default flatten_sharded_tensors to False. Check param_groups (the last key).
{'optim':......'param_groups': [{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': ['weight', 'bias']}]}}
*** Additional info:
Full optimizer_state_dict load_sharded_optimizer_state_dict()
The problematic one:
{'optim': {'state': {'weight': {'exp_avg': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 8], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 8], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 8], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 8], placement=rank:3/cuda:3)], size=torch.Size([8, 8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'exp_avg_sq': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 8], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 8], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 8], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 8], placement=rank:3/cuda:3)], size=torch.Size([8, 8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'step': tensor(1., device='cuda:1')}, 'bias': {'exp_avg': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4], shard_sizes=[2], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6], shard_sizes=[2], placement=rank:3/cuda:3)], size=torch.Size([8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'exp_avg_sq': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4], shard_sizes=[2], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6], shard_sizes=[2], placement=rank:3/cuda:3)], size=torch.Size([8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'step': tensor(1., device='cuda:1')}}, 'param_groups': '<bytes_io>'}}
The correct one:
{'optim': {'state': {'weight': {'exp_avg': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 8], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 8], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 8], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 8], placement=rank:3/cuda:3)], size=torch.Size([8, 8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'exp_avg_sq': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2, 8], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2, 0], shard_sizes=[2, 8], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4, 0], shard_sizes=[2, 8], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6, 0], shard_sizes=[2, 8], placement=rank:3/cuda:3)], size=torch.Size([8, 8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'step': tensor(1., device='cuda:1')}, 'bias': {'exp_avg': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4], shard_sizes=[2], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6], shard_sizes=[2], placement=rank:3/cuda:3)], size=torch.Size([8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'exp_avg_sq': ShardedTensor(ShardedTensorMetadata(shards_metadata=[ShardMetadata(shard_offsets=[0], shard_sizes=[2], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2], shard_sizes=[2], placement=rank:1/cuda:1), ShardMetadata(shard_offsets=[4], shard_sizes=[2], placement=rank:2/cuda:2), ShardMetadata(shard_offsets=[6], shard_sizes=[2], placement=rank:3/cuda:3)], size=torch.Size([8]), tensor_properties=TensorProperties(dtype=torch.float32, layout=torch.strided, requires_grad=False, memory_format=torch.contiguous_format, pin_memory=False))), 'step': tensor(1., device='cuda:1')}}, 'param_groups': [{'lr': 0.1, 'betas': (0.9, 0.999), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'maximize': False, 'foreach': None, 'capturable': False, 'differentiable': False, 'fused': None, 'params': ['weight', 'bias']}]}}
Error Trace:
(pytorch) irisz@a100-st-p4d24xlarge-53:~/cluster/work/pytorch$ python3 test/distributed/checkpoint/test_fsdp_optim_state.py
Fail to import hypothesis in common_utils, tests are not derandomized
/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_cuda.py:19: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
CUDA11OrLater = torch.version.cuda and LooseVersion(torch.version.cuda) >= "11.0"
/fsx/users/irisz/conda/envs/pytorch/lib/python3.9/site-packages/setuptools/_distutils/version.py:351: DeprecationWarning: distutils Version classes are deprecated. Use packaging.version instead.
other = LooseVersion(other)
INFO:torch.testing._internal.common_distributed:Started process 0 with pid 79821
INFO:torch.testing._internal.common_distributed:Started process 1 with pid 79822
INFO:torch.testing._internal.common_distributed:Started process 2 with pid 79823
INFO:torch.testing._internal.common_distributed:Started process 3 with pid 79824
Fail to import hypothesis in common_utils, tests are not derandomized
Fail to import hypothesis in common_utils, tests are not derandomized
Fail to import hypothesis in common_utils, tests are not derandomized
Fail to import hypothesis in common_utils, tests are not derandomized
INFO:torch.testing._internal.common_distributed:Starting event listener thread for rank 2
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 2
INFO:torch.testing._internal.common_distributed:Starting event listener thread for rank 0
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 0
INFO:torch.testing._internal.common_distributed:Starting event listener thread for rank 1
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 1
INFO:torch.testing._internal.common_distributed:Starting event listener thread for rank 3
INFO:torch.distributed.distributed_c10d:Added key: store_based_barrier_key:1 to store for rank: 3
INFO:torch.distributed.distributed_c10d:Rank 3: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
INFO:torch.distributed.distributed_c10d:Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
INFO:torch.distributed.distributed_c10d:Rank 2: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
INFO:torch.distributed.distributed_c10d:Rank 1: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
Using temp directory: /tmp/tmpmsq7kqeb
/fsx/users/irisz/work/pytorch/torch/distributed/checkpoint/filesystem.py:157: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
if tensor.storage().size() != tensor.numel():
/fsx/users/irisz/work/pytorch/torch/distributed/checkpoint/filesystem.py:157: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
if tensor.storage().size() != tensor.numel():
/fsx/users/irisz/work/pytorch/torch/distributed/checkpoint/filesystem.py:157: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
if tensor.storage().size() != tensor.numel():
/fsx/users/irisz/work/pytorch/torch/distributed/checkpoint/filesystem.py:157: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
if tensor.storage().size() != tensor.numel():
/fsx/users/irisz/work/pytorch/torch/distributed/distributed_c10d.py:2533: UserWarning: torch.distributed._all_gather_base is a private function and will be deprecated. Please use torch.distributed.all_gather_into_tensor instead.
warnings.warn(
/fsx/users/irisz/work/pytorch/torch/distributed/distributed_c10d.py:2533: UserWarning: torch.distributed._all_gather_base is a private function and will be deprecated. Please use torch.distributed.all_gather_into_tensor instead.
warnings.warn(
/fsx/users/irisz/work/pytorch/torch/distributed/distributed_c10d.py:2533: UserWarning: torch.distributed._all_gather_base is a private function and will be deprecated. Please use torch.distributed.all_gather_into_tensor instead.
warnings.warn(
/fsx/users/irisz/work/pytorch/torch/distributed/distributed_c10d.py:2533: UserWarning: torch.distributed._all_gather_base is a private function and will be deprecated. Please use torch.distributed.all_gather_into_tensor instead.
warnings.warn(
ERROR:torch.testing._internal.common_distributed:Caught exception:
Traceback (most recent call last):
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 663, in run_test
getattr(self, test_name)()
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 541, in wrapper
fn()
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 172, in wrapper
func(self) # type: ignore[misc]
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 172, in wrapper
return func(*args, **kwargs)
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/checkpoint_utils.py", line 34, in wrapper
func(self)
File "/fsx/users/irisz/work/pytorch/test/distributed/checkpoint/test_fsdp_optim_state.py", line 89, in test_distributed_tensor_planner
flattened_osd = FSDP.flatten_sharded_optim_state_dict(
File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1434, in flatten_sharded_optim_state_dict
return FullyShardedDataParallel._optim_state_dict_to_load_impl(
File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1231, in _optim_state_dict_to_load_impl
return _rekey_sharded_optim_state_dict(
File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/_optim_utils.py", line 972, in _rekey_sharded_optim_state_dict
for unflat_param_name in unflat_param_group["params"]
TypeError: string indices must be integers
exiting process 1 with exit code: 10
ERROR:torch.testing._internal.common_distributed:Caught exception:
Traceback (most recent call last):
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 663, in run_test
getattr(self, test_name)()
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 541, in wrapper
fn()
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 172, in wrapper
func(self) # type: ignore[misc]
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 172, in wrapper
return func(*args, **kwargs)
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/checkpoint_utils.py", line 34, in wrapper
func(self)
File "/fsx/users/irisz/work/pytorch/test/distributed/checkpoint/test_fsdp_optim_state.py", line 89, in test_distributed_tensor_planner
flattened_osd = FSDP.flatten_sharded_optim_state_dict(
File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1434, in flatten_sharded_optim_state_dict
return FullyShardedDataParallel._optim_state_dict_to_load_impl(
File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1231, in _optim_state_dict_to_load_impl
return _rekey_sharded_optim_state_dict(
File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/_optim_utils.py", line 972, in _rekey_sharded_optim_state_dict
for unflat_param_name in unflat_param_group["params"]
TypeError: string indices must be integers
exiting process 0 with exit code: 10
ERROR:torch.testing._internal.common_distributed:Caught exception:
Traceback (most recent call last):
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 663, in run_test
getattr(self, test_name)()
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 541, in wrapper
fn()
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 172, in wrapper
func(self) # type: ignore[misc]
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 172, in wrapper
return func(*args, **kwargs)
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/checkpoint_utils.py", line 34, in wrapper
func(self)
File "/fsx/users/irisz/work/pytorch/test/distributed/checkpoint/test_fsdp_optim_state.py", line 89, in test_distributed_tensor_planner
flattened_osd = FSDP.flatten_sharded_optim_state_dict(
File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1434, in flatten_sharded_optim_state_dict
return FullyShardedDataParallel._optim_state_dict_to_load_impl(
File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1231, in _optim_state_dict_to_load_impl
return _rekey_sharded_optim_state_dict(
File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/_optim_utils.py", line 972, in _rekey_sharded_optim_state_dict
for unflat_param_name in unflat_param_group["params"]
TypeError: string indices must be integers
exiting process 2 with exit code: 10
ERROR:torch.testing._internal.common_distributed:Caught exception:
Traceback (most recent call last):
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 663, in run_test
getattr(self, test_name)()
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 541, in wrapper
fn()
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 172, in wrapper
func(self) # type: ignore[misc]
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 172, in wrapper
return func(*args, **kwargs)
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/checkpoint_utils.py", line 34, in wrapper
func(self)
File "/fsx/users/irisz/work/pytorch/test/distributed/checkpoint/test_fsdp_optim_state.py", line 89, in test_distributed_tensor_planner
flattened_osd = FSDP.flatten_sharded_optim_state_dict(
File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1434, in flatten_sharded_optim_state_dict
return FullyShardedDataParallel._optim_state_dict_to_load_impl(
File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1231, in _optim_state_dict_to_load_impl
return _rekey_sharded_optim_state_dict(
File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/_optim_utils.py", line 972, in _rekey_sharded_optim_state_dict
for unflat_param_name in unflat_param_group["params"]
TypeError: string indices must be integers
exiting process 3 with exit code: 10
Process 3 terminated with exit code 10, terminating remaining processes.
E
======================================================================
ERROR: test_distributed_tensor_planner (__main__.FsdpOptimStateCheckpoint)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 539, in wrapper
self._join_processes(fn)
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 765, in _join_processes
self._check_return_codes(elapsed_time)
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 810, in _check_return_codes
raise RuntimeError(error)
RuntimeError: Process 3 exited with error code 10 and exception:
Traceback (most recent call last):
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 663, in run_test
getattr(self, test_name)()
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 541, in wrapper
fn()
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/_tensor/common_dtensor.py", line 172, in wrapper
func(self) # type: ignore[misc]
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/common_distributed.py", line 172, in wrapper
return func(*args, **kwargs)
File "/fsx/users/irisz/work/pytorch/torch/testing/_internal/distributed/checkpoint_utils.py", line 34, in wrapper
func(self)
File "/fsx/users/irisz/work/pytorch/test/distributed/checkpoint/test_fsdp_optim_state.py", line 89, in test_distributed_tensor_planner
flattened_osd = FSDP.flatten_sharded_optim_state_dict(
File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1434, in flatten_sharded_optim_state_dict
return FullyShardedDataParallel._optim_state_dict_to_load_impl(
File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1231, in _optim_state_dict_to_load_impl
return _rekey_sharded_optim_state_dict(
File "/fsx/users/irisz/work/pytorch/torch/distributed/fsdp/_optim_utils.py", line 972, in _rekey_sharded_optim_state_dict
for unflat_param_name in unflat_param_group["params"]
TypeError: string indices must be integers
----------------------------------------------------------------------
Ran 1 test in 10.869s
FAILED (errors=1)
Alternatives
No response
Additional context
No response
cc. @kumpera
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501 @awgu
Metadata
Metadata
Assignees
Labels
oncall: distributedAdd this issue/PR to distributed oncall triage queueAdd this issue/PR to distributed oncall triage queue