-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[PyTorch] Existing MHA: fuse the attn_mask addition #73219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Saw a report that this elementwise add is causing overhead. IIUC this is easy to fuse? Differential Revision: [D34160547](https://our.internmc.facebook.com/intern/diff/D34160547/) [ghstack-poisoned]
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 8166c80 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
Saw a report that this elementwise add is causing overhead. IIUC this is easy to fuse? Differential Revision: [D34160547](https://our.internmc.facebook.com/intern/diff/D34160547/) ghstack-source-id: 149670275 Pull Request resolved: #73219
jbschlosser
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Saw a report that this elementwise add is causing overhead. IIUC this is easy to fuse? Differential Revision: [D34160547](https://our.internmc.facebook.com/intern/diff/D34160547/) [ghstack-poisoned]
Saw a report that this elementwise add is causing overhead. IIUC this is easy to fuse? Differential Revision: [D34160547](https://our.internmc.facebook.com/intern/diff/D34160547/) [ghstack-poisoned]
Saw a report that this elementwise add is causing overhead. IIUC this is easy to fuse? Differential Revision: [D34160547](https://our.internmc.facebook.com/intern/diff/D34160547/) [ghstack-poisoned]
Summary: Pull Request resolved: #73219 Saw a report that this elementwise add is causing overhead. IIUC this is easy to fuse? ghstack-source-id: 152549975 Test Plan: CI, review Ran benchmark_transformers.par mha --batch-size 64 --max-sequence-length 128 --avg-sequence-length 256 --large --use-real-data-distribution --use-mask and looked at the PT time number ``` before: B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.24ms, NativePT Time: 1000000000.00ms, HF Time: 1.10ms, PT FLOPS: 59.07TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.46TFLOP/s B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.23ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms, PT FLOPS: 59.57TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.75TFLOP/s B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.24ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms, PT FLOPS: 58.87TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.77TFLOP/s after: B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.22ms, NativePT Time: 1000000000.00ms, HF Time: 1.10ms, PT FLOPS: 60.07TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.51TFLOP/s B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.22ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms, PT FLOPS: 59.80TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.69TFLOP/s B=64, T=128, Half=True, GPU=True, Seed=1234, Padded tokens=54.92%, Use Mask=True PT Time: 1.21ms, NativePT Time: 1000000000.00ms, HF Time: 1.09ms, PT FLOPS: 60.21TFLOP/s, NativePT FLOPS: 0.00TFLOP/s, HF FLOPS: 66.86TFLOP/s ``` Inspected a Kineto trace and confirmed that an elementwise add was fused into baddbmm. Additional opportunity: I see a copy_ inside baddbmm that wasn't happening with the bmm path and I'm not sure why. Perhaps something went wrong with the structured kernels port by ezyang? Reviewed By: ezyang Differential Revision: D34160547 fbshipit-source-id: 78d406fb035e6f3bf13af2c9443a886eada35ac4
Stack from ghstack (oldest at bottom):
Saw a report that this elementwise add is causing overhead. IIUC this is easy to fuse?
Differential Revision: D34160547