Skip to content

Conversation

@zhaojuanmao
Copy link
Contributor

Make input casting in root module only in default, meanwhile allowing to set different mixed precisions for different submodules

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 23, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/91365

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 Failures

As of commit 8fa7531:

FLAKY - The following jobs failed but were likely due to flakiness present on master:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: distributed (fsdp) release notes category label Dec 23, 2022
Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, the PR looks good to me. I left a few nits about clarifying comments, and also we should fix the cast_root_foward_inputs typo by changing to cast_root_forward_inputs.

Feel free to re-request review when ready.

@awgu awgu self-requested a review December 27, 2022 22:46
@zhaojuanmao zhaojuanmao force-pushed the mixedPrcesionConversion branch 2 times, most recently from 0a861cc to 0b0a7dd Compare December 28, 2022 04:21
@facebook-github-bot
Copy link
Contributor

@zhaojuanmao has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@zhaojuanmao zhaojuanmao force-pushed the mixedPrcesionConversion branch from 0b0a7dd to 8fa7531 Compare December 28, 2022 05:02
@facebook-github-bot
Copy link
Contributor

@zhaojuanmao has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

Copy link
Collaborator

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

FSDP(FSDP(model.c2, MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True)),
model.c1, MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True)),
model.c1 should be the first one executed, so that its inputs could be casted
as expected inside the root FSDP instance.see examples in unit tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure if we want to reference the unit tests like this from public docs since (1) users should not need to dig into our unit tests to understand the note and (2) the unit test may change without us remembering to change this public note.

We do not need to change address this in this PR. I can submit a follow-up if you do not mind, or you can do it as well.

@zhaojuanmao
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 28, 2022
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 additional jobs have failed, first few of them are: trunk ,trunk / linux-focal-rocm5.3-py3.8 / test (default, 1, 2, linux.rocm.gpu)

Details for Dev Infra team Raised by workflow job

@zhaojuanmao
Copy link
Contributor Author

@pytorchbot merge -f "failures are not related"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the mixedPrcesionConversion branch June 28, 2024 01:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: distributed (fsdp) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants