Skip to content

fix: use grad div factor when fsdp_degree=1#167178

Closed
garrett361 wants to merge 11 commits intopytorch:mainfrom
garrett361:fully-shard-divide-fix
Closed

fix: use grad div factor when fsdp_degree=1#167178
garrett361 wants to merge 11 commits intopytorch:mainfrom
garrett361:fully-shard-divide-fix

Conversation

@garrett361
Copy link
Contributor

@garrett361 garrett361 commented Nov 6, 2025

fully_shard's gradient_divide_factor isn't currently respected when the sharding degree = 1. This PR ensures the division factor applies also in this case.

This is a bit of an edge case, but it arises in torchtitan, e.g. with expert parallelism and ep_degree=world_size we still wrap the routed experts in fully_shard because:

  1. It lets us take advantage of its mixed-precision mechanisms.
  2. A specific gradient_divide_factor is needed for correctness

This PR ensures correctness in the reduce_scatter_group.size()==1 case.

Reproducer and sample failures are in the gist here. The net effect is that the EP grads are too-large by a factor of the world size in the case described above. I checked that the proposed fix makes these tests pass.

I guess I should add a test for this, too?

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/167178

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit a88ea4f with merge base dc00842 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Nov 6, 2025
@linux-foundation-easycla
Copy link

linux-foundation-easycla bot commented Nov 6, 2025

CLA Signed

The committers listed above are authorized under a signed CLA.

@tianyu-l
Copy link
Contributor

tianyu-l commented Nov 6, 2025

Thanks a lot! Sounds right to me. I'll let @weifengpy review / stamp.

Yes, a test would be appreciated.

@garrett361 garrett361 force-pushed the fully-shard-divide-fix branch from 7078973 to 0490875 Compare November 6, 2025 03:45
@weifengpy
Copy link
Contributor

good catch!

@weifengpy
Copy link
Contributor

triggering CI - waiting for signals

@garrett361
Copy link
Contributor Author

Thanks @weifengpy ! Figuring out CLA stuff with work, then will add a test.

@garrett361
Copy link
Contributor Author

@weifengpy still waiting on CLA stuff, but started looking at where I'd add a test. I'm not seeing gradient_divide_factor tested anywhere in the code base. Am I missing a test somewhere? Closest I can find is in TestFullyShardCollectiveOps where it's mentioned but left as None.

Should I add a test here? Seems like I'd have to add a decent bit of infra code to get a proper test set up.

@weifengpy
Copy link
Contributor

Should I add a test here? Seems like I'd have to add a decent bit of infra code to get a proper test set up.

adding a unit test would be great! I was mentioning CI to make sure it does not break shard world size 2+, but having unit test on 1 is better

cc @anshul-si for a bug fix in world size 1

@garrett361 garrett361 force-pushed the fully-shard-divide-fix branch from 337f84f to 8862e11 Compare November 11, 2025 15:26
@garrett361
Copy link
Contributor Author

I found some other issues in the code and tests related to this topic:

  1. There is a current edge case where if the user calls model.set_gradient_divide_factor(factor) and factor happens to equal the data parallel size, then a (I believe) unintended code path is taken: we end up dividing grads by factor and then averaging over the allreduce group, rather than summing. From the code below:

    if not overflow_risk and not force_sum_reduction_for_comms:
    if factor == data_parallel_size:
    # Warning: NCCL ReduceOp.AVG may produce incorrect results with
    # world size 1.
    if data_parallel_size == 1:
    return None, None, ReduceOp.SUM, ReduceOp.SUM
    return None, None, ReduceOp.AVG, ReduceOp.AVG
    else:
    reduce_scatter_op = torch.distributed._make_nccl_premul_sum(1 / factor)
    return None, None, reduce_scatter_op, ReduceOp.SUM
    if factor != data_parallel_size then we end up at the final line and return None, None, reduce_scatter_op, ReduceOp.SUM. But if we happen to set factor = data_parallel_size, then we enter the first if statement and return None, None, ReduceOp.AVG, ReduceOp.AVG, assuming data_parallel_size > 1. This is a change in semantics due to swapping the final ReduceOp.SUM return value for a ReduceOp.AVG in the latter case. I believe we should always want ReduceOp.SUM if a custom division factor is provided and changed the code to reflect that, but let me know if that is not desired.

  2. I found the relevant _test_set_reduce_scatter_divide_factor test, and as written it wasn't sensitive enough to catch some of these issues. I updated this PR to increase the sensitivity of this test and cover more cases. This is how I caught the above edge case.

  3. There's also the test _test_set_reduce_scatter_divide_factor_mixed_prevision, but this doesn't seem to actually test anything about custom division factors because we apply the division factor only to the outermost fully_shard wrapped module (the whole model) which doesn't seem to own any parameters itself, since it's a Sequential:

    model = nn.Sequential(*[MLP(16) for _ in range(3)])
    ref_model = copy.deepcopy(model).to(device_type)
    ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
    ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
    for mlp in model:
    fully_shard(mlp, mp_policy=mp_policy)
    model = fully_shard(model, mp_policy=mp_policy)
    optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
    model.set_reduce_scatter_divide_factor(divide_factor)
    Sound right? I didn't touch this test yet, but can if desired. Trying to keep the PR small.

CC @anshul-si @weifengpy @tianyu-l

Tested locally that everything in test/distributed/_composable/fsdp/test_fully_shard_comm.py passes.

@weifengpy
Copy link
Contributor

@garrett361 thanks for the detailed unit tests

@garrett361
Copy link
Contributor Author

Thanks @weifengpy , do you want me to make any changes to _test_set_reduce_scatter_divide_factor_mixed_prevision as well?

@garrett361
Copy link
Contributor Author

Hi @weifengpy , checking in on this. Anything else you need from my end? Thanks!

@weifengpy
Copy link
Contributor

Hi @weifengpy , checking in on this. Anything else you need from my end? Thanks!

@garrett361 just took another look and it's safe. we can land once CI passes. no need to cover _test_set_reduce_scatter_divide_factor_mixed_prevision

@weifengpy
Copy link
Contributor

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 19, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue open source release notes: distributed (fsdp) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants