Skip to content

[FSDP] Investigate param dtype change after FSDP constructor #90838

@awgu

Description

@awgu

#90660 introduced an attribute _orig_param_dtype as the basis for telling if parameter/gradient reduction is enabled and what those default dtypes should be. However, we compute this attribute's value in the FSDP constructor. If the user changes the model's dtype after the FSDP constructor. This may lead to incorrect results (e.g. FSDP will cast the gradient to match _orig_param_dtype in preparation for the optimizer step, but that dtype may mismatch the new parameter dtype).

The workaround is that the user should have the model in the desired dtype before passing to FSDP's constructor. As far as I can tell, there should be no difference. If anything, the FSDP initialization may run faster if the model is already cast to a low precision.

The long term solution if we want to support this dtype change after the FSDP constructor is to initialize or re-initialize the saved dtype attributes like _orig_param_dtype during lazy initialization. Then, we allow the user until the first forward pass to change the model, increasing the window slightly. However, since none of our current users train in pure low precision and instead use our native mixed precision API, we are not prioritizing this for now.

cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501

Metadata

Metadata

Assignees

Labels

module: fsdponcall: distributedAdd this issue/PR to distributed oncall triage queuetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions