[CUDA][Thor] Enable CUTLASS matmuls on Thor#164836
[CUDA][Thor] Enable CUTLASS matmuls on Thor#164836Aidyn-A wants to merge 1 commit intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164836
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 6c569b8 with merge base 2f023bf ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| const bool sm11x = properties != nullptr && properties->major == 11; | ||
|
|
||
| if (sm10x) { | ||
| if (sm10x || sm11x) { |
There was a problem hiding this comment.
Should this also be enabled on sm120x or nah?
There was a problem hiding this comment.
As far as I know, sm100 and sm110 are compatible, but sm120 is completely different from those two.
| if (sm10x || sm11x) { | ||
| if (small){ | ||
| bf16bf16_grouped_gemm_impl_sm90_sm100< | ||
| cutlass::arch::Sm100, |
There was a problem hiding this comment.
Wait, don't we need a seperate instationation here with
cutlass::arch::Sm110,
?
There was a problem hiding this comment.
Nope, it does not exist in CUTLASS. The cutlass::arch::Sm101 technically exist: https://github.com/NVIDIA/cutlass/blob/a2439551c765c5393aebe557ee75d3a0412d2211/include/cutlass/arch/arch.h#L104-L106
but it is it not used anywhere in CUTLASS. I was not able to compile anything with it.
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
5cddc35 to
c203e94
Compare
c203e94 to
6c569b8
Compare
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This PR enables special matmuls on Thor devices. This includes row-wise scaled matmul on
fp8and group gemm onbfloat16.