fix: use grad div factor when fsdp_degree=1#167178
fix: use grad div factor when fsdp_degree=1#167178garrett361 wants to merge 11 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167178
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit a88ea4f with merge base dc00842 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Thanks a lot! Sounds right to me. I'll let @weifengpy review / stamp. Yes, a test would be appreciated. |
7078973 to
0490875
Compare
|
good catch! |
|
triggering CI - waiting for signals |
|
Thanks @weifengpy ! Figuring out CLA stuff with work, then will add a test. |
|
@weifengpy still waiting on CLA stuff, but started looking at where I'd add a test. I'm not seeing Should I add a test here? Seems like I'd have to add a decent bit of infra code to get a proper test set up. |
adding a unit test would be great! I was mentioning CI to make sure it does not break shard world size 2+, but having unit test on 1 is better cc @anshul-si for a bug fix in world size 1 |
Signed-off-by: Garrett Goon <goon@ibm.com>
337f84f to
8862e11
Compare
|
I found some other issues in the code and tests related to this topic:
CC @anshul-si @weifengpy @tianyu-l Tested locally that everything in |
|
@garrett361 thanks for the detailed unit tests |
|
Thanks @weifengpy , do you want me to make any changes to |
|
Hi @weifengpy , checking in on this. Anything else you need from my end? Thanks! |
@garrett361 just took another look and it's safe. we can land once CI passes. no need to cover _test_set_reduce_scatter_divide_factor_mixed_prevision |
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
fully_shard'sgradient_divide_factorisn't currently respected when the sharding degree = 1. This PR ensures the division factor applies also in this case.This is a bit of an edge case, but it arises in
torchtitan, e.g. with expert parallelism andep_degree=world_sizewe still wrap the routed experts infully_shardbecause:This PR ensures correctness in the
reduce_scatter_group.size()==1case.Reproducer and sample failures are in the gist here. The net effect is that the EP grads are too-large by a factor of the world size in the case described above. I checked that the proposed fix makes these tests pass.
I guess I should add a test for this, too?
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci