Skip to content

Conversation

@awgu
Copy link
Collaborator

@awgu awgu commented Feb 6, 2023

Stack from ghstack:

Overview
This refactors module materialization (i.e. meta device or torchdistX deferred initialization) to compute the parameter and buffer names as needed instead of pre-computing them. These are needed to reacquire references to the states (e.g. module.get_parameter(param_name)) after materialization since the materialization may create new variables.

This refactor simplifies _get_fully_sharded_module_to_states() (the core function for "pseudo auto wrapping") to better enable lowest common ancestor (LCA) module computation for shared parameters, for which tracking parameter and buffer names may complicate the already non-obvious implementation.

Discussion
The tradeoff is a worst case quadratic traversal over modules if materializing all of them. However, since (1) the number of modules is relatively small, (2) the computation per module in the quadratic traversal is negligible, (3) this runs only once per training session, and (4) module materialization targets truly large models, I think this tradeoff is tolerable.

For Reviewers

  • _init_param_handle_from_module() initializes one FlatParamHandle from a fully sharded module and represents the module wrapper code path. For this code path, there is no need to reacquire references to the parameters/buffers for now since the managed parameters are only computed after materialization. This works because the managed parameters have a simple definition: any parameter in the local root module's tree excluding those already marked as flattened by FSDP. Similarly, FSDP marks buffers to indicate that they have already been processed (synced if sync_module_states).
  • _init_param_handles_from_module() initializes all FlatParamHandles from a fully sharded module and represents the composable code path. For this code path, we must reacquire references to parameters/buffers because each logical wrapping is specified as a list of parameters/buffers to group together by those variables and because materialization may create new variables.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Feb 6, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Feb 6, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94196

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 221ef73:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

awgu pushed a commit to awgu/pytorch that referenced this pull request Feb 6, 2023
ghstack-source-id: d064a75
Pull Request resolved: pytorch#94196
@awgu awgu added the topic: not user facing topic category label Feb 6, 2023
@awgu awgu marked this pull request as ready for review February 6, 2023 21:37
Copy link
Contributor

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks!

if is_meta_module or is_torchdistX_deferred_init:
materialized_module = True
# Save the parameter and buffer names to reacquire references after
# after materialization since their variables may change
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even after reading the PR description, I'm not 100% sure why the variables may change after materialization? I thought materialization is all about filling in meta parameters with their actual values?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but at the implementation level, module.to(device) will replace all meta-device parameters with new Python parameter variables.

This function returns False for meta-device tensors:

def compute_should_use_set_data(tensor, tensor_applied):
if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):

This leads to the else branch:
should_use_set_data = compute_should_use_set_data(param, param_applied)
if should_use_set_data:
param.data = param_applied
out_param = param
else:
assert isinstance(param, Parameter)
assert param.is_leaf
out_param = Parameter(param_applied, param.requires_grad)
self._parameters[key] = out_param

elif is_meta_module:
_materialize_meta_module(fully_sharded_module, device_id)
elif is_torchdistX_deferred_init:
deferred_init.materialize_module(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have unittests covering deferred init for the composable path?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not because torchdistX does not support the latest PyTorch version, so I could not run torchdistX locally.

**Overview**
This refactors module materialization (i.e. meta device or `torchdistX` deferred initialization) to compute the parameter and buffer names as needed instead of pre-computing them. These are needed to reacquire references to the states (e.g. `module.get_parameter(param_name)`) after materialization since the materialization may create new variables.

This refactor simplifies `_get_fully_sharded_module_to_states()` (the core function for "pseudo auto wrapping") to better enable lowest common ancestor (LCA) module computation for shared parameters, for which tracking parameter and buffer names may complicate the already non-obvious implementation.

**Discussion**
The tradeoff is a worst case quadratic traversal over modules if materializing all of them. However, since (1) the number of modules is relatively small, (2) the computation per module in the quadratic traversal is negligible, (3) this runs only once per training session, and (4) module materialization targets truly large models, I think this tradeoff is tolerable.

**For Reviewers**
- `_init_param_handle_from_module()` initializes _one_ `FlatParamHandle` from a fully sharded module and represents the module wrapper code path. For this code path, there is no need to reacquire references to the parameters/buffers for now since the managed parameters are only computed after materialization. This works because the managed parameters have a simple definition: any parameter in the local root module's tree excluding those already marked as flattened by FSDP. Similarly, FSDP marks buffers to indicate that they have already been processed (synced if `sync_module_states`).
- `_init_param_handles_from_module()` initializes _all_ `FlatParamHandle`s from a fully sharded module and represents the composable code path. For this code path, we must reacquire references to parameters/buffers because each logical wrapping is specified as a list of parameters/buffers to group together by those variables and because materialization may create new variables.


[ghstack-poisoned]
awgu pushed a commit to awgu/pytorch that referenced this pull request Feb 13, 2023
ghstack-source-id: 3563d70
Pull Request resolved: pytorch#94196
@awgu awgu requested a review from fegin as a code owner February 13, 2023 17:15
@awgu awgu added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 13, 2023
@awgu
Copy link
Collaborator Author

awgu commented Feb 13, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@facebook-github-bot facebook-github-bot deleted the gh/awgu/322/head branch June 8, 2023 15:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants