[FSDP2] idempotent reset_sharded_param: no-op if _local_tensor is already padded#163130
[FSDP2] idempotent reset_sharded_param: no-op if _local_tensor is already padded#163130weifengpy wants to merge 7 commits intogh/weifengpy/31/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163130
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8351a06 with merge base f6ea41e ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
|
adding unit tests |
…dempotent" resolves pytorch/torchtitan#1136 cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
…dempotent" resolves pytorch/torchtitan#1136 cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
|
ci error is not relevant |
|
sorry I did not follow why we would need to reset the sharded param in the state dict pre hook and could not figure it out from the unit test (e.g. where did the padding get lost in the first place / why was there not padding re-added at that point) |
|
@awgu The current issue is that some training frameworks, including TorchTitan, calls |
| return | ||
| updated_local_tensor = False | ||
| # `reset_sharded_param` can be called twice | ||
| # 1st time in sd = model.state_dict() |
There was a problem hiding this comment.
Please update comments based on offline discussions:
- first time should be during
fully_shardcall - 2nd time could happen with / without state dict load. If with load, the 2nd time should not be a no-op.
There was a problem hiding this comment.
updated. good catch!
…dempotent" resolves pytorch/torchtitan#1136 torchtitan use cached state dict for ft. fsdp2 should run padding for sharded params ``` # should call reset sharded params for padding sd = fsdp_model.state_dict() # reset sharded params should be a no-op loss = fsdp_model(inp).sum() ``` this PR does two thing * reset sharded params in state dict pre hook * make sharded params idempotent by checking storage data ptr and return early unit test ``` pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_cached_state_dict ``` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
good question! no padding is getting losted. here is what I want to achieve fully_shard(model) with padded local_tensor -> model.state_dict() -> model(input) reset_sharded_param should be no-op I just need reset_sharded_param to be idempotent. without the PR, we are always creating new padded tensors, because No need to call reset_sharded_param in state dict hooks. I modified the PR |
…dempotent" resolves pytorch/torchtitan#1136 torchtitan use cached state dict for ft. reset_sharded_param should be idempotent if model.parameters() are padded already ``` # pad DTensor._local_tensor fully_shard(model) sd = fsdp_model.state_dict() # reset_sharded_param should be a no-op in lazy_init loss = fsdp_model(inp).sum() ``` this PR make `reset_sharded_param` idempotent by checking storage data ptr and return early unit test ``` pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_cached_state_dict ``` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-cuda12.8-py3.10-gcc11 / test (distributed, 3, 3, linux.g4dn.12xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
|
@weifengpy The test failure looks real. |
…nsor is already padded" resolves pytorch/torchtitan#1136 torchtitan use cached state dict for ft. reset_sharded_param should be idempotent if model.parameters() are padded already ``` # pad DTensor._local_tensor fully_shard(model) sd = fsdp_model.state_dict() # reset_sharded_param should be a no-op in lazy_init loss = fsdp_model(inp).sum() ``` this PR make `reset_sharded_param` idempotent by checking storage data ptr and return early unit test ``` pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_cached_state_dict ``` cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta ezyang msaroufim dcci [ghstack-poisoned]
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Thanks for the fix @weifengpy ! By reading the code and comments, the And this PR seems not merged into another base branch instead of main branch, is there any plan to upstream to main? |
right, updated the PR to skip tensor subclass |
that's right. for example, loading state dict triggers padding
this is ghstack so the branch looks weird. but it's merged into main |
…eady padded (pytorch#163130) resolves pytorch/torchtitan#1136 torchtitan use cached state dict for ft. reset_sharded_param should be idempotent if model.parameters() are padded already ``` # pad DTensor._local_tensor fully_shard(model) sd = fsdp_model.state_dict() # reset_sharded_param should be a no-op in lazy_init loss = fsdp_model(inp).sum() ``` this PR make `reset_sharded_param` idempotent by checking storage data ptr and return early unit test ``` pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_cached_state_dict ``` Pull Request resolved: pytorch#163130 Approved by: https://github.com/tianyu-l
…eady padded (pytorch#163130) resolves pytorch/torchtitan#1136 torchtitan use cached state dict for ft. reset_sharded_param should be idempotent if model.parameters() are padded already ``` # pad DTensor._local_tensor fully_shard(model) sd = fsdp_model.state_dict() # reset_sharded_param should be a no-op in lazy_init loss = fsdp_model(inp).sum() ``` this PR make `reset_sharded_param` idempotent by checking storage data ptr and return early unit test ``` pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_cached_state_dict ``` Pull Request resolved: pytorch#163130 Approved by: https://github.com/tianyu-l
…eady padded (pytorch#163130) resolves pytorch/torchtitan#1136 torchtitan use cached state dict for ft. reset_sharded_param should be idempotent if model.parameters() are padded already ``` # pad DTensor._local_tensor fully_shard(model) sd = fsdp_model.state_dict() # reset_sharded_param should be a no-op in lazy_init loss = fsdp_model(inp).sum() ``` this PR make `reset_sharded_param` idempotent by checking storage data ptr and return early unit test ``` pytest -s test/distributed/_composable/fsdp/test_fully_shard_state_dict.py -k test_cached_state_dict ``` Pull Request resolved: pytorch#163130 Approved by: https://github.com/tianyu-l
resolves pytorch/torchtitan#1136
torchtitan use cached state dict for ft. reset_sharded_param should be idempotent if model.parameters() are padded already
this PR make
reset_sharded_paramidempotent by checking storage data ptr and return earlyunit test
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim @dcci