Skip to content

Commit ba23ecb

Browse files
committed
[FSDP][optim_state_dict][10/N] Make optim_state_dict and optim_state_dict_to_load public
Make optim_state_dict and optim_state_dict_to_load public APIs and consolidate them with state_dict by using the same state_dict_type to decide how to perform the optimizer state_dict save and load. Differential Revision: [D42488022](https://our.internmc.facebook.com/intern/diff/D42488022/) ghstack-source-id: 177584342 Pull Request resolved: #92118
1 parent 949f25b commit ba23ecb

File tree

5 files changed

+161
-61
lines changed

5 files changed

+161
-61
lines changed

test/distributed/fsdp/test_fsdp_optim_state.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -783,9 +783,8 @@ def test_flatten_sharded_optim_state_dict_transformer(self) -> None:
783783
num_iters=3,
784784
)
785785

786-
@unittest.skip("The test currently fails on CI.")
787786
@skip_if_lt_x_gpu(2)
788-
def test_use_orig_params(self) -> None:
787+
def ftest_use_orig_params(self) -> None:
789788
"""Tests :meth:`optim_state_dict` for an FSDP-root nested model."""
790789
self._test_load_optim_state(
791790
_ModelClass.NESTED,
@@ -824,7 +823,7 @@ def _test_load_optim_state(
824823
"""
825824
initializer = self._model_class[model_class]
826825
if osd_comm_method == _OSDCommMethod.OPTIM_STATE_DICT:
827-
osd_method = FSDP._optim_state_dict
826+
osd_method = FSDP.optim_state_dict
828827
elif osd_comm_method == _OSDCommMethod.FLATTEN_SHARDED_OSD:
829828
osd_method = FSDP.sharded_optim_state_dict
830829
else:

torch/distributed/fsdp/_common_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
_CHECKPOINT_PREFIX,
2525
)
2626

27-
from .api import FullStateDictConfig, ShardingStrategy, StateDictConfig, StateDictType
27+
from .api import FullStateDictConfig, ShardingStrategy, StateDictConfig, StateDictType, OptimStateDictConfig, FullOptimStateDictConfig
2828

2929
FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module"
3030
FSDP_PREFIX = FSDP_WRAPPED_MODULE + "."
@@ -39,6 +39,7 @@ def __init__(self) -> None:
3939
self._unshard_params_ctx: Dict[nn.Module, Generator] = {}
4040
self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT
4141
self._state_dict_config: StateDictConfig = FullStateDictConfig()
42+
self._state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig()
4243
self._is_root: Optional[bool] = None
4344
self._handles: List[flat_param_file.FlatParamHandle] = []
4445
self._ignored_modules: Set[nn.Module] = set()

torch/distributed/fsdp/_init_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
BackwardPrefetch,
3838
CPUOffload,
3939
FullStateDictConfig,
40+
FullOptimStateDictConfig,
4041
MixedPrecision,
4142
ShardingStrategy,
4243
StateDictConfig,
@@ -374,6 +375,7 @@ def _init_prefetching_state(
374375
def _init_state_dict_state(state: _FSDPState) -> _FSDPState:
375376
state._state_dict_type = StateDictType.FULL_STATE_DICT
376377
state_dict_config: StateDictConfig = FullStateDictConfig()
378+
state._optim_state_dict_config = FullOptimStateDictConfig()
377379
state._state_dict_config = state_dict_config
378380
unshard_params_ctx: Dict[nn.Module, Generator] = {}
379381
state._unshard_params_ctx = unshard_params_ctx

torch/distributed/fsdp/api.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,35 @@ class LocalStateDictConfig(StateDictConfig):
284284
@dataclass
285285
class ShardedStateDictConfig(StateDictConfig):
286286
pass
287+
288+
289+
@dataclass
290+
class OptimStateDictConfig:
291+
"""
292+
``OptimStateDictConfig`` is the base class for all optimizer state_dict
293+
configuration classes. Users should instantiate a child version
294+
(i.e. ``FullOptimStateDictConfig``) in order to configure settings for the
295+
particular type of ``optim_state_dict`` implementation FSDP will use.
296+
"""
297+
# TODO: actually use this flag in the _optim_utils.py
298+
offload_to_cpu: bool = True
299+
300+
301+
@dataclass
302+
class FullOptimStateDictConfig(OptimStateDictConfig):
303+
rank0_only: bool = False
304+
305+
@dataclass
306+
class LocalOptimStateDictConfig(OptimStateDictConfig):
307+
offload_to_cpu: bool = False
308+
309+
@dataclass
310+
class ShardedOptimStateDictConfig(OptimStateDictConfig):
311+
pass
312+
313+
314+
@dataclass
315+
class StateDictSettings:
316+
state_dict_type: StateDictType
317+
state_dict_config: StateDictConfig
318+
optim_state_dict_config: OptimStateDictConfig

0 commit comments

Comments
 (0)