-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP] Add re-key btw param names/IDs for optim state dict #74912
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit d9b02e1 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
|
|
||
|
|
||
| def _get_flat_param_id_to_param( | ||
| def _get_param_id_to_param( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renaming this because it actually works for parameter ID to parameter in both the flattened and unflattened cases, as long as the keys and values are consistent.
| } | ||
| return sharded_optim_state_dict | ||
|
|
||
| @staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This functionality is actually not unique to FSDP. It could be used generally for PyTorch. However, nowhere else have we seen keying by parameter name, so I have put this inside fully_sharded_data_parallel.py for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good!
| state). | ||
| """ | ||
| non_none_tensors = [t for t in pos_dim_tensors if t is not None] | ||
| # Check that all are tensors on CPU with the same dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just removing this check, which is overly strict. We can just move the tensors to CPU inside the function to avoid device mismatch. It is not actually semantically important which device the tensors in the optimizer state are on.
zhaojuanmao
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great!
| } | ||
| return sharded_optim_state_dict | ||
|
|
||
| @staticmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good!
| SHARDED_STATE_DICT = auto() | ||
|
|
||
|
|
||
| class OptimStateKeyType(Enum): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: let's export it in fspd/init.py file as well?
|
@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
**Overview** This introduces a new static method `FSDP.rekey_optim_state_dict()` as a utility for interoperating between local/DDP (non-wrapped) models and FSDP (wrapped) models. To load from a wrapped model to a non-wrapped model: ``` wrapped_model, wrapped_optim = ... full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) nonwrapped_model, nonwrapped_optim = ... rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) nonwrapped_optim.load_state_dict(rekeyed_osd) ``` To load from a non-wrapped model to a wrapped model: ``` nonwrapped_model, nonwrapped_optim = ... osd = nonwrapped_optim.state_dict() rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) wrapped_model, wrapped_optim = ... sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) wrapped_optim.load_state_dict(sharded_osd) ``` **Test Plan** `test_rekey_optim_state_dict_to_ids()` and `test_rekey_optim_state_dict_to_names()`. Differential Revision: [D35225819](https://our.internmc.facebook.com/intern/diff/D35225819) [ghstack-poisoned]
|
@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: Pull Request resolved: #74912 **Overview** This introduces a new static method `FSDP.rekey_optim_state_dict()` as a utility for interoperating between local/DDP (non-wrapped) models and FSDP (wrapped) models. To load from a wrapped model to a non-wrapped model: ``` wrapped_model, wrapped_optim = ... full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) nonwrapped_model, nonwrapped_optim = ... rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) nonwrapped_optim.load_state_dict(rekeyed_osd) ``` To load from a non-wrapped model to a wrapped model: ``` nonwrapped_model, nonwrapped_optim = ... osd = nonwrapped_optim.state_dict() rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) wrapped_model, wrapped_optim = ... sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) wrapped_optim.load_state_dict(sharded_osd) ``` **Test Plan** `test_rekey_optim_state_dict_to_ids()` and `test_rekey_optim_state_dict_to_names()`. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D35225819 Pulled By: awgu fbshipit-source-id: fbbdbde8b595a9c65b17a9aecb4f22b2c9761a23
|
Hey @awgu. |
Stack from ghstack:
Overview
This introduces a new static method
FSDP.rekey_optim_state_dict()as a utility for interoperating between local/DDP (non-wrapped) models and FSDP (wrapped) models.To load from a wrapped model to a non-wrapped model:
To load from a non-wrapped model to a wrapped model:
Test Plan
test_rekey_optim_state_dict_to_ids()andtest_rekey_optim_state_dict_to_names().Differential Revision: D35225819