Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Mar 29, 2022

Stack from ghstack:

Overview
This introduces a new static method FSDP.rekey_optim_state_dict() as a utility for interoperating between local/DDP (non-wrapped) models and FSDP (wrapped) models.

To load from a wrapped model to a non-wrapped model:

wrapped_model, wrapped_optim = ...
full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
nonwrapped_model, nonwrapped_optim = ...
rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
nonwrapped_optim.load_state_dict(rekeyed_osd)

To load from a non-wrapped model to a wrapped model:

nonwrapped_model, nonwrapped_optim = ...
osd = nonwrapped_optim.state_dict()
rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
wrapped_model, wrapped_optim = ...
sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
wrapped_optim.load_state_dict(sharded_osd)

Test Plan
test_rekey_optim_state_dict_to_ids() and test_rekey_optim_state_dict_to_names().

Differential Revision: D35225819

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Mar 29, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit d9b02e1 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@facebook-github-bot facebook-github-bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Mar 29, 2022
@awgu awgu marked this pull request as ready for review March 29, 2022 17:19


def _get_flat_param_id_to_param(
def _get_param_id_to_param(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renaming this because it actually works for parameter ID to parameter in both the flattened and unflattened cases, as long as the keys and values are consistent.

}
return sharded_optim_state_dict

@staticmethod
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This functionality is actually not unique to FSDP. It could be used generally for PyTorch. However, nowhere else have we seen keying by parameter name, so I have put this inside fully_sharded_data_parallel.py for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

state).
"""
non_none_tensors = [t for t in pos_dim_tensors if t is not None]
# Check that all are tensors on CPU with the same dtype
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just removing this check, which is overly strict. We can just move the tensors to CPU inside the function to avoid device mismatch. It is not actually semantically important which device the tensors in the optimizer state are on.

Copy link
Contributor

@zhaojuanmao zhaojuanmao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great!

}
return sharded_optim_state_dict

@staticmethod
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

SHARDED_STATE_DICT = auto()


class OptimStateKeyType(Enum):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: let's export it in fspd/init.py file as well?

@awgu
Copy link
Collaborator Author

awgu commented Mar 29, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

**Overview**
This introduces a new static method `FSDP.rekey_optim_state_dict()` as a utility for interoperating between local/DDP (non-wrapped) models and FSDP (wrapped) models.

To load from a wrapped model to a non-wrapped model:
```
wrapped_model, wrapped_optim = ...
full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
nonwrapped_model, nonwrapped_optim = ...
rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
nonwrapped_optim.load_state_dict(rekeyed_osd)
```
To load from a non-wrapped model to a wrapped model:
```
nonwrapped_model, nonwrapped_optim = ...
osd = nonwrapped_optim.state_dict()
rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
wrapped_model, wrapped_optim = ...
sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
wrapped_optim.load_state_dict(sharded_osd)
```

**Test Plan**
`test_rekey_optim_state_dict_to_ids()` and `test_rekey_optim_state_dict_to_names()`.

Differential Revision: [D35225819](https://our.internmc.facebook.com/intern/diff/D35225819)

[ghstack-poisoned]
awgu pushed a commit that referenced this pull request Mar 29, 2022
@awgu
Copy link
Collaborator Author

awgu commented Mar 29, 2022

@awgu has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

facebook-github-bot pushed a commit that referenced this pull request Mar 30, 2022
Summary:
Pull Request resolved: #74912

**Overview**
This introduces a new static method `FSDP.rekey_optim_state_dict()` as a utility for interoperating between local/DDP (non-wrapped) models and FSDP (wrapped) models.

To load from a wrapped model to a non-wrapped model:
```
wrapped_model, wrapped_optim = ...
full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim)
nonwrapped_model, nonwrapped_optim = ...
rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model)
nonwrapped_optim.load_state_dict(rekeyed_osd)
```
To load from a non-wrapped model to a wrapped model:
```
nonwrapped_model, nonwrapped_optim = ...
osd = nonwrapped_optim.state_dict()
rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model)
wrapped_model, wrapped_optim = ...
sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model)
wrapped_optim.load_state_dict(sharded_osd)
```

**Test Plan**
`test_rekey_optim_state_dict_to_ids()` and `test_rekey_optim_state_dict_to_names()`.

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D35225819

Pulled By: awgu

fbshipit-source-id: fbbdbde8b595a9c65b17a9aecb4f22b2c9761a23
@github-actions
Copy link
Contributor

Hey @awgu.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed oncall: distributed Add this issue/PR to distributed oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants