Add scaled_grouped_mm_v2 and python API#165154
Add scaled_grouped_mm_v2 and python API#165154slayton58 wants to merge 9 commits intogh/slayton58/27/basefrom
Conversation
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165154
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit 319da7a with merge base 3a110c9 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 7c9c6b3 Pull Request resolved: #165154
|
@pytorchbot label "release notes: quantization" |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 92775cb Pull Request resolved: #165154 Signed-off-by: Simon Layton <simonlayton@meta.com>
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 92775cb Pull Request resolved: #165154 Signed-off-by: Simon Layton <simonlayton@meta.com>
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 70b4b2d Pull Request resolved: #165154 Signed-off-by: Simon Layton <simonlayton@meta.com>
aten/src/ATen/native/cuda/Blas.cpp
Outdated
|
|
||
| // NOTE(slayton): For sub-1B formats want contraction_dim argument? | ||
| if (!a_is_2d || !b_is_2d) { | ||
| TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match"); |
There was a problem hiding this comment.
@danielvegamyhre In real-world use-cases, do you end up needing to pass any transposed inputs? If so, I think we want a contraction_dim argument, as .t() becomes non-"free" for sub-1B formats (like e2m1x2)
There was a problem hiding this comment.
(sorry missed this comment somehow) - discussed this offline but for posterity's sake:
- fbgemm mx8mx8bf16 grouped mm API requires the B tensor be non-transposed (e.g., E,N,K)
- torch._scaled_grouped_mm requires the B tensor be pre-transposed (e.g., E,K,N) so before dispatching to fbgemm we do a
B.transpose(-2, -1)
If so, I think we want a contraction_dim argument, as .t() becomes non-"free" for sub-1B formats (like e2m1x2)
Could we just update fbgemm to have consistent API with torch._scaled_grouped_mm, accepting B as pre-transposed? Would need to take a look at how this affects perf but it should theoretically be fine I believe
There was a problem hiding this comment.
In this case that'd do it, but in general we do want to be able to support the full matrix of potential contractions (just like regular gemm does with it's T/NT modes), so both cases should ideally work.
aten/src/ATen/native/cuda/Blas.cpp
Outdated
| // MXFP8 expects float8_e8m0fnu scales. | ||
| TORCH_CHECK_VALUE(scale_a[0].scalar_type() == at::kFloat8_e8m0fnu && scale_b[0].scalar_type() == at::kFloat8_e8m0fnu, | ||
| "For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors."); | ||
| TORCH_CHECK_VALUE(swizzle_a_enum[0] == SwizzleType::SWIZZLE_32_4_4 && swizzle_b_enum[0] == SwizzleType::SWIZZLE_32_4_4, |
There was a problem hiding this comment.
Note to self: ROCM doesn't need swizzle afaik
There was a problem hiding this comment.
cc @jeffdaily @petrex can you confirm? I read through https://rocm.blogs.amd.com/software-tools-optimization/matrix-cores-cdna/README.html
but still wasn't 100% sure
There was a problem hiding this comment.
My reading is the same as yours @drisspg, and I'll make the appropriate change in the code, but a confirmation would be very welcome!
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 1e34dca Pull Request resolved: #165154 Signed-off-by: Simon Layton <simonlayton@meta.com>
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 458ff5e Pull Request resolved: #165154 Signed-off-by: Simon Layton <simonlayton@meta.com>
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 498808e Pull Request resolved: #165154 Signed-off-by: Simon Layton <simonlayton@meta.com>
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 4e65740 Pull Request resolved: #165154 Signed-off-by: Simon Layton <simonlayton@meta.com>
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 4e65740 Pull Request resolved: #165154 Signed-off-by: Simon Layton <simonlayton@meta.com>
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: c39fb9a Pull Request resolved: #165154 Signed-off-by: Simon Layton <simonlayton@meta.com>
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 554cdd1 Pull Request resolved: #165154 Signed-off-by: Simon Layton <simonlayton@meta.com>
|
Successfully rebased |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: * Add `torch._scaled_grouped_mm_v2` with more functionality and extensibility for future formats * Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint * Test both original and v2 functionality Test Plan: ``` pytest -svv -k grouped test/test_scaled_matmul_cuda.py ``` Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> Pull Request resolved: pytorch#165154 Approved by: https://github.com/drisspg, https://github.com/danielvegamyhre
Stack from ghstack (oldest at bottom):
Summary:
torch._scaled_grouped_mm_v2with more functionality andextensibility for future formats
torch.nn.functional.scaled_grouped_mmas public entrypointTest Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Signed-off-by: Simon Layton simonlayton@meta.com
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki