-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Test FSDP with submodule non-reentrant checkpointing #89781
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/89781
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 73c7752: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
[ghstack-poisoned]
[ghstack-poisoned]
awgu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding these tests!
At a meta level, I am wondering how we should approach testing more interleavings in a systematic and complete way.
Also, I am not sure if you want to wait for the assert relaxation PR to land and then update test_checkpoint_submodule_reentrant() or not.
rohan-varma
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, stamping to unblock. Will file a follow up issue to debug why this FSDP + AC structure does not work.
With combining FSDP with reentrant checkpointing, the post backward hook might run twice, and then hit [this error](https://github.com/pytorch/pytorch/blob/e20ec44544c17d6d3d411f88b870e05043bda731/torch/distributed/fsdp/_runtime_utils.py#L487). This is because reentrant backward uses nested autograd GraphTasks. The inner GraphTask is not aware of the outer one and therefore will flush pending `AccumulateGrad` invocations on exit, which in turn triggers the post backward hooks registered by FSDP. Later, the outer GraphTask will trigger that again, leading to the above error. PR #89791 relaxes the FSDP training state check, but we still run into grad value check failures occasionally. Therefore, this PR only lands the test for non-reentrant test, and we can enable the reentrant test when the accuracy issues are addressed. [ghstack-poisoned]
With combining FSDP with reentrant checkpointing, the post backward hook might run twice, and then hit [this error](https://github.com/pytorch/pytorch/blob/e20ec44544c17d6d3d411f88b870e05043bda731/torch/distributed/fsdp/_runtime_utils.py#L487). This is because reentrant backward uses nested autograd GraphTasks. The inner GraphTask is not aware of the outer one and therefore will flush pending `AccumulateGrad` invocations on exit, which in turn triggers the post backward hooks registered by FSDP. Later, the outer GraphTask will trigger that again, leading to the above error. PR #89791 relaxes the FSDP training state check, but we still run into grad value check failures occasionally. Therefore, this PR only lands the test for non-reentrant test, and we can enable the reentrant test when the accuracy issues are addressed. ghstack-source-id: 8848c4c Pull Request resolved: #89781
Updated the PR summary to include that. Due to the grad value issue, the new tests are not testing that code path at the moment. |
|
@pytorchbot merge -g |
Merge startedYour change will be merged once all checks on your PR pass since you used the green (-g) flag (ETA: 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
With combining FSDP with reentrant checkpointing, the post backward hook might run twice, and then hit [this error](https://github.com/pytorch/pytorch/blob/e20ec44544c17d6d3d411f88b870e05043bda731/torch/distributed/fsdp/_runtime_utils.py#L487). This is because reentrant backward uses nested autograd GraphTasks. The inner GraphTask is not aware of the outer one and therefore will flush pending `AccumulateGrad` invocations on exit, which in turn triggers the post backward hooks registered by FSDP. Later, the outer GraphTask will trigger that again, leading to the above error. PR pytorch#89791 relaxes the FSDP training state check, but we still run into grad value check failures occasionally. Therefore, this PR only lands the test for non-reentrant test, and we can enable the reentrant test when the accuracy issues are addressed. Pull Request resolved: pytorch#89781 Approved by: https://github.com/rohan-varma
Stack from ghstack (oldest at bottom):
With combining FSDP with reentrant checkpointing, the post backward
hook might run twice, and then hit this
error.
This is because reentrant backward uses nested autograd GraphTasks.
The inner GraphTask is not aware of the outer one and therefore
will flush pending
AccumulateGradinvocations on exit, which inturn triggers the post backward hooks registered by FSDP. Later,
the outer GraphTask will trigger that again, leading to the above
error.
PR #89791 relaxes the FSDP training state check, but we still run
into grad value check failures occasionally. Therefore, this PR only
lands the test for non-reentrant test, and we can enable the
reentrant test when the accuracy issues are addressed.