-
Notifications
You must be signed in to change notification settings - Fork 26.3k
RowwiseMoments: use float as acc type for bfloat16 inputs #81850
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Originally `utils::RowwiseMoments<BFloat16>` will still accululate on BFloat16, which is not only slow but also introducing additional rounding errors. This patch will do accumulation on float for the bfloat16 inputs: each of bfloat16 vec (size 16) will be converted to two float vec (size 8), and accumulated on m1(mean) and m2(rstd) vecs which are all float vecs. [ghstack-poisoned]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 8d6edf6 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
|
This PR is to fix #77507 Allowing bfloat16 to be accumulated in float32 also brings performance improvement since we don't have to redundant dtype conversion which is very time consuming.
this PR has no effect on fp32 performance, will bring avx512 result
avx2 result
|
To fix #77507 Originally `utils::RowwiseMoments<BFloat16>` will still accululate on BFloat16, which is not only slow but also introducing additional rounding errors. This patch will do accumulation on float for the bfloat16 inputs: each of bfloat16 vec (size 16) will be converted to two float vec (size 8), and accumulated on m1(mean) and m2(rstd) vecs which are all float vecs. No effect on float performance, will improve bfloat16 performance: * avx512 single socket: ``` before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.210 ms; bf16: 0.770 ms after: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.215 ms; bf16: 0.178 ms ``` * avx512 single core: ``` before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.661 ms; bf16: 12.267 ms after: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.618 ms; bf16: 2.309 ms ``` * avx2 single socket: ``` before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.540 ms; bf16: 2.030 ms after: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.527 ms; bf16: 0.458 ms ``` * avx2 single core: ``` before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 4.349 ms; bf16: 19.252 ms after: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 4.416 ms; bf16: 3.524 ms ``` [ghstack-poisoned]
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks
|
@pytorchbot merge -g |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Merge failed due to Matched rule superuser, but PR #81849 has not been reviewed yet |
|
@pytorchbot merge |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Merge failed due to This PR is too stale; the last push date was more than 3 days ago. Please rebase and try again. |
|
@mingfeima looks like you'll need to rebase and ping @ezyang for a land |
|
One can rebase using the rebase command of the mergebot |
|
@pytorchbot merge -f |
|
❌ 🤖 pytorchbot command failed: Try |
|
@pytorchbot merge -f "This codepath is unlikely to change recently" |
|
@pytorchbot successfully started a merge job. Check the current status here |
|
Hey @mingfeima. |
|
@pytorchbot revert -c weird "Revert as caused perf regression, see pytorch/benchmark#1099" |
|
❌ 🤖 pytorchbot command failed: Try |
|
@pytorchbot revert -c weird -m "Revert as caused perf regression, see pytorch/benchmark#1099" |
|
@pytorchbot successfully started a revert job. Check the current status here. |
|
@mingfeima your PR has been successfully reverted. |
…1850)" This reverts commit 2fe3ea6. Reverted #81850 on behalf of https://github.com/malfet due to Revert as caused perf regression, see pytorch/benchmark#1099
…81850) Summary: To fix #77507 Originally `utils::RowwiseMoments<BFloat16>` will still accululate on BFloat16, which is not only slow but also introducing additional rounding errors. This patch will do accumulation on float for the bfloat16 inputs: each of bfloat16 vec (size 16) will be converted to two float vec (size 8), and accumulated on m1(mean) and m2(rstd) vecs which are all float vecs. No effect on float performance, will improve bfloat16 performance: * avx512 single socket: ``` before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.210 ms; bf16: 0.770 ms after: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.215 ms; bf16: 0.178 ms ``` * avx512 single core: ``` before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.661 ms; bf16: 12.267 ms after: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 2.618 ms; bf16: 2.309 ms ``` * avx2 single socket: ``` before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.540 ms; bf16: 2.030 ms after: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 0.527 ms; bf16: 0.458 ms ``` * avx2 single core: ``` before: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 4.349 ms; bf16: 19.252 ms after: LayerNorm((1024,), eps=1e-05, elementwise_affine=True) : 32x128x1024: fp32: 4.416 ms; bf16: 3.524 ms ``` Pull Request resolved: #81850 Approved by: https://github.com/ezyang, https://github.com/malfet Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/2fe3ea65c2b9147077ea3a3dc4757f1768483ba4 Reviewed By: seemethere Differential Revision: D38600344 fbshipit-source-id: 63929b302c9c0adc1ec7fc2ecd3416e3cff72cb5
…1850)" Summary: This reverts commit 2fe3ea6. Reverted #81850 on behalf of https://github.com/malfet due to Revert as caused perf regression, see pytorch/benchmark#1099 Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/7e6da2fb1048392cec2eb163c8ebf98625a0d468 Reviewed By: seemethere Differential Revision: D38643463 fbshipit-source-id: bf4069be8487591a83b0b4f619e03286142a6698
Stack from ghstack:
To fix #77507
Originally
utils::RowwiseMoments<BFloat16>will still accululate on BFloat16,which is not only slow but also introducing additional rounding errors.
This patch will do accumulation on float for the bfloat16 inputs:
each of bfloat16 vec (size 16) will be converted to two float vec (size 8),
and accumulated on m1(mean) and m2(rstd) vecs which are all float vecs.
No effect on float performance, will improve bfloat16 performance: