@@ -360,13 +360,6 @@ class ShardedStateDictConfig(StateDictConfig):
360360 pass
361361
362362
363- _state_dict_type_to_config = {
364- StateDictType .FULL_STATE_DICT : FullStateDictConfig ,
365- StateDictType .LOCAL_STATE_DICT : LocalStateDictConfig ,
366- StateDictType .SHARDED_STATE_DICT : ShardedStateDictConfig ,
367- }
368-
369-
370363class OptimStateKeyType (Enum ):
371364 PARAM_NAME = auto ()
372365 PARAM_ID = auto ()
@@ -2113,42 +2106,53 @@ def _get_training_state(
21132106 return next (iter (training_states ))
21142107
21152108 @staticmethod
2116- @contextlib .contextmanager
2117- def state_dict_type (
2109+ def set_state_dict_type (
21182110 module : nn .Module ,
21192111 state_dict_type : StateDictType ,
21202112 state_dict_config : Optional [StateDictConfig ] = None ,
2121- ) -> Generator :
2113+ ) -> Tuple [ StateDictType , StateDictConfig ] :
21222114 """
2123- A context manager to set the ``state_dict_type`` of all the descendant
2124- FSDP modules of the target module. The target module does not have to
2125- be a FSDP module. If the target module is a FSDP module, its
2126- ``state_dict_type`` will also be changed.
2115+ Set the ``state_dict_type`` and the corresponding (optional)
2116+ configurations of all the descendant FSDP modules of the target module.
2117+ The target module does not have to be a FSDP module. If the target
2118+ module is a FSDP module, its ``state_dict_type`` will also be changed.
21272119
21282120 .. note:: This API should be called for only the top-level (root)
21292121 module.
21302122
21312123 .. note:: This API enables users to transparently use the conventional
21322124 ``state_dict`` API to take model checkpoints in cases where the
21332125 root FSDP module is wrapped by another ``nn.Module``. For example,
2134- the following will ensure ``state_dict`` is called on all non-FSDP
2135- instances, while dispatching into `local_state_dict ` implementation
2126+ the following will ensure ``state_dict`` is called on all non-FSDP
2127+ instances, while dispatching into `sharded_state_dict ` implementation
21362128 for FSDP:
21372129
21382130 Example::
21392131
21402132 >>> # xdoctest: +SKIP("undefined variables")
21412133 >>> model = DDP(FSDP(...))
2142- >>> with FSDP.state_dict_type(model, StateDictType.LOCAL_STATE_DICT):
2143- >>> checkpoint = model.state_dict()
2134+ >>> FSDP.set_state_dict_type(
2135+ >>> model,
2136+ >>> StateDictType.SHARDED_STATE_DICT,
2137+ >>> ShardedStateDictConfig(offload_to_cpu=True),
2138+ >>> )
2139+ >>> checkpoint = model.state_dict()
21442140
21452141 Args:
21462142 module (torch.nn.Module): Root module.
21472143 state_dict_type (StateDictType): the desired ``state_dict_type`` to set.
2144+ state_dict_config (Optional[StateDictConfig]): the configuration for the
2145+ target ``state_dict_type``.
21482146 """
2147+ _state_dict_type_to_config = {
2148+ StateDictType .FULL_STATE_DICT : FullStateDictConfig ,
2149+ StateDictType .LOCAL_STATE_DICT : LocalStateDictConfig ,
2150+ StateDictType .SHARDED_STATE_DICT : ShardedStateDictConfig ,
2151+ }
2152+
21492153 prev_state_dict_type = None
21502154 prev_state_dict_config = None
2151- # Use default config a state_dict config is not set.
2155+ # Use the default config if a state_dict config is not set.
21522156 if state_dict_config is None :
21532157 state_dict_config = _state_dict_type_to_config [state_dict_type ]()
21542158 for submodule in FullyShardedDataParallel .fsdp_modules (module ):
@@ -2166,18 +2170,62 @@ def state_dict_type(
21662170 expected_state_dict_config_type = _state_dict_type_to_config [state_dict_type ]
21672171 if expected_state_dict_config_type != type (state_dict_config ):
21682172 raise RuntimeError (
2169- f"Expected state_dict_config of type { expected_state_dict_config_type } but got { type (state_dict_config )} "
2173+ f"Expected state_dict_config of type { expected_state_dict_config_type } "
2174+ f"but got { type (state_dict_config )} "
21702175 )
21712176 submodule ._state_dict_type = state_dict_type
21722177 submodule ._state_dict_config = state_dict_config
2178+
2179+ return prev_state_dict_type , prev_state_dict_config
2180+
2181+ @staticmethod
2182+ @contextlib .contextmanager
2183+ def state_dict_type (
2184+ module : nn .Module ,
2185+ state_dict_type : StateDictType ,
2186+ state_dict_config : Optional [StateDictConfig ] = None ,
2187+ ) -> Generator :
2188+ """
2189+ A context manager to set the ``state_dict_type`` of all the descendant
2190+ FSDP modules of the target module. This context manager has the same
2191+ functions as :meth:`set_state_dict_type`. Read the document of
2192+ :meth:`set_state_dict_type` for the detail.
2193+
2194+ Example::
2195+
2196+ >>> # xdoctest: +SKIP("undefined variables")
2197+ >>> model = DDP(FSDP(...))
2198+ >>> with FSDP.state_dict_type(
2199+ >>> model,
2200+ >>> StateDictType.SHARDED_STATE_DICT,
2201+ >>> ):
2202+ >>> checkpoint = model.state_dict()
2203+
2204+ Args:
2205+ module (torch.nn.Module): Root module.
2206+ state_dict_type (StateDictType): the desired ``state_dict_type`` to set.
2207+ state_dict_config (Optional[StateDictConfig]): the configuration for the
2208+ target ``state_dict_type``.
2209+ """
2210+ prev_state_dict_type = None
2211+ prev_state_dict_config = None
21732212 try :
2213+ prev_state_dict_type , prev_state_dict_config = (
2214+ FullyShardedDataParallel .set_state_dict_type (
2215+ module , state_dict_type , state_dict_config
2216+ )
2217+ )
21742218 yield
2219+ except Exception as e :
2220+ raise e
2221+ else :
2222+ assert prev_state_dict_type is not None
2223+ assert prev_state_dict_config is not None
21752224 finally :
2176- assert prev_state_dict_type is not None # Avoid mypy warning
2177- assert prev_state_dict_config is not None # Avoid mypy warning
2178- for submodule in FullyShardedDataParallel .fsdp_modules (module ):
2179- submodule ._state_dict_type = prev_state_dict_type
2180- submodule ._state_dict_config = prev_state_dict_config
2225+ if prev_state_dict_type is not None and prev_state_dict_config is not None :
2226+ FullyShardedDataParallel .set_state_dict_type (
2227+ module , prev_state_dict_type , prev_state_dict_config
2228+ )
21812229
21822230 def _convert_to_wrapped_module_name (self , module_name : str ) -> str :
21832231 module_name = module_name .replace (f"{ FPW_MODULE } ." , "" )
@@ -2524,7 +2572,7 @@ def _sharded_state_dict(self, *args: Any, **kwargs: Any) -> Any:
25242572 (e.g., DPP, model parallelism, and single trainer) after a valid
25252573 resharding.
25262574 """
2527- with self .set_state_dict_type (StateDictType .SHARDED_STATE_DICT ):
2575+ with self .state_dict_type (StateDictType .SHARDED_STATE_DICT ):
25282576 return self .state_dict (self , * args , ** kwargs )
25292577
25302578 def _full_pre_load_state_dict_hook (
@@ -2759,7 +2807,7 @@ def _load_sharded_state_dict(
27592807 """
27602808 Load states from a unflattened, sharded state dictionary.
27612809 """
2762- with self .set_state_dict_type (StateDictType .SHARDED_STATE_DICT ):
2810+ with self .state_dict_type (StateDictType .SHARDED_STATE_DICT ):
27632811 return self .load_state_dict (state_dict , strict )
27642812
27652813 def forward (self , * args : Any , ** kwargs : Any ) -> Any :
0 commit comments