You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update on "Test FSDP with submodule non-reentrant checkpointing"
With combining FSDP with reentrant checkpointing, the post backward
hook might run twice, and then hit [this
error](https://github.com/pytorch/pytorch/blob/e20ec44544c17d6d3d411f88b870e05043bda731/torch/distributed/fsdp/_runtime_utils.py#L487).
This is because reentrant backward uses nested autograd GraphTasks.
The inner GraphTask is not aware of the outer one and therefore
will flush pending `AccumulateGrad` invocations on exit, which in
turn triggers the post backward hooks registered by FSDP. Later,
the outer GraphTask will trigger that again, leading to the above
error.
PR #89791 relaxes the FSDP training state check, but we still run
into grad value check failures occasionally. Therefore, this PR only
lands the test for non-reentrant test, and we can enable the
reentrant test when the accuracy issues are addressed.
[ghstack-poisoned]
0 commit comments