temp fix state dict loading: avoid cache_state_dict#1702
temp fix state dict loading: avoid cache_state_dict#1702weifengpy wants to merge 1 commit intopytorch:mainfrom
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
|
keep this as draft. this is to unblock customers if they are urgent. I still need to understand how cache_state_dict is related to uneven sharding |
|
This is an excellent finding. But I don't understand why? We cache during Checkpointer ctor, which should be after model is wrapped and before the training. However the timing of DCP.load should also be the same. So caching or not caching shouldn't make a big difference. |
discussed with @fegin , the core is returning padded parameters when user call model.state_dict(). there are 2 ways
a minimal repro is |
|
abandon this PR. landing fsdp2 side fix: pytorch/pytorch#163130 |
temporary fix for #1136 where loss are worse when resuming training from checkpoint
loss becomes exactly the same after disabling cache_state_dict. I need to understand why this only happens to unven sharding. but provid this workaround to unblock customers
1st run to save checkpoint at step 10 and 20:
NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh2nd run to load checkpoint from step 10:
rm -rf outputs/checkpoint/step-20 && NGPU=4 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh