Skip to content

Add scaled_grouped_mm_v2 and python API#165154

Closed
slayton58 wants to merge 9 commits intogh/slayton58/27/basefrom
gh/slayton58/27/head
Closed

Add scaled_grouped_mm_v2 and python API#165154
slayton58 wants to merge 9 commits intogh/slayton58/27/basefrom
gh/slayton58/27/head

Conversation

@slayton58
Copy link
Contributor

@slayton58 slayton58 commented Oct 10, 2025

Stack from ghstack (oldest at bottom):

Summary:

  • Add torch._scaled_grouped_mm_v2 with more functionality and
    extensibility for future formats
  • Add torch.nn.functional.scaled_grouped_mm as public entrypoint
  • Test both original and v2 functionality

Test Plan:

pytest -svv -k grouped test/test_scaled_matmul_cuda.py

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton simonlayton@meta.com

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki

Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165154

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit 319da7a with merge base 3a110c9 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

slayton58 added a commit that referenced this pull request Oct 10, 2025
Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

ghstack-source-id: 7c9c6b3
Pull Request resolved: #165154
@slayton58
Copy link
Contributor Author

@pytorchbot label "release notes: quantization"

@pytorch-bot pytorch-bot bot added the release notes: quantization release notes category label Oct 10, 2025
@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

[ghstack-poisoned]
slayton58 added a commit that referenced this pull request Oct 10, 2025
Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

ghstack-source-id: 92775cb
Pull Request resolved: #165154

Signed-off-by: Simon Layton <simonlayton@meta.com>
slayton58 added a commit that referenced this pull request Oct 10, 2025
Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

ghstack-source-id: 92775cb
Pull Request resolved: #165154

Signed-off-by: Simon Layton <simonlayton@meta.com>
[ghstack-poisoned]
slayton58 added a commit that referenced this pull request Oct 10, 2025
Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

ghstack-source-id: 70b4b2d
Pull Request resolved: #165154

Signed-off-by: Simon Layton <simonlayton@meta.com>
@drisspg drisspg added module: nn Related to torch.nn topic: new features topic category labels Oct 10, 2025

// NOTE(slayton): For sub-1B formats want contraction_dim argument?
if (!a_is_2d || !b_is_2d) {
TORCH_CHECK_VALUE(mat_a.size(-1) == mat_b.size(-2), "contraction dimension of mat_a and mat_b must match");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@danielvegamyhre In real-world use-cases, do you end up needing to pass any transposed inputs? If so, I think we want a contraction_dim argument, as .t() becomes non-"free" for sub-1B formats (like e2m1x2)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(sorry missed this comment somehow) - discussed this offline but for posterity's sake:

  • fbgemm mx8mx8bf16 grouped mm API requires the B tensor be non-transposed (e.g., E,N,K)
  • torch._scaled_grouped_mm requires the B tensor be pre-transposed (e.g., E,K,N) so before dispatching to fbgemm we do a B.transpose(-2, -1)

If so, I think we want a contraction_dim argument, as .t() becomes non-"free" for sub-1B formats (like e2m1x2)

Could we just update fbgemm to have consistent API with torch._scaled_grouped_mm, accepting B as pre-transposed? Would need to take a look at how this affects perf but it should theoretically be fine I believe

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case that'd do it, but in general we do want to be able to support the full matrix of potential contractions (just like regular gemm does with it's T/NT modes), so both cases should ideally work.

// MXFP8 expects float8_e8m0fnu scales.
TORCH_CHECK_VALUE(scale_a[0].scalar_type() == at::kFloat8_e8m0fnu && scale_b[0].scalar_type() == at::kFloat8_e8m0fnu,
"For MXFP8 grouped gemm, both scales must be float8_e8m0fnu tensors.");
TORCH_CHECK_VALUE(swizzle_a_enum[0] == SwizzleType::SWIZZLE_32_4_4 && swizzle_b_enum[0] == SwizzleType::SWIZZLE_32_4_4,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: ROCM doesn't need swizzle afaik

Copy link
Contributor

@drisspg drisspg Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My reading is the same as yours @drisspg, and I'll make the appropriate change in the code, but a confirmation would be very welcome!

[ghstack-poisoned]
slayton58 added a commit that referenced this pull request Oct 13, 2025
Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

ghstack-source-id: 1e34dca
Pull Request resolved: #165154

Signed-off-by: Simon Layton <simonlayton@meta.com>
@albanD albanD removed their request for review October 13, 2025 15:24
[ghstack-poisoned]
slayton58 added a commit that referenced this pull request Oct 13, 2025
Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

ghstack-source-id: 458ff5e
Pull Request resolved: #165154

Signed-off-by: Simon Layton <simonlayton@meta.com>
@drisspg drisspg added the ciflow/rocm Trigger "default" config CI on ROCm label Oct 13, 2025
[ghstack-poisoned]
slayton58 added a commit that referenced this pull request Oct 13, 2025
Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

ghstack-source-id: 498808e
Pull Request resolved: #165154

Signed-off-by: Simon Layton <simonlayton@meta.com>
[ghstack-poisoned]
slayton58 added a commit that referenced this pull request Oct 13, 2025
Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

ghstack-source-id: 4e65740
Pull Request resolved: #165154

Signed-off-by: Simon Layton <simonlayton@meta.com>
slayton58 added a commit that referenced this pull request Oct 13, 2025
Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

ghstack-source-id: 4e65740
Pull Request resolved: #165154

Signed-off-by: Simon Layton <simonlayton@meta.com>
[ghstack-poisoned]
slayton58 added a commit that referenced this pull request Oct 14, 2025
Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

ghstack-source-id: c39fb9a
Pull Request resolved: #165154

Signed-off-by: Simon Layton <simonlayton@meta.com>
@slayton58
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Oct 14, 2025
Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlaytonmeta.com>

ghstack-source-id: 554cdd1
Pull Request resolved: #165154

Signed-off-by: Simon Layton <simonlayton@meta.com>
@pytorchmergebot
Copy link
Collaborator

Successfully rebased gh/slayton58/27/orig onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via ghstack checkout https://github.com/pytorch/pytorch/pull/165154)

@slayton58
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 15, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
Summary:

* Add `torch._scaled_grouped_mm_v2` with more functionality and
  extensibility for future formats
* Add `torch.nn.functional.scaled_grouped_mm` as public entrypoint
* Test both original and v2 functionality

Test Plan:

```
pytest -svv -k grouped test/test_scaled_matmul_cuda.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
Signed-off-by: Simon Layton <simonlayton@meta.com>

Pull Request resolved: pytorch#165154
Approved by: https://github.com/drisspg, https://github.com/danielvegamyhre
@github-actions github-actions bot deleted the gh/slayton58/27/head branch November 15, 2025 02:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/b200 ciflow/h100 ciflow/rocm Trigger "default" config CI on ROCm ciflow/trunk Trigger trunk jobs on your pull request Merged module: nn Related to torch.nn release notes: quantization release notes category topic: new features topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants