Skip to content

[xpu][feature] [3/3] Register the scaled_mm and scaled_mm_v2 for xpu#166056

Closed
Stonepia wants to merge 20 commits intopytorch:mainfrom
Stonepia:xpu/register_scaled_mm_xpu
Closed

[xpu][feature] [3/3] Register the scaled_mm and scaled_mm_v2 for xpu#166056
Stonepia wants to merge 20 commits intopytorch:mainfrom
Stonepia:xpu/register_scaled_mm_xpu

Conversation

@Stonepia
Copy link
Contributor

@Stonepia Stonepia commented Oct 22, 2025

This PR registers the scaled_mm op for XPU support.

It does the following:

  1. Registered the _scaled_mm and _scaled_mm_v2 op for XPU.
  2. Enables XPU tests in test_scaled_matmul_cuda.py.
  3. Update torch-xpu-ops pin to remove fallback scaled_mm to CPU implementation.

PR Stack:

Task tracker:

We will track all the scaled_mm related tasks in: #167170

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @gujinghui @fengyuan14 @guangyey @chenyang78

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166056

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit bf9795f with merge base a7dc6da (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@Stonepia
Copy link
Contributor Author

@pytorchbot label "module: xpu"

@pytorch-bot pytorch-bot bot added the module: xpu Intel XPU related issues label Oct 22, 2025
@github-actions
Copy link
Contributor

Attention! native_functions.yaml was changed

If you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info.


Caused by:

@github-actions
Copy link
Contributor

Attention! PyTorch one of the C-stable API file was changed

You MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function.


Caused by:

@Stonepia
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased xpu/register_scaled_mm_xpu onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout xpu/register_scaled_mm_xpu && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the xpu/register_scaled_mm_xpu branch from 817358f to 26189e3 Compare October 27, 2025 08:32
@Stonepia
Copy link
Contributor Author

Stonepia commented Nov 6, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/166056/head returned non-zero exit code 1

Rebasing (1/3)
Rebasing (2/3)
Auto-merging test/test_scaled_matmul_cuda.py
CONFLICT (content): Merge conflict in test/test_scaled_matmul_cuda.py
error: could not apply 89eeddd8e63... Enable `scaled_mm` tests for xpu
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Could not apply 89eeddd8e63... # Enable `scaled_mm` tests for xpu

Raised by https://github.com/pytorch/pytorch/actions/runs/19121565673

@Stonepia Stonepia force-pushed the xpu/register_scaled_mm_xpu branch from 9b9f12f to e940420 Compare November 6, 2025 01:26
@Stonepia Stonepia marked this pull request as ready for review November 6, 2025 01:51
@Stonepia Stonepia changed the title [XPU] [2/2] Register the scaled_mm for xpu [XPU] [Feature] [2/2] Register the scaled_mm for xpu Nov 6, 2025
@slayton58
Copy link
Contributor

This functionality is great!

The torch._scaled_mm API is somewhat deprecated at this point - it would be much much better to tie into torch._scaled_mm_v2 and the python front-end for it, torch.nn.functional.scaled_mm -- this is the public face of scaled gemms, and where work is expected to be in the future.

It's not a big difference in API, the main change is that scaling types are explicitly passed to the API, rather than inferred from the input & scale shapes. I'd be happy to talk through the necessary differences if you need.

@Stonepia
Copy link
Contributor Author

Stonepia commented Nov 7, 2025

This functionality is great!

The torch._scaled_mm API is somewhat deprecated at this point - it would be much much better to tie into torch._scaled_mm_v2 and the python front-end for it, torch.nn.functional.scaled_mm -- this is the public face of scaled gemms, and where work is expected to be in the future.

It's not a big difference in API, the main change is that scaling types are explicitly passed to the API, rather than inferred from the input & scale shapes. I'd be happy to talk through the necessary differences if you need.

thanks for the suggestion! I will refactor the code to support v2 version. Originally, I thought that v2 is not stabled enough, so there are syncing efforts because of code changes.

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Nov 7, 2025
out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device)
)

@skipXPU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this UT has been decorated with onlyCUDA, is skipXPU necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion! Removed.

lambda: scaled_mm_wrap(x, y, scale_a, scale_b, out_dtype=torch.float32),
)

@skipXPU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SM100OrLater should have covered skipXPU, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I removed those skipXPU for smaller code changes.

Comment on lines 1148 to 1154
if not _device_supports_scaled_mm_fp8(device) or (not torch.xpu.is_available() and IS_WINDOWS):
raise unittest.SkipTest(f8_msg)
if not torch.xpu.is_available() and not SM89OrLater:
raise unittest.SkipTest("rowwise implementation is currently sm89-sm100 specific")
if torch.xpu.is_available() and use_fast_accum:
raise unittest.SkipTest("XPU does not support fast accum yet")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two lines of code could be replaced by

@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@skipCUDAIf(Windows, f8_msg)
@skipCUDAIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
@skipXPUIf(not use_fast_accum, "XPU does not support fast accum yet")

Comment on lines 1273 to 1277
if not _device_supports_scaled_mm_fp8(device) or (not torch.xpu.is_available() and IS_WINDOWS):
raise unittest.SkipTest(f8_msg)
if not torch.xpu.is_available() and not SM89OrLater:
raise unittest.SkipTest("rowwise implementation is currently sm89-sm100 specific")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto. Please refine the test a little bit.

output_dtype
)

@skipXPU
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@skipXPU

Comment on lines 1692 to 1694
def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile, device) -> None:
if not _device_supports_scaled_mm_fp8(device):
raise unittest.SkipTest(f8_msg)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the code changes due to PLATFORM_SUPPORTS_FP8 not supporting XPU?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the PLATFORM_SUPPORTS_FP8 only have CUDA. So I wrap all these changes:

if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:

to a function, so that it only affects this file.

def _device_supports_scaled_mm_fp8(device):
    if device not in ['cpu', 'xpu'] and (torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8):
        return False
    return True

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because of this change, the new tests need to be added:
4b5b0d0

These tests mainly include scaled_mm, but also others for FP8.

@Stonepia
Copy link
Contributor Author

@pytorchbot rebase

@Stonepia Stonepia marked this pull request as draft November 11, 2025 05:15
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased xpu/register_scaled_mm_xpu onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout xpu/register_scaled_mm_xpu && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the xpu/register_scaled_mm_xpu branch from 2d93be5 to 8a17cf4 Compare November 11, 2025 05:15
@EikanWang EikanWang added the ciflow/xpu Run XPU CI tasks label Dec 2, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 2, 2025

To add the ciflow label ciflow/trunk please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 2, 2025

To add the ciflow label ciflow/xpu please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@pytorch-bot pytorch-bot bot removed ciflow/trunk Trigger trunk jobs on your pull request ciflow/xpu Run XPU CI tasks labels Dec 2, 2025
@EikanWang EikanWang added the ciflow/xpu Run XPU CI tasks label Dec 2, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 2, 2025

To add the ciflow label ciflow/xpu please first approve the workflows that are awaiting approval (scroll to the bottom of this page).

This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows.

@EikanWang EikanWang added the ciflow/xpu Run XPU CI tasks label Dec 2, 2025
@EikanWang
Copy link
Collaborator

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Dec 2, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@Stonepia
Copy link
Contributor Author

Stonepia commented Dec 2, 2025

Current Failure:

  1. The XPU related is manually stopped, and another failed due to timeout, so need a re-trigger. Hi @EikanWang , may I ask a favor to rerun the failed xpu tests?
  2. The RoCM related test error is unrelated.

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: xpu / linux-noble-xpu-n-py3.10 / test (default, 5, 12, linux.idc.xpu)

Details for Dev Infra team Raised by workflow job

@EikanWang
Copy link
Collaborator

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@Stonepia Stonepia deleted the xpu/register_scaled_mm_xpu branch December 3, 2025 05:10
JacobSzwejbka pushed a commit that referenced this pull request Dec 8, 2025
…xpu (#166056)

This PR registers the `scaled_mm` op for XPU support.

It does the following:
1. Registered the `_scaled_mm` and `_scaled_mm_v2` op for XPU.
2. Enables XPU tests in `test_scaled_matmul_cuda.py`.
3. Update torch-xpu-ops pin to remove fallback `scaled_mm` to CPU implementation.

## PR Stack:
- #165978 : implementation of XPU scaled_mm and oneDNN kernel
- #167518 : implementation of XPU scaled_mm_v2
- -> #166056 : Op registration

## Task tracker:
We will track all the scaled_mm related tasks in: #167170

Pull Request resolved: #166056
Approved by: https://github.com/EikanWang, https://github.com/slayton58, https://github.com/drisspg
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request ciflow/xpu Run XPU CI tasks keep-going Don't stop on first failure, keep running tests until the end Merged open source release notes: inductor (aoti) topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants