Skip to content

Commit eadee17

Browse files
committed
[FSDP] Consolidate FSDP state_dict offload_to_cpu settings
Pull Request resolved: #86211 Consolidate FSDP state_dict offload_to_cpu settings. All state_dict_types now have offload_to_cpu options. ghstack-source-id: 170308168 Differential Revision: [D40065969](https://our.internmc.facebook.com/intern/diff/D40065969/)
1 parent 894c421 commit eadee17

File tree

3 files changed

+29
-8
lines changed

3 files changed

+29
-8
lines changed

test/distributed/fsdp/test_fsdp_state_dict.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1414
checkpoint_wrapper,
1515
)
16-
from torch.distributed.fsdp import CPUOffload, FullStateDictConfig
1716
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1817
from torch.distributed.fsdp import (
18+
CPUOffload,
19+
FullStateDictConfig,
1920
LocalStateDictConfig,
2021
MixedPrecision,
22+
ShardedStateDictConfig,
2123
StateDictType,
2224
)
2325
from torch.distributed.fsdp._shard_utils import _gather_state_dict
@@ -186,8 +188,16 @@ def _get_state_dict_mgr(
186188
rank0_only=state_dict_rank0_and_offload,
187189
offload_to_cpu=state_dict_rank0_and_offload,
188190
)
191+
elif state_dict_type == "local_state_dict":
192+
config = LocalStateDictConfig(
193+
offload_to_cpu=state_dict_rank0_and_offload,
194+
)
195+
elif state_dict_type == "sharded_state_dict":
196+
config = ShardedStateDictConfig(
197+
offload_to_cpu=state_dict_rank0_and_offload,
198+
)
189199
else:
190-
config = None
200+
raise ValueError("Unsupported state_dict_type")
191201
return FSDP.state_dict_type(model, _state_dict_type, config)
192202

193203
def _validate_state_dict_contents(

torch/distributed/fsdp/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
LocalStateDictConfig,
88
MixedPrecision,
99
OptimStateKeyType,
10+
ShardedStateDictConfig,
1011
ShardingStrategy,
1112
StateDictType,
1213
)

torch/distributed/fsdp/fully_sharded_data_parallel.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ class StateDictType(Enum):
307307
LOCAL_STATE_DICT = auto()
308308
SHARDED_STATE_DICT = auto()
309309

310+
310311
@dataclass
311312
class StateDictConfig:
312313
"""
@@ -315,7 +316,8 @@ class StateDictConfig:
315316
order to configure settings for the particular type of ``state_dict``
316317
implementation FSDP will use.
317318
"""
318-
pass
319+
offload_to_cpu: bool = False
320+
319321

320322
@dataclass
321323
class FullStateDictConfig(StateDictConfig):
@@ -345,23 +347,26 @@ class FullStateDictConfig(StateDictConfig):
345347
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
346348
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
347349
"""
348-
offload_to_cpu: bool = False
349350
rank0_only: bool = False
350351

352+
351353
@dataclass
352354
class LocalStateDictConfig(StateDictConfig):
353355
pass
354356

357+
355358
@dataclass
356359
class ShardedStateDictConfig(StateDictConfig):
357360
pass
358361

362+
359363
_state_dict_type_to_config = {
360364
StateDictType.FULL_STATE_DICT: FullStateDictConfig,
361365
StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig,
362366
StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig,
363367
}
364368

369+
365370
class OptimStateKeyType(Enum):
366371
PARAM_NAME = auto()
367372
PARAM_ID = auto()
@@ -2317,10 +2322,12 @@ def _local_post_state_dict_hook(
23172322
local_shards = [
23182323
Shard.from_tensor_and_offsets(flat_param, [shard_offset], self.rank)
23192324
]
2320-
state_dict[f"{prefix}{FLAT_PARAM}"] = init_from_local_shards(
2325+
sharded_tensor = init_from_local_shards(
23212326
local_shards, full_numel, process_group=self.process_group
23222327
) # type: ignore[assignment]
2323-
2328+
if self._state_dict_config.offload_to_cpu:
2329+
sharded_tensor = sharded_tensor.cpu()
2330+
state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor
23242331
return state_dict
23252332

23262333
@torch.no_grad()
@@ -2345,13 +2352,16 @@ def _sharded_post_state_dict_hook(
23452352
for fqn, _, _ in self._param_fqns:
23462353
# Create a ShardedTensor for the unflattened, non-sharded parameter.
23472354
param = functools.reduce(getattr, fqn.split("."), self.module)
2348-
state_dict[f"{prefix}{fqn}"] = _ext_chunk_tensor(
2355+
sharded_tensor = _ext_chunk_tensor(
23492356
tensor=param,
23502357
rank=self.rank,
23512358
world_size=self.world_size,
23522359
num_devices_per_node=torch.cuda.device_count(),
23532360
pg=self.process_group
2354-
) # type: ignore[assignment]
2361+
)
2362+
if self._state_dict_config.offload_to_cpu:
2363+
sharded_tensor = sharded_tensor.cpu()
2364+
state_dict[f"{prefix}{fqn}"] = sharded_tensor
23552365
# For `use_orig_params=True`, the `FlatParameter` is not registered, so
23562366
# there is no entry in the state dict for it to pop.
23572367
if not self._use_orig_params:

0 commit comments

Comments
 (0)