Skip to content

Commit 99f1c1f

Browse files
committed
[FSDP] Consolidate FSDP state_dict offload_to_cpu settings
Pull Request resolved: #86211 Consolidate FSDP state_dict offload_to_cpu settings. All state_dict_types now have offload_to_cpu options. ghstack-source-id: 169442813 Differential Revision: [D40065969](https://our.internmc.facebook.com/intern/diff/D40065969/)
1 parent 2355b62 commit 99f1c1f

File tree

3 files changed

+29
-7
lines changed

3 files changed

+29
-7
lines changed

test/distributed/fsdp/test_fsdp_state_dict.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
1414
checkpoint_wrapper,
1515
)
16-
from torch.distributed.fsdp import CPUOffload, FullStateDictConfig
1716
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1817
from torch.distributed.fsdp import (
18+
CPUOffload,
19+
FullStateDictConfig,
1920
LocalStateDictConfig,
2021
MixedPrecision,
22+
ShardedStateDictConfig,
2123
StateDictType,
2224
)
2325
from torch.distributed.fsdp._shard_utils import _gather_state_dict
@@ -186,8 +188,16 @@ def _get_state_dict_mgr(
186188
rank0_only=state_dict_rank0_and_offload,
187189
offload_to_cpu=state_dict_rank0_and_offload,
188190
)
191+
elif state_dict_type == "local_state_dict":
192+
config = LocalStateDictConfig(
193+
offload_to_cpu=state_dict_rank0_and_offload,
194+
)
195+
elif state_dict_type == "sharded_state_dict":
196+
config = ShardedStateDictConfig(
197+
offload_to_cpu=state_dict_rank0_and_offload,
198+
)
189199
else:
190-
config = None
200+
raise ValueError("Unsupported state_dict_type")
191201
return FSDP.state_dict_type(model, _state_dict_type, config)
192202

193203
def _validate_state_dict_contents(

torch/distributed/fsdp/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
LocalStateDictConfig,
88
MixedPrecision,
99
OptimStateKeyType,
10+
ShardedStateDictConfig,
1011
ShardingStrategy,
1112
StateDictType,
1213
)

torch/distributed/fsdp/fully_sharded_data_parallel.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ class StateDictType(Enum):
307307
LOCAL_STATE_DICT = auto()
308308
SHARDED_STATE_DICT = auto()
309309

310+
310311
@dataclass
311312
class StateDictConfig:
312313
"""
@@ -315,7 +316,8 @@ class StateDictConfig:
315316
order to configure settings for the particular type of ``state_dict``
316317
implementation FSDP will use.
317318
"""
318-
pass
319+
offload_to_cpu: bool = False
320+
319321

320322
@dataclass
321323
class FullStateDictConfig(StateDictConfig):
@@ -345,23 +347,26 @@ class FullStateDictConfig(StateDictConfig):
345347
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
346348
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
347349
"""
348-
offload_to_cpu: bool = False
349350
rank0_only: bool = False
350351

352+
351353
@dataclass
352354
class LocalStateDictConfig(StateDictConfig):
353355
pass
354356

357+
355358
@dataclass
356359
class ShardedStateDictConfig(StateDictConfig):
357360
pass
358361

362+
359363
_state_dict_type_to_config = {
360364
StateDictType.FULL_STATE_DICT: FullStateDictConfig,
361365
StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig,
362366
StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig,
363367
}
364368

369+
365370
class OptimStateKeyType(Enum):
366371
PARAM_NAME = auto()
367372
PARAM_ID = auto()
@@ -2286,10 +2291,12 @@ def _local_post_state_dict_hook(
22862291
local_shards = [
22872292
Shard.from_tensor_and_offsets(flat_param, [shard_offset], self.rank)
22882293
]
2289-
state_dict[f"{prefix}{FLAT_PARAM}"] = init_from_local_shards(
2294+
sharded_tensor = init_from_local_shards(
22902295
local_shards, full_numel, process_group=self.process_group
22912296
) # type: ignore[assignment]
2292-
2297+
if self._state_dict_config.offload_to_cpu:
2298+
sharded_tensor = sharded_tensor.cpu()
2299+
state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor
22932300
return state_dict
22942301

22952302
@torch.no_grad()
@@ -2314,13 +2321,17 @@ def _sharded_post_state_dict_hook(
23142321
for fqn, _, _ in self._param_fqns:
23152322
# Create a ShardedTensor for the unflattened, non-sharded parameter.
23162323
param = functools.reduce(getattr, fqn.split("."), self.module)
2317-
state_dict[f"{prefix}{fqn}"] = _ext_chunk_tensor(
2324+
sharded_tensor = _ext_chunk_tensor(
23182325
tensor=param,
23192326
rank=self.rank,
23202327
world_size=self.world_size,
23212328
num_devices_per_node=torch.cuda.device_count(),
23222329
pg=self.process_group
23232330
) # type: ignore[assignment]
2331+
if self._state_dict_config.offload_to_cpu:
2332+
sharded_tensor = sharded_tensor.cpu()
2333+
state_dict[f"{prefix}{fqn}"] = sharded_tensor
2334+
23242335
state_dict.pop(f"{prefix}{FLAT_PARAM}")
23252336
return state_dict
23262337

0 commit comments

Comments
 (0)