-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP][1/N] Refactor module materialization #94196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/94196
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 221ef73: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: d064a75 Pull Request resolved: pytorch#94196
rohan-varma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks!
| if is_meta_module or is_torchdistX_deferred_init: | ||
| materialized_module = True | ||
| # Save the parameter and buffer names to reacquire references after | ||
| # after materialization since their variables may change |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even after reading the PR description, I'm not 100% sure why the variables may change after materialization? I thought materialization is all about filling in meta parameters with their actual values?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but at the implementation level, module.to(device) will replace all meta-device parameters with new Python parameter variables.
This function returns False for meta-device tensors:
pytorch/torch/nn/modules/module.py
Lines 799 to 800 in a064ce1
| def compute_should_use_set_data(tensor, tensor_applied): | |
| if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): |
This leads to the
else branch:pytorch/torch/nn/modules/module.py
Lines 821 to 829 in a064ce1
| should_use_set_data = compute_should_use_set_data(param, param_applied) | |
| if should_use_set_data: | |
| param.data = param_applied | |
| out_param = param | |
| else: | |
| assert isinstance(param, Parameter) | |
| assert param.is_leaf | |
| out_param = Parameter(param_applied, param.requires_grad) | |
| self._parameters[key] = out_param |
| elif is_meta_module: | ||
| _materialize_meta_module(fully_sharded_module, device_id) | ||
| elif is_torchdistX_deferred_init: | ||
| deferred_init.materialize_module( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we have unittests covering deferred init for the composable path?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We do not because torchdistX does not support the latest PyTorch version, so I could not run torchdistX locally.
**Overview** This refactors module materialization (i.e. meta device or `torchdistX` deferred initialization) to compute the parameter and buffer names as needed instead of pre-computing them. These are needed to reacquire references to the states (e.g. `module.get_parameter(param_name)`) after materialization since the materialization may create new variables. This refactor simplifies `_get_fully_sharded_module_to_states()` (the core function for "pseudo auto wrapping") to better enable lowest common ancestor (LCA) module computation for shared parameters, for which tracking parameter and buffer names may complicate the already non-obvious implementation. **Discussion** The tradeoff is a worst case quadratic traversal over modules if materializing all of them. However, since (1) the number of modules is relatively small, (2) the computation per module in the quadratic traversal is negligible, (3) this runs only once per training session, and (4) module materialization targets truly large models, I think this tradeoff is tolerable. **For Reviewers** - `_init_param_handle_from_module()` initializes _one_ `FlatParamHandle` from a fully sharded module and represents the module wrapper code path. For this code path, there is no need to reacquire references to the parameters/buffers for now since the managed parameters are only computed after materialization. This works because the managed parameters have a simple definition: any parameter in the local root module's tree excluding those already marked as flattened by FSDP. Similarly, FSDP marks buffers to indicate that they have already been processed (synced if `sync_module_states`). - `_init_param_handles_from_module()` initializes _all_ `FlatParamHandle`s from a fully sharded module and represents the composable code path. For this code path, we must reacquire references to parameters/buffers because each logical wrapping is specified as a list of parameters/buffers to group together by those variables and because materialization may create new variables. [ghstack-poisoned]
ghstack-source-id: 3563d70 Pull Request resolved: pytorch#94196
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack:
fully_shard#94198 [FSDP][3/N] Add LCA logic tofully_shardOverview
This refactors module materialization (i.e. meta device or
torchdistXdeferred initialization) to compute the parameter and buffer names as needed instead of pre-computing them. These are needed to reacquire references to the states (e.g.module.get_parameter(param_name)) after materialization since the materialization may create new variables.This refactor simplifies
_get_fully_sharded_module_to_states()(the core function for "pseudo auto wrapping") to better enable lowest common ancestor (LCA) module computation for shared parameters, for which tracking parameter and buffer names may complicate the already non-obvious implementation.Discussion
The tradeoff is a worst case quadratic traversal over modules if materializing all of them. However, since (1) the number of modules is relatively small, (2) the computation per module in the quadratic traversal is negligible, (3) this runs only once per training session, and (4) module materialization targets truly large models, I think this tradeoff is tolerable.
For Reviewers
_init_param_handle_from_module()initializes oneFlatParamHandlefrom a fully sharded module and represents the module wrapper code path. For this code path, there is no need to reacquire references to the parameters/buffers for now since the managed parameters are only computed after materialization. This works because the managed parameters have a simple definition: any parameter in the local root module's tree excluding those already marked as flattened by FSDP. Similarly, FSDP marks buffers to indicate that they have already been processed (synced ifsync_module_states)._init_param_handles_from_module()initializes allFlatParamHandles from a fully sharded module and represents the composable code path. For this code path, we must reacquire references to parameters/buffers because each logical wrapping is specified as a list of parameters/buffers to group together by those variables and because materialization may create new variables.