Skip to content

[ROCm] Enable group gemm through CK#166334

Closed
jagadish-amd wants to merge 17 commits intopytorch:mainfrom
jagadish-amd:group_gemm_ck
Closed

[ROCm] Enable group gemm through CK#166334
jagadish-amd wants to merge 17 commits intopytorch:mainfrom
jagadish-amd:group_gemm_ck

Conversation

@jagadish-amd
Copy link
Contributor

@jagadish-amd jagadish-amd commented Oct 27, 2025

Fixes #161366
All the 4 types of dimension matrix are supported.
2d-2d, 2d-3d, 3d-3d, 3d-2d. The corresponding test cases in test_matmul_cuda are working
for both forward and backward pass.
The CK path is enabled for gfx942, gfx950.
ToDo: Need to enable support on gfx90a since the ck kernel used in this commit produces gpu error,
might require a different CK kernel config, based on the profiler result on gfx90a.

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Also added comments.

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
python test/test_matmul_cuda.py -v -k test_grouped_gemm_2d_3d
Ran 24 tests in 5.566s
OK

python test/test_matmul_cuda.py -v -k test_grouped_gemm_2d_2d
Ran 24 tests in 5.537s

OK

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
All test cases are passing with forward and backward pass.

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 27, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166334

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures

As of commit 5df10f1 with merge base b4e4ee8 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@jagadish-amd jagadish-amd marked this pull request as draft October 27, 2025 20:07
@jagadish-amd jagadish-amd changed the title Group gemm ck [ROCm] Enable group gemm through CK Oct 27, 2025
@pytorch-bot pytorch-bot bot added the module: rocm AMD GPU support for Pytorch label Oct 27, 2025
@jagadish-amd
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label Oct 27, 2025
@jagadish-amd
Copy link
Contributor Author

cc @jeffdaily @pruthvistony

@jeffdaily jeffdaily added release notes: rocm mandatorylabel ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 and removed topic: not user facing topic category labels Oct 27, 2025
@jagadish-amd
Copy link
Contributor Author

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@pytorch-bot pytorch-bot bot removed ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 labels Oct 28, 2025
@atalman
Copy link
Contributor

atalman commented Oct 29, 2025

Here are the errors:

ck_group_gemm.hip:129:13: error: missing field 'stride_Ds_' initializer [-Werror,-Wmissing-field-initializers]
  129 |             });
      |             ^

ck_group_gemm.hip:420:13: note: in instantiation of function template specialization 'at::hip::detail::launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, unsigned short>' requested here
  420 |             launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
      |             ^

error: missing field 'stride_Ds_' initializer [-Werror,-Wmissing-field-initializers]

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
@pytorch-bot pytorch-bot bot removed the ciflow/trunk Trigger trunk jobs on your pull request label Oct 29, 2025
@jagadish-amd
Copy link
Contributor Author

Here are the errors:

ck_group_gemm.hip:129:13: error: missing field 'stride_Ds_' initializer [-Werror,-Wmissing-field-initializers]
  129 |             });
      |             ^

ck_group_gemm.hip:420:13: note: in instantiation of function template specialization 'at::hip::detail::launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, unsigned short>' requested here
  420 |             launch_grouped_bgemm_ck_impl_dispatch<ck::tensor_layout::gemm::RowMajor, ck::tensor_layout::gemm::RowMajor, CkTypes::BF16>(mat_a, mat_b, offs, out);
      |             ^

looking at the errors, it's due to
"missing field 'stride_Ds_' initializer"
Pushed a commit which initializes the field with empty vector.

@jerryzh168 jerryzh168 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 30, 2025
@jeffdaily
Copy link
Collaborator

@atalman would you like to do a Meta import to verify our changes and prevent more reverts?

@meta-codesync
Copy link

meta-codesync bot commented Oct 30, 2025

@atalman has imported this pull request. If you are a Meta employee, you can view this in D85904819.

BoyuanFeng pushed a commit that referenced this pull request Oct 31, 2025
Fixes #161366
All the 4 types of dimension matrix are supported.
2d-2d, 2d-3d, 3d-3d, 3d-2d. The corresponding test cases in test_matmul_cuda are working
for both forward and backward pass.
The CK path is enabled for gfx942, gfx950.
ToDo: Need to enable support on gfx90a since the ck kernel used in this commit produces gpu error,
might require a different CK kernel config, based on the profiler result on gfx90a.

Pull Request resolved: #166334
Approved by: https://github.com/jeffdaily, https://github.com/pruthvistony
BoyuanFeng pushed a commit that referenced this pull request Oct 31, 2025
This reverts commit 1fa520e.

Reverted #166334 on behalf of https://github.com/atalman due to Internal build failures ([comment](#166334 (comment)))
@jeffdaily
Copy link
Collaborator

@atalman how did the internal import go?

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Copy link
Contributor

@atalman atalman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm. Internal change is clear. Please fix lint:
Lint / lintrunner-pyrefly-all / linux-job (gh)

Lint for torch/utils/tensorboard/writer.py:

@jagadish-amd
Copy link
Contributor Author

@jeffdaily the lint errors are not related to my changes, can we rerun the CI and attempt to merge.

@jeffdaily
Copy link
Collaborator

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Rebase failed due to Command git -C /home/runner/work/pytorch/pytorch rebase refs/remotes/origin/viable/strict pull/166334/head returned non-zero exit code 1

Rebasing (1/15)
Auto-merging aten/src/ATen/native/cuda/Blas.cpp
CONFLICT (content): Merge conflict in aten/src/ATen/native/cuda/Blas.cpp
error: could not apply a5934fc219e... Initial group gemm through CK code
hint: Resolve all conflicts manually, mark them as resolved with
hint: "git add/rm <conflicted_files>", then run "git rebase --continue".
hint: You can instead skip this commit: run "git rebase --skip".
hint: To abort and get back to the state before "git rebase", run "git rebase --abort".
hint: Disable this message with "git config set advice.mergeConflict false"
Could not apply a5934fc219e... # Initial group gemm through CK code

Raised by https://github.com/pytorch/pytorch/actions/runs/19087534422

@jeffdaily jeffdaily added ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-mi355 Trigger "default" config CI on ROCm MI355 runners labels Nov 5, 2025
@jeffdaily
Copy link
Collaborator

@pytorchbot merge -f "lint good, rocm good, upstream broke mi355 tests, merging anyway, fixing separately in #167066"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Comment on lines +680 to +684
#ifndef USE_ROCM
at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out);
#else
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#endif
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks builds on Windows, where CK is currently not enabled: ROCm/TheRock#2054

2025-11-07T07:51:12.4977670Z [7076/7084] Linking CXX shared library bin\torch_hip.dll
2025-11-07T07:51:12.4977863Z FAILED: [code=4294967295] bin/torch_hip.dll lib/torch_hip.lib 
2025-11-07T07:51:12.4979510Z C:\Windows\system32\cmd.exe /C "cd . && C:\home\runner\_work\_tool\Python\3.12.10\x64\Lib\site-packages\cmake\data\bin\cmake.exe -E vs_link_dll --msvc-ver=1944 --intdir=caffe2\CMakeFiles\torch_hip.dir --rc="C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\rc.exe" --mt="C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\x64\mt.exe" --manifests  -- C:\home\runner\_work\_tool\Python\3.12.10\x64\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp  /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO && cd ."
2025-11-07T07:51:12.4980262Z LINK: command "C:\home\runner\_work\_tool\Python\3.12.10\x64\Lib\site-packages\_rocm_sdk_devel\lib\llvm\bin\lld-link.exe /nologo @CMakeFiles\torch_hip.rsp /out:bin\torch_hip.dll /implib:lib\torch_hip.lib /pdb:bin\torch_hip.pdb /dll /version:0.0 /machine:x64 /ignore:4049 /ignore:4217 /ignore:4099 /INCREMENTAL:NO /MANIFEST:EMBED,ID=2" failed (exit code 1) with the following output:
2025-11-07T07:51:12.4980745Z lld-link: error: undefined symbol: void __cdecl at::hip::detail::group_gemm_ck(class at::Tensor const &, class at::Tensor const &, class std::optional<class at::Tensor> const &, class std::optional<class at::Tensor> const &, class at::Tensor &)
2025-11-07T07:51:12.4980868Z 
2025-11-07T07:51:12.4981526Z >>> referenced by caffe2\CMakeFiles\torch_hip.dir\__\aten\src\ATen\native\hip\GroupedBlas.cpp.obj:(class at::Tensor __cdecl at::native::_grouped_mm_cuda(class at::Tensor const &, class at::Tensor const &, class std::optional<class at::Tensor> const &, class std::optional<class at::Tensor> const &, class std::optional<enum c10::ScalarType>))
2025-11-07T07:51:12.4981649Z 
2025-11-07T07:51:12.4981729Z ninja: build stopped: subcommand failed.

Can this be conditioned on CK being enabled too?

pytorch/CMakeLists.txt

Lines 252 to 253 in 724cd32

cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX OR WIN32" OFF)
cmake_dependent_option(USE_ROCM_CK_GEMM "Use ROCm Composable Kernel for GEMMs" ON "USE_ROCM;NOT WIN32" OFF)

Perhaps this code style?

#if defined(USE_ROCM) && defined(USE_ROCM_CK_GEMM)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::bgemm_internal_ck<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
}
#endif
else {
bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure @ScottTodd , I will raise a PR soon.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A fix-forward is in review at #167403. Can we revert first though? I'm not sure if I have permission to trigger the bot for that...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confirmed that the fix-forward was successful and our nightly release builds are functional again. Latest status update at ROCm/TheRock#2054 (comment).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/rocm Trigger "default" config CI on ROCm ciflow/rocm-mi300 Trigger "default" config CI on ROCm MI300 ciflow/rocm-mi355 Trigger "default" config CI on ROCm MI355 runners Merged module: rocm AMD GPU support for Pytorch open source release notes: rocm mandatorylabel Reverted triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[ROCm] FP8 _scaled_grouped_mm & BF16 _grouped_mm support

8 participants