[max/kernels] Fix MXFP4 dequant matmul on MI300X (CDNA3): use FP8 fnuz dtype#6474
Open
ramineroane wants to merge 1 commit into
Open
[max/kernels] Fix MXFP4 dequant matmul on MI300X (CDNA3): use FP8 fnuz dtype#6474ramineroane wants to merge 1 commit into
ramineroane wants to merge 1 commit into
Conversation
…z dtype The mxfp4_dequant_matmul_amd and mxfp4_dequant_grouped_matmul_amd kernels both hardcode `DType.float8_e4m3fn`, which only works on CDNA4. On CDNA3 (MI300X / gfx942) the AMD MMA stdlib only matches `float8_e4m3fnuz`, so both kernels fail to compile/dispatch on that arch. Route the FP8 working dtype through `get_amd_fp8_dtype()` (already exported from `std.gpu.compute.mma` and already imported by the single-expert variant) so the dequant buffer dtype, the activation cast dtype, and the MMA dispatch dtype agree on both CDNA3 and CDNA4. Assisted-by: AI
|
I have read the CLA Document and I hereby sign the CLA You can retrigger this bot by commenting recheck in this Pull Request. Posted by the CLA Assistant Lite bot. |
Author
|
I have read the CLA Document and I hereby sign the CLA |
Author
|
recheck |
Contributor
|
!sync |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
mxfp4_dequant_matmul_amdandmxfp4_dequant_grouped_matmul_amdbothhardcode the FP8 working dtype to
DType.float8_e4m3fn(the OCP type usedon CDNA4 / MI355X). On CDNA3 / MI300X (
gfx942), the AMD MFMA stdlib onlymatches
DType.float8_e4m3fnuz, so any compile against an MI300X targetfails with an MMA constraint error and the kernels are unusable on that
arch today. This PR routes the FP8 dtype through
get_amd_fp8_dtype()(already exported from
std.gpu.compute.mmaand already imported by thesingle-expert variant), restoring CDNA3 compatibility while keeping the
existing CDNA4 behavior unchanged.
Reproducer
Build the upstream test for either kernel against
gfx942:# Any host with ROCm + MAX nightly + an MI300X (gfx942) target: pixi run mojo build max/kernels/test/linalg/test_mxfp4_dequant_grouped_matmul_amd.mojoObserved (today, on MI300X, MAX
26.3.0.dev*):The single-expert variant (
test_mxfp4_dequant_matmul_amd.mojo) compilesbecause its FP8 GEMM goes through
_matmul_gpu's vendor-BLAS dispatch,but it then fails at runtime with:
(hipBLASLt has no algorithm for OCP-FP8 inputs on CDNA3 — the same root
cause from a different layer.)
After this patch both kernels build cleanly for
gfx942. The grouped pathdispatches into the existing CDNA3 FP8 MFMA tile, and the single-expert
path produces FNUZ-typed buffers that hipBLASLt does support on CDNA3.
Root cause
The kernel docstring in
mxfp4_dequant_matmul_amd.mojoalready calls outthe issue:
…and the file even imports
get_amd_fp8_dtypefromstd.gpu.compute.mma.But the next line still hardcodes
comptime fp8_type = DType.float8_e4m3fn,which only holds on CDNA4. The grouped variant makes the same assumption
without the corresponding import. This looks like a small piece of an
otherwise-CDNA4-targeted patch series that didn't quite make the round trip
back to CDNA3.
Fix
Replace the hardcoded
DType.float8_e4m3fnin both kernels with a call tothe existing
get_amd_fp8_dtype()helper fromstd.gpu.compute.mma. Thathelper already encodes the right policy (FNUZ on CDNA3, OCP on CDNA4+) and
is what the AMD MMA stdlib itself uses to decide which intrinsics to emit,
so the dequant buffer dtype, the activation-cast destination dtype, and
the MMA dispatch dtype now agree by construction. The grouped variant gets
a one-line import added; the single-expert variant already had the import.
The change is
+3 / -2lines across two files.Testing
Locally I verified the patch builds cleanly against
gfx942and that thegrouped GEMM dispatches into the CDNA3 FP8 MFMA tile (the constraint
failure no longer triggers). I am not in a position to run Modular's full
internal correctness/perf harness, so I'd recommend that a maintainer
rerun:
max/kernels/test/linalg/test_mxfp4_dequant_grouped_matmul_amd.mojomax/kernels/test/linalg/test_mxfp4_dequant_matmul_amd.mojoon both
gfx942(MI300X) andgfx950(MI355X) to confirm correctness onboth arches. No test changes should be needed — the tests already drive
the kernels through the public entry points and the dtype change is
internal.
Future work (out of scope for this PR)
Once the dispatch is functional on CDNA3, the bigger lever is fusing the
dequant into the GEMM tile loop rather than materializing the full FP8
weight buffer first. In a quick standalone bench on MI300X I measured
dequant_mxfp4running at ~0.6–0.7 TB/s (≈12–14% of the chip's ~5.3 TB/speak HBM3) at typical MoE-expert shapes — i.e., the bandwidth-bound
prologue is currently a meaningful fraction of total wall time and the
materialize-then-GEMM strategy is paying for those bytes twice. A fused
"load packed → unpack/scale → MMA" tile would avoid the round trip and
should close most of that gap on CDNA3, where there's no native
mfma.scale.f32.*.f8f6f4intrinsic to lean on. Happy to follow up in aseparate issue / PR if there's interest; this PR keeps things minimal and
just unblocks the existing path on MI300X.
Checklist
sequence of smaller PRs
./bazelw run formatto format my changes — I did not runbazel locally; the diff is
+3 / -2and stylistically identical tosurrounding code, so I'm relying on CI to catch any formatter nits.
Happy to rebase a formatted version on request.
existing
test_mxfp4_dequant_matmul_amd.mojoandtest_mxfp4_dequant_grouped_matmul_amd.mojoalready exercise theaffected code paths and should now pass on
gfx942.Assisted-by:trailer in my commit message or this PR description(see AI Tool Use Policy)
Assisted-by: AI (per Modular AI Tool Use Policy). The diff (2 files, +3 / -2) was reviewed by the human author before submission; root cause analysis and validation were done on actual MI300X hardware.