-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[FSDP][optim_state_dict][8/N] Enable fully_shard optim state_dict save and load #91234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…e and load [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91234
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit de7f826: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
…te_dict save and load" [ghstack-poisoned]
…te_dict save and load" **What does this PR do?** This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`. [ghstack-poisoned]
…te_dict save and load" **What does this PR do?** This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`. [ghstack-poisoned]
rohan-varma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Super exciting stuff :D
| FSDPInitMode, | ||
| FSDPTest, | ||
| TransformerWithSharedParams, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can we have formatting changes in separate PR?
I recognize this is tricky and I think it's time to align on formatting convention for FSDP codebase and automate it. cc @awgu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My plan was to just get everyone on lintrunner and lintrunner f beginning of next half. I decided since we are cranking PRs with urgency right now, we can just not worry about it. The PR to achieve this look like: #90873 I have re-pushed recently, but the main change is just in the .lintrunner.toml file and making sure all relevant files are compliant.
I do think that unifying under lintrunner / lintrunner f is nice. Sometimes I add changes to a file that create long lines or add imports, and I want to just auto-format. However, without an agreed-upon auto-formatter, this becomes a problem and actually complicates the workflow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will rebase this PR on top of #91255.
| return 2 | ||
|
|
||
| @skip_if_lt_x_gpu(2) | ||
| def _test_optim_state_dict_save_load(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be better to just have test instead of disabling with prefix, and adding skip decorator mentioning reason it is disabled and filing issue
| ): | ||
| _insert_module_state(submodule, state) | ||
| # Insert all comm_modules to the module to state mapping. | ||
| for submodule in state._fully_sharded_module_to_handles.keys(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this change equivalent to the former code? If not, is there a reasoning we're changing the inserted states?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not equivalent to the former code. The reason behind the change is to only map the modules that actually have the handles -- the local root modules.
| mapping between parameters and parameter IDs. Using ``optim_input`` is being | ||
| deprecated. | ||
| If the optimizer is a ``NamedOptimizer``, the optimizer state_dict does not |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if optim_input is provided but also it is a NamedOptimizer, will that create an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it will fail. Will add a error handling for this.
| composable_optim_state_dict["param_groups"], | ||
| ): | ||
| for key, value in group1.items(): | ||
| self.assertEqual(value, group2[key]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it worth adding tests for:
- non root FSDP
- DDP / replicate root
- nested FSDP + non root?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add an extra test for non root FSDP. Will add more tests after fixing the all_gather_object issue that prevent us from running tests on CI.
…te_dict save and load" **What does this PR do?** This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`. [ghstack-poisoned]
…te_dict save and load" **What does this PR do?** This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`. [ghstack-poisoned]
…onding test folders" This PR apply ufmt to format `_composable` related code. This is a request from #91234 to separate formatting changes as a new PR. [ghstack-poisoned]
… and the corresponding test folders" This PR apply ufmt to format `_composable` related code. This is a request from #91234 to separate formatting changes as a new PR. [ghstack-poisoned]
…onding test folders" This PR apply ufmt to format `_composable` related code. This is a request from #91234 to separate formatting changes as a new PR. [ghstack-poisoned]
…te_dict save and load" **What does this PR do?** This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`. [ghstack-poisoned]
…te_dict save and load" **What does this PR do?** This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`. [ghstack-poisoned]
… folders (#91255) This PR apply ufmt to format `_composable` related code. This is a request from #91234 to separate formatting changes as a new PR. Pull Request resolved: #91255 Approved by: https://github.com/awgu
…te_dict save and load" **What does this PR do?** This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`. [ghstack-poisoned]
…te_dict save and load" **What does this PR do?** This PR refactor `_optim_utils.py` to use `_FSDPState` instead of `FullyShardedDataParallel` class. This change enables the support of optim state_dict for `fully_shard`. [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 additional jobs have failed, first few of them are: trunk ,trunk / linux-focal-rocm5.3-py3.8 / test (default, 2, 2, linux.rocm.gpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
What does this PR do?
This PR refactor
_optim_utils.pyto use_FSDPStateinstead ofFullyShardedDataParallelclass. This change enables the support of optim state_dict forfully_shard.