Skip to content

Commit aa1cb74

Browse files
committed
[FSDP] Consolidate FSDP state_dict offload_to_cpu settings
Consolidate FSDP state_dict offload_to_cpu settings. All state_dict_types now have offload_to_cpu options. Differential Revision: [D40065969](https://our.internmc.facebook.com/intern/diff/D40065969/) ghstack-source-id: 169306202 Pull Request resolved: #86211
1 parent 0b0ce72 commit aa1cb74

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

test/distributed/fsdp/test_fsdp_state_dict.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1414
checkpoint_wrapper,
1515
)
16-
from torch.distributed.fsdp import CPUOffload, FullStateDictConfig
16+
from torch.distributed.fsdp import (
17+
CPUOffload,
18+
FullStateDictConfig,
19+
LocalStateDictConfig,
20+
ShardedStateDictConfig,
21+
)
1722
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1823
from torch.distributed.fsdp import (
1924
LocalStateDictConfig,
@@ -73,8 +78,8 @@
7378

7479
NON_ROOT_FSDP_PREFIX = 'non_fsdp_lin'
7580

76-
_UNFLATTENED_STATE_DICT_IMPLS = ["state_dict", "sharded_state_dict"]
77-
_FLATTENED_STATE_DICT_IMPLS = ["local_state_dict"]
81+
_UNFLATTENED_STATE_DICT_IMPLS = ["state_dict"]
82+
_FLATTENED_STATE_DICT_IMPLS = []
7883
_SUPPORTED_STATE_DICT_IMPLS = (
7984
_UNFLATTENED_STATE_DICT_IMPLS + _FLATTENED_STATE_DICT_IMPLS
8085
)
@@ -180,8 +185,16 @@ def _get_state_dict_mgr(
180185
rank0_only=state_dict_rank0_and_offload,
181186
offload_to_cpu=state_dict_rank0_and_offload,
182187
)
188+
elif state_dict_type == "local_state_dict":
189+
config = LocalStateDictConfig(
190+
offload_to_cpu=state_dict_rank0_and_offload,
191+
)
192+
elif state_dict_type == "sharded_state_dict":
193+
config = ShardedStateDictConfig(
194+
offload_to_cpu=state_dict_rank0_and_offload,
195+
)
183196
else:
184-
config = None
197+
raise ValueError("Unspported state_dict_type")
185198
return FSDP.state_dict_type(model, _state_dict_type, config)
186199

187200
def _validate_state_dict_contents(

torch/distributed/fsdp/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
LocalStateDictConfig,
88
MixedPrecision,
99
OptimStateKeyType,
10+
ShardedStateDictConfig,
1011
ShardingStrategy,
1112
StateDictType,
1213
)

torch/distributed/fsdp/fully_sharded_data_parallel.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ class StateDictConfig:
315315
order to configure settings for the particular type of ``state_dict``
316316
implementation FSDP will use.
317317
"""
318-
pass
318+
offload_to_cpu: bool = False
319319

320320
@dataclass
321321
class FullStateDictConfig(StateDictConfig):
@@ -345,9 +345,9 @@ class FullStateDictConfig(StateDictConfig):
345345
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
346346
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
347347
"""
348-
offload_to_cpu: bool = False
349348
rank0_only: bool = False
350349

350+
351351
@dataclass
352352
class LocalStateDictConfig(StateDictConfig):
353353
pass
@@ -2251,10 +2251,12 @@ def _local_post_state_dict_hook(
22512251
local_shards = [
22522252
Shard.from_tensor_and_offsets(flat_param, [shard_offset], self.rank)
22532253
]
2254-
state_dict[f"{prefix}{FLAT_PARAM}"] = init_from_local_shards(
2254+
sharded_tensor = init_from_local_shards(
22552255
local_shards, full_numel, process_group=self.process_group
22562256
) # type: ignore[assignment]
2257-
2257+
if self._state_dict_type.offload_to_cpu:
2258+
sharded_tensor = sharded_tensor.cpu()
2259+
state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor
22582260
return state_dict
22592261

22602262
@torch.no_grad()
@@ -2279,13 +2281,17 @@ def _sharded_post_state_dict_hook(
22792281
for fqn, _, _ in self._param_fqns:
22802282
# Create a ShardedTensor for the unflattened, non-sharded parameter.
22812283
param = functools.reduce(getattr, fqn.split("."), self.module)
2282-
state_dict[f"{prefix}{fqn}"] = _ext_chunk_tensor(
2284+
sharded_tensor = _ext_chunk_tensor(
22832285
tensor=param,
22842286
rank=self.rank,
22852287
world_size=self.world_size,
22862288
num_devices_per_node=torch.cuda.device_count(),
22872289
pg=self.process_group
22882290
) # type: ignore[assignment]
2291+
if self._state_dict_config.offload_to_cpu:
2292+
sharded_tensor = sharded_tensor.cpu()
2293+
state_dict[f"{prefix}{fqn}"] = sharded_tensor
2294+
22892295
state_dict.pop(f"{prefix}{FLAT_PARAM}")
22902296
return state_dict
22912297

0 commit comments

Comments
 (0)