Skip to content

Commit c9e62db

Browse files
committed
[FSDP] Add set_state_dict_type API to setup state_dict_type without using context manager
Pull Request resolved: #86243 FSDP.state_dict_type is a context manager. However, users may want to decide what state_dict is going to used during initialization. `set_state_dict_type` allows users to do so. ghstack-source-id: 170765562 Differential Revision: [D40083670](https://our.internmc.facebook.com/intern/diff/D40083670/)
1 parent 8b0cc9c commit c9e62db

File tree

1 file changed

+75
-27
lines changed

1 file changed

+75
-27
lines changed

torch/distributed/fsdp/fully_sharded_data_parallel.py

Lines changed: 75 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -360,13 +360,6 @@ class ShardedStateDictConfig(StateDictConfig):
360360
pass
361361

362362

363-
_state_dict_type_to_config = {
364-
StateDictType.FULL_STATE_DICT: FullStateDictConfig,
365-
StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig,
366-
StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig,
367-
}
368-
369-
370363
class OptimStateKeyType(Enum):
371364
PARAM_NAME = auto()
372365
PARAM_ID = auto()
@@ -2113,42 +2106,53 @@ def _get_training_state(
21132106
return next(iter(training_states))
21142107

21152108
@staticmethod
2116-
@contextlib.contextmanager
2117-
def state_dict_type(
2109+
def set_state_dict_type(
21182110
module: nn.Module,
21192111
state_dict_type: StateDictType,
21202112
state_dict_config: Optional[StateDictConfig] = None,
2121-
) -> Generator:
2113+
) -> Tuple[StateDictType, StateDictConfig]:
21222114
"""
2123-
A context manager to set the ``state_dict_type`` of all the descendant
2124-
FSDP modules of the target module. The target module does not have to
2125-
be a FSDP module. If the target module is a FSDP module, its
2126-
``state_dict_type`` will also be changed.
2115+
Set the ``state_dict_type`` and the corresponding (optional)
2116+
configurations of all the descendant FSDP modules of the target module.
2117+
The target module does not have to be a FSDP module. If the target
2118+
module is a FSDP module, its ``state_dict_type`` will also be changed.
21272119
21282120
.. note:: This API should be called for only the top-level (root)
21292121
module.
21302122
21312123
.. note:: This API enables users to transparently use the conventional
21322124
``state_dict`` API to take model checkpoints in cases where the
21332125
root FSDP module is wrapped by another ``nn.Module``. For example,
2134-
the following will ensure ``state_dict`` is called on all non-FSDP
2135-
instances, while dispatching into `local_state_dict` implementation
2126+
the following will ensure ``state_dict`` is called on all non-FSDP
2127+
instances, while dispatching into `sharded_state_dict` implementation
21362128
for FSDP:
21372129
21382130
Example::
21392131
21402132
>>> # xdoctest: +SKIP("undefined variables")
21412133
>>> model = DDP(FSDP(...))
2142-
>>> with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
2143-
>>> checkpoint = model.state_dict()
2134+
>>> FSDP.set_state_dict_type(
2135+
>>> model,
2136+
>>> StateDictType.SHARDED_STATE_DICT,
2137+
>>> ShardedStateDictConfig(offload_to_cpu=True),
2138+
>>> )
2139+
>>> checkpoint = model.state_dict()
21442140
21452141
Args:
21462142
module (torch.nn.Module): Root module.
21472143
state_dict_type (StateDictType): the desired ``state_dict_type`` to set.
2144+
state_dict_config (Optional[StateDictConfig]): the configuration for the
2145+
target ``state_dict_type``.
21482146
"""
2147+
_state_dict_type_to_config = {
2148+
StateDictType.FULL_STATE_DICT: FullStateDictConfig,
2149+
StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig,
2150+
StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig,
2151+
}
2152+
21492153
prev_state_dict_type = None
21502154
prev_state_dict_config = None
2151-
# Use default config a state_dict config is not set.
2155+
# Use the default config if a state_dict config is not set.
21522156
if state_dict_config is None:
21532157
state_dict_config = _state_dict_type_to_config[state_dict_type]()
21542158
for submodule in FullyShardedDataParallel.fsdp_modules(module):
@@ -2166,18 +2170,62 @@ def state_dict_type(
21662170
expected_state_dict_config_type = _state_dict_type_to_config[state_dict_type]
21672171
if expected_state_dict_config_type != type(state_dict_config):
21682172
raise RuntimeError(
2169-
f"Expected state_dict_config of type {expected_state_dict_config_type} but got {type(state_dict_config)}"
2173+
f"Expected state_dict_config of type {expected_state_dict_config_type} "
2174+
f"but got {type(state_dict_config)}"
21702175
)
21712176
submodule._state_dict_type = state_dict_type
21722177
submodule._state_dict_config = state_dict_config
2178+
2179+
return prev_state_dict_type, prev_state_dict_config
2180+
2181+
@staticmethod
2182+
@contextlib.contextmanager
2183+
def state_dict_type(
2184+
module: nn.Module,
2185+
state_dict_type: StateDictType,
2186+
state_dict_config: Optional[StateDictConfig] = None,
2187+
) -> Generator:
2188+
"""
2189+
A context manager to set the ``state_dict_type`` of all the descendant
2190+
FSDP modules of the target module. This context manager has the same
2191+
functions as :meth:`set_state_dict_type`. Read the document of
2192+
:meth:`set_state_dict_type` for the detail.
2193+
2194+
Example::
2195+
2196+
>>> # xdoctest: +SKIP("undefined variables")
2197+
>>> model = DDP(FSDP(...))
2198+
>>> with FSDP.state_dict_type(
2199+
>>> model,
2200+
>>> StateDictType.SHARDED_STATE_DICT,
2201+
>>> ):
2202+
>>> checkpoint = model.state_dict()
2203+
2204+
Args:
2205+
module (torch.nn.Module): Root module.
2206+
state_dict_type (StateDictType): the desired ``state_dict_type`` to set.
2207+
state_dict_config (Optional[StateDictConfig]): the configuration for the
2208+
target ``state_dict_type``.
2209+
"""
2210+
prev_state_dict_type = None
2211+
prev_state_dict_config = None
21732212
try:
2213+
prev_state_dict_type, prev_state_dict_config = (
2214+
FullyShardedDataParallel.set_state_dict_type(
2215+
module, state_dict_type, state_dict_config
2216+
)
2217+
)
21742218
yield
2219+
except Exception as e:
2220+
raise e
2221+
else:
2222+
assert prev_state_dict_type is not None
2223+
assert prev_state_dict_config is not None
21752224
finally:
2176-
assert prev_state_dict_type is not None # Avoid mypy warning
2177-
assert prev_state_dict_config is not None # Avoid mypy warning
2178-
for submodule in FullyShardedDataParallel.fsdp_modules(module):
2179-
submodule._state_dict_type = prev_state_dict_type
2180-
submodule._state_dict_config = prev_state_dict_config
2225+
if prev_state_dict_type is not None and prev_state_dict_config is not None:
2226+
FullyShardedDataParallel.set_state_dict_type(
2227+
module, prev_state_dict_type, prev_state_dict_config
2228+
)
21812229

21822230
def _convert_to_wrapped_module_name(self, module_name: str) -> str:
21832231
module_name = module_name.replace(f"{FPW_MODULE}.", "")
@@ -2524,7 +2572,7 @@ def _sharded_state_dict(self, *args: Any, **kwargs: Any) -> Any:
25242572
(e.g., DPP, model parallelism, and single trainer) after a valid
25252573
resharding.
25262574
"""
2527-
with self.set_state_dict_type(StateDictType.SHARDED_STATE_DICT):
2575+
with self.state_dict_type(StateDictType.SHARDED_STATE_DICT):
25282576
return self.state_dict(self, *args, **kwargs)
25292577

25302578
def _full_pre_load_state_dict_hook(
@@ -2759,7 +2807,7 @@ def _load_sharded_state_dict(
27592807
"""
27602808
Load states from a unflattened, sharded state dictionary.
27612809
"""
2762-
with self.set_state_dict_type(StateDictType.SHARDED_STATE_DICT):
2810+
with self.state_dict_type(StateDictType.SHARDED_STATE_DICT):
27632811
return self.load_state_dict(state_dict, strict)
27642812

27652813
def forward(self, *args: Any, **kwargs: Any) -> Any:

0 commit comments

Comments
 (0)