Skip to content

Commit 95d23a8

Browse files
committed
[FSDP][state_dict] Return tensors instead of FlatParameters to avoid pickling errors
Pull Request resolved: #94637 After #88913, user-defined parameter states will be pickled. For a FlatParameter, this means `_local_shard` will also be pickled. Since state_dict and load_state_dict only require the tensor, returning the full FlatParameter does not give us any extra benefit. This PR changes the behavior to simply return a view of the FlatParameter. ghstack-source-id: 179983735 Differential Revision: [D43205127](https://our.internmc.facebook.com/intern/diff/D43205127/)
1 parent 4c6a7fa commit 95d23a8

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

test/distributed/fsdp/test_fsdp_state_dict.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Owner(s): ["oncall: distributed"]
22

3+
import io
34
import itertools
45
import sys
56
from contextlib import suppress
@@ -10,6 +11,7 @@
1011
import torch
1112
import torch.nn as nn
1213
from torch import distributed as dist
14+
from torch.distributed._shard.sharded_tensor import ShardedTensor
1315
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1416
apply_activation_checkpointing,
1517
checkpoint_wrapper,
@@ -1067,6 +1069,23 @@ def forward(self, x):
10671069
with FSDP.summon_full_params(model):
10681070
self.assertEqual(model.my_parameter.item(), 3.1415926)
10691071

1072+
@skip_if_lt_x_gpu(2)
1073+
def test_torch_save_load(self):
1074+
model = Model(wrap_fsdp=True).cuda()
1075+
with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
1076+
state_dict = model.state_dict()
1077+
checkpoint = io.BytesIO()
1078+
torch.save(state_dict, checkpoint)
1079+
checkpoint.seek(0)
1080+
state_dict_saved = torch.load(checkpoint)
1081+
for k, v in state_dict_saved.items():
1082+
if isinstance(v, ShardedTensor):
1083+
self.assertEqual(
1084+
v._local_shards[0].tensor, state_dict[k]._local_shards[0].tensor
1085+
)
1086+
else:
1087+
self.assertEqual(v, state_dict[k])
1088+
10701089

10711090
instantiate_parametrized_tests(TestFSDPStateDict)
10721091

torch/distributed/fsdp/_state_dict_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,11 @@ def _local_post_state_dict_hook(
393393
shard_offset = flat_param.numel() * fsdp_state.rank
394394
valid_data_size = flat_param.numel() - flat_param._shard_numel_padded
395395
if valid_data_size > 0:
396-
if flat_param._shard_numel_padded > 0:
397-
flat_param = flat_param.narrow(0, 0, valid_data_size)
396+
# If FlatParameter is returned, FlatParameter._local_shard cause a
397+
# pickling issue (can be torch.save but not torch.load). Since there
398+
# is no benefit for state_dict to return the actual FlatParameter class,
399+
# a view (which is a tensor) of the FlatParameter will be returned.
400+
flat_param = flat_param[:valid_data_size].view(valid_data_size)
398401
local_shards = [
399402
Shard.from_tensor_and_offsets(flat_param, [shard_offset], fsdp_state.rank)
400403
]

0 commit comments

Comments
 (0)