Skip to content

[CUDA][Thor] Enable CUTLASS matmuls on Thor#164836

Closed
Aidyn-A wants to merge 1 commit intopytorch:mainfrom
Aidyn-A:add_sm_110a_to_cutlass_matmuls
Closed

[CUDA][Thor] Enable CUTLASS matmuls on Thor#164836
Aidyn-A wants to merge 1 commit intopytorch:mainfrom
Aidyn-A:add_sm_110a_to_cutlass_matmuls

Conversation

@Aidyn-A
Copy link
Collaborator

@Aidyn-A Aidyn-A commented Oct 7, 2025

This PR enables special matmuls on Thor devices. This includes row-wise scaled matmul on fp8 and group gemm on bfloat16.

@Aidyn-A Aidyn-A self-assigned this Oct 7, 2025
@Aidyn-A Aidyn-A requested review from eqy and syed-ahmed as code owners October 7, 2025 13:13
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164836

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 6c569b8 with merge base 2f023bf (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eqy eqy added ciflow/trunk Trigger trunk jobs on your pull request matrix multiplication release notes: cuda release notes category labels Oct 7, 2025
Skylion007
Skylion007 previously approved these changes Oct 7, 2025
const bool sm11x = properties != nullptr && properties->major == 11;

if (sm10x) {
if (sm10x || sm11x) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this also be enabled on sm120x or nah?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I know, sm100 and sm110 are compatible, but sm120 is completely different from those two.

if (sm10x || sm11x) {
if (small){
bf16bf16_grouped_gemm_impl_sm90_sm100<
cutlass::arch::Sm100,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, don't we need a seperate instationation here with

        cutlass::arch::Sm110,

?

Copy link
Collaborator Author

@Aidyn-A Aidyn-A Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, it does not exist in CUTLASS. The cutlass::arch::Sm101 technically exist: https://github.com/NVIDIA/cutlass/blob/a2439551c765c5393aebe557ee75d3a0412d2211/include/cutlass/arch/arch.h#L104-L106
but it is it not used anywhere in CUTLASS. I was not able to compile anything with it.

@Skylion007 Skylion007 dismissed their stale review October 10, 2025 17:53

withdraw until tests are fixed.

@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented Oct 13, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased add_sm_110a_to_cutlass_matmuls onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout add_sm_110a_to_cutlass_matmuls && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the add_sm_110a_to_cutlass_matmuls branch from 5cddc35 to c203e94 Compare October 13, 2025 14:04
@soulitzer soulitzer requested a review from ngimel October 13, 2025 15:35
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Oct 13, 2025
johnnynunez added a commit to dusty-nv/jetson-containers that referenced this pull request Nov 18, 2025
@Aidyn-A Aidyn-A force-pushed the add_sm_110a_to_cutlass_matmuls branch from c203e94 to 6c569b8 Compare November 18, 2025 10:57
@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented Nov 18, 2025

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

eee4017 pushed a commit to eee4017/pytorch that referenced this pull request Nov 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request matrix multiplication Merged open source release notes: cuda release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

7 participants