[xpu][feature] [3/3] Register the scaled_mm and scaled_mm_v2 for xpu#166056
[xpu][feature] [3/3] Register the scaled_mm and scaled_mm_v2 for xpu#166056Stonepia wants to merge 20 commits intopytorch:mainfrom
scaled_mm and scaled_mm_v2 for xpu#166056Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166056
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit bf9795f with merge base a7dc6da ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "module: xpu" |
Attention! native_functions.yaml was changedIf you are adding a new function or defaulted argument to native_functions.yaml, you cannot use it from pre-existing Python frontend code until our FC window passes (two weeks). Split your PR into two PRs, one which adds the new C++ functionality, and one that makes use of it from Python, and land them two weeks apart. See https://github.com/pytorch/pytorch/wiki/PyTorch's-Python-Frontend-Backward-and-Forward-Compatibility-Policy#forwards-compatibility-fc for more info. Caused by: |
Attention! PyTorch one of the C-stable API file was changedYou MUST NOT change existing function declarations in this, as this header defines a stable C ABI. If you need to change the signature for a function, introduce a new v2 version of the function and modify code generation to target the new version of the function. Caused by: |
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
817358f to
26189e3
Compare
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Rebase failed due to Command Raised by https://github.com/pytorch/pytorch/actions/runs/19121565673 |
9b9f12f to
e940420
Compare
scaled_mm for xpuscaled_mm for xpu
|
This functionality is great! The It's not a big difference in API, the main change is that scaling types are explicitly passed to the API, rather than inferred from the input & scale shapes. I'd be happy to talk through the necessary differences if you need. |
thanks for the suggestion! I will refactor the code to support v2 version. Originally, I thought that v2 is not stabled enough, so there are syncing efforts because of code changes. |
test/test_scaled_matmul_cuda.py
Outdated
| out_fp8.to(torch.float32), torch.full((M, N), K * (fill_value**2), device=device) | ||
| ) | ||
|
|
||
| @skipXPU |
There was a problem hiding this comment.
Since this UT has been decorated with onlyCUDA, is skipXPU necessary?
There was a problem hiding this comment.
Thanks for the suggestion! Removed.
test/test_scaled_matmul_cuda.py
Outdated
| lambda: scaled_mm_wrap(x, y, scale_a, scale_b, out_dtype=torch.float32), | ||
| ) | ||
|
|
||
| @skipXPU |
There was a problem hiding this comment.
SM100OrLater should have covered skipXPU, right?
There was a problem hiding this comment.
Yes, I removed those skipXPU for smaller code changes.
test/test_scaled_matmul_cuda.py
Outdated
| if not _device_supports_scaled_mm_fp8(device) or (not torch.xpu.is_available() and IS_WINDOWS): | ||
| raise unittest.SkipTest(f8_msg) | ||
| if not torch.xpu.is_available() and not SM89OrLater: | ||
| raise unittest.SkipTest("rowwise implementation is currently sm89-sm100 specific") | ||
| if torch.xpu.is_available() and use_fast_accum: | ||
| raise unittest.SkipTest("XPU does not support fast accum yet") | ||
|
|
There was a problem hiding this comment.
The two lines of code could be replaced by
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@skipCUDAIf(Windows, f8_msg)
@skipCUDAIf(not SM89OrLater, "rowwise implementation is currently sm89-sm100 specific")
@skipXPUIf(not use_fast_accum, "XPU does not support fast accum yet")
test/test_scaled_matmul_cuda.py
Outdated
| if not _device_supports_scaled_mm_fp8(device) or (not torch.xpu.is_available() and IS_WINDOWS): | ||
| raise unittest.SkipTest(f8_msg) | ||
| if not torch.xpu.is_available() and not SM89OrLater: | ||
| raise unittest.SkipTest("rowwise implementation is currently sm89-sm100 specific") | ||
|
|
There was a problem hiding this comment.
ditto. Please refine the test a little bit.
test/test_scaled_matmul_cuda.py
Outdated
| output_dtype | ||
| ) | ||
|
|
||
| @skipXPU |
There was a problem hiding this comment.
| @skipXPU |
test/test_scaled_matmul_cuda.py
Outdated
| def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile, device) -> None: | ||
| if not _device_supports_scaled_mm_fp8(device): | ||
| raise unittest.SkipTest(f8_msg) |
There was a problem hiding this comment.
Are the code changes due to PLATFORM_SUPPORTS_FP8 not supporting XPU?
There was a problem hiding this comment.
Yes, the PLATFORM_SUPPORTS_FP8 only have CUDA. So I wrap all these changes:
if device != "cpu" and torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8:
to a function, so that it only affects this file.
def _device_supports_scaled_mm_fp8(device):
if device not in ['cpu', 'xpu'] and (torch.cuda.is_available() and not PLATFORM_SUPPORTS_FP8):
return False
return True
There was a problem hiding this comment.
Because of this change, the new tests need to be added:
4b5b0d0
These tests mainly include scaled_mm, but also others for FP8.
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
2d93be5 to
8a17cf4
Compare
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
To add the ciflow label This helps ensure we don't trigger CI on this PR until it is actually authorized to do so. Please ping one of the reviewers if you do not have access to approve and run workflows. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Current Failure:
|
Merge failedReason: 1 jobs have failed, first few of them are: xpu / linux-noble-xpu-n-py3.10 / test (default, 5, 12, linux.idc.xpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…xpu (#166056) This PR registers the `scaled_mm` op for XPU support. It does the following: 1. Registered the `_scaled_mm` and `_scaled_mm_v2` op for XPU. 2. Enables XPU tests in `test_scaled_matmul_cuda.py`. 3. Update torch-xpu-ops pin to remove fallback `scaled_mm` to CPU implementation. ## PR Stack: - #165978 : implementation of XPU scaled_mm and oneDNN kernel - #167518 : implementation of XPU scaled_mm_v2 - -> #166056 : Op registration ## Task tracker: We will track all the scaled_mm related tasks in: #167170 Pull Request resolved: #166056 Approved by: https://github.com/EikanWang, https://github.com/slayton58, https://github.com/drisspg
This PR registers the
scaled_mmop for XPU support.It does the following:
_scaled_mmand_scaled_mm_v2op for XPU.test_scaled_matmul_cuda.py.scaled_mmto CPU implementation.PR Stack:
scaled_mmandscaled_mm_v2for xpu #166056 : Op registrationTask tracker:
We will track all the scaled_mm related tasks in: #167170
cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @gujinghui @fengyuan14 @guangyey @chenyang78