Skip to content

[FSDP2] idempotent reset_sharded_param: no-op if _local_tensor is already padded#163130

Closed
weifengpy wants to merge 7 commits intogh/weifengpy/31/basefrom
gh/weifengpy/31/head
Closed

[FSDP2] idempotent reset_sharded_param: no-op if _local_tensor is already padded#163130
weifengpy wants to merge 7 commits intogh/weifengpy/31/basefrom
gh/weifengpy/31/head

Conversation

@weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Sep 17, 2025

resolves pytorch/torchtitan#1136

torchtitan use cached state dict for ft. reset_sharded_param should be idempotent if model.parameters() are padded already

# pad DTensor._local_tensor
fully_shard(model)
sd = fsdp_model.state_dict()
# reset_sharded_param should be a no-op in lazy_init
loss = fsdp_model(inp).sum()

this PR make reset_sharded_param idempotent by checking storage data ptr and return early

unit test

pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_cached_state_dict

Stack from ghstack (oldest at bottom):

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163130

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8351a06 with merge base f6ea41e (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

weifengpy added a commit that referenced this pull request Sep 17, 2025
@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Sep 17, 2025
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Sep 17, 2025
@weifengpy weifengpy changed the title reset sharded params in state_dict() and make it idempotent [FSDP2] reset sharded params in state_dict() and make it idempotent Sep 17, 2025
@weifengpy weifengpy marked this pull request as draft September 17, 2025 01:07
@weifengpy
Copy link
Contributor Author

adding unit tests

…dempotent"

resolves pytorch/torchtitan#1136




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Sep 17, 2025
…dempotent"

resolves pytorch/torchtitan#1136




cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Sep 17, 2025
@weifengpy weifengpy marked this pull request as ready for review September 17, 2025 08:31
@weifengpy
Copy link
Contributor Author

ci error is not relevant

@awgu
Copy link
Collaborator

awgu commented Sep 17, 2025

sorry I did not follow why we would need to reset the sharded param in the state dict pre hook and could not figure it out from the unit test

(e.g. where did the padding get lost in the first place / why was there not padding re-added at that point)

@fegin
Copy link
Contributor

fegin commented Sep 17, 2025

@awgu The current issue is that some training frameworks, including TorchTitan, calls model.state_dict() before the first forward() and uses that result throughout the entire training without calling model.state_dict() again. This will result in incorrect checkpoints being saved.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

return
updated_local_tensor = False
# `reset_sharded_param` can be called twice
# 1st time in sd = model.state_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update comments based on offline discussions:

  • first time should be during fully_shard call
  • 2nd time could happen with / without state dict load. If with load, the 2nd time should not be a no-op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated. good catch!

…dempotent"

resolves pytorch/torchtitan#1136

torchtitan use cached state dict for ft. fsdp2 should run padding for sharded params
```
# should call reset sharded params for padding
sd = fsdp_model.state_dict()
# reset sharded params should be a no-op
loss = fsdp_model(inp).sum()
```

this PR does two thing
* reset sharded params in state dict pre hook 
* make sharded params idempotent by checking storage data ptr and return early


unit test
```
pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_cached_state_dict
```



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Sep 17, 2025
@weifengpy
Copy link
Contributor Author

weifengpy commented Sep 17, 2025

where did the padding get lost in the first place / why was there not padding re-added at that point

good question! no padding is getting losted. here is what I want to achieve

fully_shard(model) with padded local_tensor -> model.state_dict() -> model(input) reset_sharded_param should be no-op

I just need reset_sharded_param to be idempotent. without the PR, we are always creating new padded tensors, because local_tensor.size() != padded_sharded_size is always true (local_tensor is narrowed to origninal size after padding)

No need to call reset_sharded_param in state dict hooks. I modified the PR

…dempotent"

resolves pytorch/torchtitan#1136

torchtitan use cached state dict for ft. reset_sharded_param should be idempotent if model.parameters() are padded already

```
# pad DTensor._local_tensor
fully_shard(model)
sd = fsdp_model.state_dict()
# reset_sharded_param should be a no-op in lazy_init
loss = fsdp_model(inp).sum()
```

this PR make `reset_sharded_param` idempotent by checking storage data ptr and return early

unit test
```
pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_cached_state_dict
```



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Sep 17, 2025
@weifengpy weifengpy changed the title [FSDP2] reset sharded params in state_dict() and make it idempotent [FSDP2] idempotent reset_sharded_param: no-op of _local_tensor are already padded Sep 17, 2025
@weifengpy
Copy link
Contributor Author

@pytorchmergebot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 17, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (distributed, 3, 3, linux.g4dn.12xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@fegin
Copy link
Contributor

fegin commented Sep 18, 2025

@weifengpy The test failure looks real.

…nsor is already padded"

resolves pytorch/torchtitan#1136

torchtitan use cached state dict for ft. reset_sharded_param should be idempotent if model.parameters() are padded already

```
# pad DTensor._local_tensor
fully_shard(model)
sd = fsdp_model.state_dict()
# reset_sharded_param should be a no-op in lazy_init
loss = fsdp_model(inp).sum()
```

this PR make `reset_sharded_param` idempotent by checking storage data ptr and return early

unit test
```
pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_cached_state_dict
```



cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci

[ghstack-poisoned]
weifengpy added a commit that referenced this pull request Sep 18, 2025
@weifengpy
Copy link
Contributor Author

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@wwwjn
Copy link

wwwjn commented Sep 18, 2025

model.state_dict() -> model(input) reset_sharded_param should be no-op

Thanks for the fix @weifengpy ! By reading the code and comments, the reset_sharded_param() should only pad again if the model local tensor changed, right?

And this PR seems not merged into another base branch instead of main branch, is there any plan to upstream to main?

@weifengpy
Copy link
Contributor Author

@weifengpy The test failure looks real.

right, updated the PR to skip tensor subclass

@weifengpy
Copy link
Contributor Author

Thanks for the fix @weifengpy ! By reading the code and comments, the reset_sharded_param() should only pad again if the model local tensor changed, right?

that's right. for example, loading state dict triggers padding

And this PR seems not merged into another base branch instead of main branch, is there any plan to upstream to main?

this is ghstack so the branch looks weird. but it's merged into main

@weifengpy weifengpy added release notes: distributed (fsdp2) release notes category and removed release notes: distributed (fsdp) release notes category labels Sep 18, 2025
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…eady padded (pytorch#163130)

resolves pytorch/torchtitan#1136

torchtitan use cached state dict for ft. reset_sharded_param should be idempotent if model.parameters() are padded already

```
# pad DTensor._local_tensor
fully_shard(model)
sd = fsdp_model.state_dict()
# reset_sharded_param should be a no-op in lazy_init
loss = fsdp_model(inp).sum()
```

this PR make `reset_sharded_param` idempotent by checking storage data ptr and return early

unit test
```
pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_cached_state_dict
```

Pull Request resolved: pytorch#163130
Approved by: https://github.com/tianyu-l
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…eady padded (pytorch#163130)

resolves pytorch/torchtitan#1136

torchtitan use cached state dict for ft. reset_sharded_param should be idempotent if model.parameters() are padded already

```
# pad DTensor._local_tensor
fully_shard(model)
sd = fsdp_model.state_dict()
# reset_sharded_param should be a no-op in lazy_init
loss = fsdp_model(inp).sum()
```

this PR make `reset_sharded_param` idempotent by checking storage data ptr and return early

unit test
```
pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_cached_state_dict
```

Pull Request resolved: pytorch#163130
Approved by: https://github.com/tianyu-l
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…eady padded (pytorch#163130)

resolves pytorch/torchtitan#1136

torchtitan use cached state dict for ft. reset_sharded_param should be idempotent if model.parameters() are padded already

```
# pad DTensor._local_tensor
fully_shard(model)
sd = fsdp_model.state_dict()
# reset_sharded_param should be a no-op in lazy_init
loss = fsdp_model(inp).sum()
```

this PR make `reset_sharded_param` idempotent by checking storage data ptr and return early

unit test
```
pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_cached_state_dict
```

Pull Request resolved: pytorch#163130
Approved by: https://github.com/tianyu-l
@github-actions github-actions bot deleted the gh/weifengpy/31/head branch October 19, 2025 02:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp2) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Inconsistent loss when resume training with vocab size that is not divisible by world size.

6 participants