[ROCm] Add scaled_mm v2 support.#165528
Conversation
Add mx fp4 support in Blas.cpp. Modify the tests under test_scaled_matmul_cuda accordingly. PYTORCH_TEST_WITH_ROCM=1 python test/test_scaled_matmul_cuda.py -v -k test_blockwise 115 test passed. Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165528
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 05ddd83 with merge base 53f9ae0 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
cc @jeffdaily |
|
@pytorchbot label "topic: not user facing" |
|
@pytorchbot merge |
|
Pull workflow has not been scheduled for the PR yet. It could be because author doesn't have permissions to run those or skip-checks keywords were added to PR/commits, aborting merge. Please get/give approval for the workflows and/or remove skip ci decorators before next merge attempt. If you think this is a mistake, please contact PyTorch Dev Infra. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Summary: PR #165528 changes some scale/swizzle inference behavior in scaled_mm tests - mxfp8 tests on Blackwell can get incorrectly classified, resulting in failures. Fix the inference code to prevent this. Fixes #165743 Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: @jagadish-amd @jeffdaily @drisspg Subscribers: @Aidyn-A Tasks: Tags: Signed-off-by: Simon Layton <simonlayton@meta.com> [ghstack-poisoned]
Summary: PR #165528 changes some scale/swizzle inference behavior in scaled_mm tests - mxfp8 tests on Blackwell can get incorrectly classified, resulting in failures. Fix the inference code to prevent this. Fixes #165743 Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: jagadish-amd jeffdaily drisspg Subscribers: Aidyn-A Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 3bb8060 Pull Request resolved: #165746
Summary: PR #165528 changes some scale/swizzle inference behavior in scaled_mm tests - mxfp8 tests on Blackwell can get incorrectly classified, resulting in failures. Fix the inference code to prevent this. Fixes #165743 Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: jagadish-amd jeffdaily drisspg Subscribers: Aidyn-A Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> [ghstack-poisoned]
Summary: PR #165528 changes some scale/swizzle inference behavior in scaled_mm tests - mxfp8 tests on Blackwell can get incorrectly classified, resulting in failures. Fix the inference code to prevent this. Fixes #165743 Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: jagadish-amd jeffdaily drisspg Subscribers: Aidyn-A Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> [ghstack-poisoned]
Summary: PR #165528 changes some scale/swizzle inference behavior in scaled_mm tests - mxfp8 tests on Blackwell can get incorrectly classified, resulting in failures. Fix the inference code to prevent this. Fixes #165743 Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: jagadish-amd jeffdaily drisspg Subscribers: Aidyn-A Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 01149f8 Pull Request resolved: #165746
Summary: PR #165528 changes some scale/swizzle inference behavior in scaled_mm tests - mxfp8 tests on Blackwell can get incorrectly classified, resulting in failures. Fix the inference code to prevent this. Fixes #165743 Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: jagadish-amd jeffdaily drisspg Subscribers: Aidyn-A Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> [ghstack-poisoned]
Summary: PR #165528 changes some scale/swizzle inference behavior in scaled_mm tests - mxfp8 tests on Blackwell can get incorrectly classified, resulting in failures. Fix the inference code to prevent this. Fixes #165743 Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: jagadish-amd jeffdaily drisspg Subscribers: Aidyn-A Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 01149f8 Pull Request resolved: #165747
Summary: PR #165528 changes some scale/swizzle inference behavior in scaled_mm tests - mxfp8 tests on Blackwell can get incorrectly classified, resulting in failures. Fix the inference code to prevent this. Fixes #165743 Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: jagadish-amd jeffdaily drisspg Subscribers: Aidyn-A Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> [ghstack-poisoned]
Summary: PR #165528 changes some scale/swizzle inference behavior in scaled_mm tests - mxfp8 tests on Blackwell can get incorrectly classified, resulting in failures. Fix the inference code to prevent this. Fixes #165743 Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: @jagadish-amd @jeffdaily @drisspg Subscribers: @Aidyn-A Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 3e93657 Pull Request resolved: #165747 Signed-off-by: Simon Layton <simonlayton@meta.com>
Summary: PR #165528 changes some scale/swizzle inference behavior in scaled_mm tests - mxfp8 tests on Blackwell can get incorrectly classified, resulting in failures. Fix the inference code to prevent this. Fixes #165743 Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: jagadish-amd jeffdaily drisspg Subscribers: Aidyn-A Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> [ghstack-poisoned]
Summary: PR #165528 changes some scale/swizzle inference behavior in scaled_mm tests - mxfp8 tests on Blackwell can get incorrectly classified, resulting in failures. Fix the inference code to prevent this. Fixes #165743 Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: @jagadish-amd @jeffdaily @drisspg Subscribers: @Aidyn-A Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> ghstack-source-id: 7f834a5 Pull Request resolved: #165747 Signed-off-by: Simon Layton <simonlayton@meta.com>
Summary: PR #165528 changes some scale/swizzle inference behavior in scaled_mm tests - mxfp8 tests on Blackwell can get incorrectly classified, resulting in failures. Fix the scale/swizzle inference code to prevent this. Fixes #165743 Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: @jagadish-amd @jeffdaily @drisspg Subscribers: @Aidyn-A Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> Pull Request resolved: #165747 Approved by: https://github.com/eqy, https://github.com/drisspg, https://github.com/jeffdaily
Add mx fp4 support in Blas.cpp. Updated the scale_kernel_dispatch array and ScaledGemmImplementation enum to include MXFP4 support. Modify the tests under test_scaled_matmul_cuda accordingly. PYTORCH_TEST_WITH_ROCM=1 python test/test_scaled_matmul_cuda.py -v -k test_blockwise 115 test passed. Pull Request resolved: pytorch#165528 Approved by: https://github.com/jeffdaily
Summary: PR pytorch#165528 changes some scale/swizzle inference behavior in scaled_mm tests - mxfp8 tests on Blackwell can get incorrectly classified, resulting in failures. Fix the scale/swizzle inference code to prevent this. Fixes pytorch#165743 Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: @jagadish-amd @jeffdaily @drisspg Subscribers: @Aidyn-A Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> Pull Request resolved: pytorch#165747 Approved by: https://github.com/eqy, https://github.com/drisspg, https://github.com/jeffdaily
Add mx fp4 support in Blas.cpp. Updated the scale_kernel_dispatch array and ScaledGemmImplementation enum to include MXFP4 support. Modify the tests under test_scaled_matmul_cuda accordingly. PYTORCH_TEST_WITH_ROCM=1 python test/test_scaled_matmul_cuda.py -v -k test_blockwise 115 test passed. Pull Request resolved: pytorch#165528 Approved by: https://github.com/jeffdaily
Summary: PR pytorch#165528 changes some scale/swizzle inference behavior in scaled_mm tests - mxfp8 tests on Blackwell can get incorrectly classified, resulting in failures. Fix the scale/swizzle inference code to prevent this. Fixes pytorch#165743 Test Plan: ``` pytest -svv test/test_scaled_matmul_cuda.py ``` Reviewers: @jagadish-amd @jeffdaily @drisspg Subscribers: @Aidyn-A Tasks: Tags: Signed-off-by: Simon Layton <simonlaytonmeta.com> Pull Request resolved: pytorch#165747 Approved by: https://github.com/eqy, https://github.com/drisspg, https://github.com/jeffdaily
Add mx fp4 support in Blas.cpp.
Updated the scale_kernel_dispatch array and ScaledGemmImplementation enum to include MXFP4 support.
Modify the tests under test_scaled_matmul_cuda accordingly.
PYTORCH_TEST_WITH_ROCM=1 python test/test_scaled_matmul_cuda.py -v -k test_blockwise
115 test passed.
cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd