Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions test/distributed/fsdp/test_fsdp_state_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)
from torch.distributed.fsdp import CPUOffload, FullStateDictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import (
CPUOffload,
FullStateDictConfig,
LocalStateDictConfig,
MixedPrecision,
ShardedStateDictConfig,
StateDictType,
)
from torch.distributed.fsdp._shard_utils import _gather_state_dict
Expand Down Expand Up @@ -186,8 +188,16 @@ def _get_state_dict_mgr(
rank0_only=state_dict_rank0_and_offload,
offload_to_cpu=state_dict_rank0_and_offload,
)
elif state_dict_type == "local_state_dict":
config = LocalStateDictConfig(
offload_to_cpu=state_dict_rank0_and_offload,
)
elif state_dict_type == "sharded_state_dict":
config = ShardedStateDictConfig(
offload_to_cpu=state_dict_rank0_and_offload,
)
else:
config = None
raise ValueError("Unsupported state_dict_type")
return FSDP.state_dict_type(model, _state_dict_type, config)

def _validate_state_dict_contents(
Expand Down
1 change: 1 addition & 0 deletions torch/distributed/fsdp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
LocalStateDictConfig,
MixedPrecision,
OptimStateKeyType,
ShardedStateDictConfig,
ShardingStrategy,
StateDictType,
)
Expand Down
22 changes: 16 additions & 6 deletions torch/distributed/fsdp/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ class StateDictType(Enum):
LOCAL_STATE_DICT = auto()
SHARDED_STATE_DICT = auto()


@dataclass
class StateDictConfig:
"""
Expand All @@ -315,7 +316,8 @@ class StateDictConfig:
order to configure settings for the particular type of ``state_dict``
implementation FSDP will use.
"""
pass
offload_to_cpu: bool = False


@dataclass
class FullStateDictConfig(StateDictConfig):
Expand Down Expand Up @@ -345,23 +347,26 @@ class FullStateDictConfig(StateDictConfig):
>>> fsdp = FSDP(model, device_id=torch.cuda.current_device(), auto_wrap_policy=..., sync_module_states=True)
>>> # After this point, all ranks have FSDP model with loaded checkpoint.
"""
offload_to_cpu: bool = False
rank0_only: bool = False


@dataclass
class LocalStateDictConfig(StateDictConfig):
pass


@dataclass
class ShardedStateDictConfig(StateDictConfig):
pass


_state_dict_type_to_config = {
StateDictType.FULL_STATE_DICT: FullStateDictConfig,
StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig,
StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig,
}


class OptimStateKeyType(Enum):
PARAM_NAME = auto()
PARAM_ID = auto()
Expand Down Expand Up @@ -2317,10 +2322,12 @@ def _local_post_state_dict_hook(
local_shards = [
Shard.from_tensor_and_offsets(flat_param, [shard_offset], self.rank)
]
state_dict[f"{prefix}{FLAT_PARAM}"] = init_from_local_shards(
sharded_tensor = init_from_local_shards(
local_shards, full_numel, process_group=self.process_group
) # type: ignore[assignment]

if self._state_dict_config.offload_to_cpu:
sharded_tensor = sharded_tensor.cpu()
state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor
return state_dict

@torch.no_grad()
Expand All @@ -2345,13 +2352,16 @@ def _sharded_post_state_dict_hook(
for fqn, _, _ in self._param_fqns:
# Create a ShardedTensor for the unflattened, non-sharded parameter.
param = functools.reduce(getattr, fqn.split("."), self.module)
state_dict[f"{prefix}{fqn}"] = _ext_chunk_tensor(
sharded_tensor = _ext_chunk_tensor(
tensor=param,
rank=self.rank,
world_size=self.world_size,
num_devices_per_node=torch.cuda.device_count(),
pg=self.process_group
) # type: ignore[assignment]
)
if self._state_dict_config.offload_to_cpu:
sharded_tensor = sharded_tensor.cpu()
state_dict[f"{prefix}{fqn}"] = sharded_tensor
# For `use_orig_params=True`, the `FlatParameter` is not registered, so
# there is no entry in the state dict for it to pop.
if not self._use_orig_params:
Expand Down