-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add BFloat16 support and optimization for mish, hardtanh backward, and silu on CPU #82460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful links
✅ No Failures (1 Pending)As of commit 4bede7f549 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
7c70c00 to
0ee7dd6
Compare
4bede7f to
1272d52
Compare
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/82460
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 534c73b: This comment was automatically generated by Dr. CI and updates every 15 minutes. |
bca91b0 to
3e2ba41
Compare
|
@pytorchbot merge |
Merge failedReason: PR #82460 has not been reviewed yet (Rule superuser) Details for Dev Infra teamRaised by workflow job |
|
Hi @frank-wei , could you please review this PR ? Thank you. |
|
/easycla As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details. This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign. |
|
@CaoE Do you mind make the PR title more specific about what changes you made? The word "activation" sounds too general. |
|
Changed the title, and I will provide more performance numbers later. |
|
Hi @kit1980, could you please view this PR ? Thank you. |
malfet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but see nits
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit
| return (float(self_val) <= min_val || float(self_val) >= max_val) ? float(0) : float(grad_val); | |
| return (float(self_val) <= min_val || float(self_val) >= max_val) ? BFloat16(0) :grad_val; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit
| const Vectorized<float> kOneVec(float(1)); | |
| const Vectorized<float> kOneVec(1.0f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit
| return float(x) / (float(1) + std::exp(-float(x))); | |
| return float(x) / (1.0f + std::exp(-float(x))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit
| const Vectorized<float> kOneVec(float(1)); | |
| const Vectorized<float> kOneVec(1.0f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| float(1) / (float(1) + std::exp(-float(x))); | |
| return dy * sigmoid * (float(1) + x * (float(1) - sigmoid)); | |
| 1.0f / (1.0f + std::exp(-float(x))); | |
| return dy * sigmoid * (1.0f + x * (1.0f - sigmoid)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| const Vec kOneVec(float(1)); | |
| const Vec kOneVec(1.0f); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| float(1) / (float(1) + std::exp(-float(x))); | |
| 1.0f / (1.0f + std::exp(-float(x))); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| return dy * (tanh_softplus + x * sigmoid * (float(1) - tanh_softplus * tanh_softplus)); | |
| return dy * (tanh_softplus + x * sigmoid * (1.0f - tanh_softplus * tanh_softplus)); |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…d silu on CPU (pytorch#82460) ### Description * add BFloat16 support for mish and hardtanh backward on CPU. * optimize the performance for silu ### Testing - optimize the performance for silu: bfloat16 single socket (28 cores): ``` before: 1x128x1024 forward 0.090 s backward 0.218 s 10x128x1024 forward 0.146 s backward 0.314 s after: 1x128x1024 forward 0.064 s backward 0.100 s 10x128x1024 forward 0.085 s backward 0.133 s ``` single core: ``` before: 1x128x1024 forward 0.300 s backward 0.606 s 10x128x1024 forward 2.825 s backward 5.834 s after: 1x128x1024 forward 0.156 s backward 0.239 s 10x128x1024 forward 1.447 s backward 2.165 s ``` - Add BFloat16 support for mish and backward of hardtanh on CPU. single socket (20 cores): op | shape | fp32 / s | fp32 / s | bf16 / s | bf16 / s -- | -- | -- | -- | -- | -- | | forward | backward | forward | backward silu | [10, 128, 10, 10] | 4.41E-05 | 7.67E-05 | 5.32E-05 | 9.38E-05 | [10, 128, 80, 80] | 0.0008 | 0.001788 | 0.00067 | 0.001031 mish | [10, 128, 10, 10] | 0.000356 | 0.000427 | 0.000367 | 0.000436 | [10, 128, 80, 80] | 0.004527 | 0.005807 | 0.004757 | 0.005393 hardtanh | [10, 128, 10, 10] | / | 3.97E-05 | / | 4.45E-05 | [10, 128, 80, 80] | / | 0.001748 | / | 0.000645 single core: op | shape | fp32 / s | fp32 / s | bf16 / s | bf16 / s -- | -- | -- | -- | -- | -- | | forward | backward | forward | backward silu | [10, 128, 10, 10] | 1.17E-04 | 1.91E-04 | 1.35E-04 | 2.23E-04 | [10, 128, 80, 80] | 0.007434 | 0.013141 | 0.008464 | 0.013044 mish | [10, 128, 10, 10] | 0.00103 | 0.00122 | 0.00106 | 0.001227 | [10, 128, 80, 80] | 0.065629 | 0.078418 | 0.067779 | 0.077214 hardtanh | [10, 128, 10, 10] | / | 1.18E-04 | / | 9.30E-05 | [10, 128, 80, 80] | / | 0.010773 | / | 0.005834 Pull Request resolved: pytorch#82460 Approved by: https://github.com/mingfeima, https://github.com/malfet
Description
Testing
single socket (28 cores):
single core:
single socket (20 cores):
single core:
cc @VitalyFedyunin @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10