-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
#90660 introduced an attribute _orig_param_dtype as the basis for telling if parameter/gradient reduction is enabled and what those default dtypes should be. However, we compute this attribute's value in the FSDP constructor. If the user changes the model's dtype after the FSDP constructor. This may lead to incorrect results (e.g. FSDP will cast the gradient to match _orig_param_dtype in preparation for the optimizer step, but that dtype may mismatch the new parameter dtype).
The workaround is that the user should have the model in the desired dtype before passing to FSDP's constructor. As far as I can tell, there should be no difference. If anything, the FSDP initialization may run faster if the model is already cast to a low precision.
The long term solution if we want to support this dtype change after the FSDP constructor is to initialize or re-initialize the saved dtype attributes like _orig_param_dtype during lazy initialization. Then, we allow the user until the first forward pass to change the model, increasing the window slightly. However, since none of our current users train in pure low precision and instead use our native mixed precision API, we are not prioritizing this for now.
cc @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @H-Huang @kwen2501