Skip to content

[max/kernels] Fix MXFP4 dequant matmul on MI300X (CDNA3): use FP8 fnuz dtype#6474

Open
ramineroane wants to merge 1 commit into
modular:mainfrom
ramineroane:fix/mi300x-mxfp4-fp8-fnuz-dtype
Open

[max/kernels] Fix MXFP4 dequant matmul on MI300X (CDNA3): use FP8 fnuz dtype#6474
ramineroane wants to merge 1 commit into
modular:mainfrom
ramineroane:fix/mi300x-mxfp4-fp8-fnuz-dtype

Conversation

@ramineroane
Copy link
Copy Markdown

Summary

mxfp4_dequant_matmul_amd and mxfp4_dequant_grouped_matmul_amd both
hardcode the FP8 working dtype to DType.float8_e4m3fn (the OCP type used
on CDNA4 / MI355X). On CDNA3 / MI300X (gfx942), the AMD MFMA stdlib only
matches DType.float8_e4m3fnuz, so any compile against an MI300X target
fails with an MMA constraint error and the kernels are unusable on that
arch today. This PR routes the FP8 dtype through get_amd_fp8_dtype()
(already exported from std.gpu.compute.mma and already imported by the
single-expert variant), restoring CDNA3 compatibility while keeping the
existing CDNA4 behavior unchanged.

Reproducer

Build the upstream test for either kernel against gfx942:

# Any host with ROCm + MAX nightly + an MI300X (gfx942) target:
pixi run mojo build max/kernels/test/linalg/test_mxfp4_dequant_grouped_matmul_amd.mojo

Observed (today, on MI300X, MAX 26.3.0.dev*):

constraint failed: no valid implementation of mma for
  a=8xfloat8_e4m3fn, b=8xfloat8_e4m3fn, c=4xfloat32, and d=4xfloat32

The single-expert variant (test_mxfp4_dequant_matmul_amd.mojo) compiles
because its FP8 GEMM goes through _matmul_gpu's vendor-BLAS dispatch,
but it then fails at runtime with:

Unhandled exception caught during execution: No algorithm was found!

(hipBLASLt has no algorithm for OCP-FP8 inputs on CDNA3 — the same root
cause from a different layer.)

After this patch both kernels build cleanly for gfx942. The grouped path
dispatches into the existing CDNA3 FP8 MFMA tile, and the single-expert
path produces FNUZ-typed buffers that hipBLASLt does support on CDNA3.

Root cause

The kernel docstring in mxfp4_dequant_matmul_amd.mojo already calls out
the issue:

MI355X (CDNA4) uses float8_e4m3fn; MI300X (CDNA3) uses float8_e4m3fnuz.
The FP8 type is selected at compile time based on the target architecture.

…and the file even imports get_amd_fp8_dtype from std.gpu.compute.mma.
But the next line still hardcodes comptime fp8_type = DType.float8_e4m3fn,
which only holds on CDNA4. The grouped variant makes the same assumption
without the corresponding import. This looks like a small piece of an
otherwise-CDNA4-targeted patch series that didn't quite make the round trip
back to CDNA3.

Fix

Replace the hardcoded DType.float8_e4m3fn in both kernels with a call to
the existing get_amd_fp8_dtype() helper from std.gpu.compute.mma. That
helper already encodes the right policy (FNUZ on CDNA3, OCP on CDNA4+) and
is what the AMD MMA stdlib itself uses to decide which intrinsics to emit,
so the dequant buffer dtype, the activation-cast destination dtype, and
the MMA dispatch dtype now agree by construction. The grouped variant gets
a one-line import added; the single-expert variant already had the import.
The change is +3 / -2 lines across two files.

Testing

Locally I verified the patch builds cleanly against gfx942 and that the
grouped GEMM dispatches into the CDNA3 FP8 MFMA tile (the constraint
failure no longer triggers). I am not in a position to run Modular's full
internal correctness/perf harness, so I'd recommend that a maintainer
rerun:

  • max/kernels/test/linalg/test_mxfp4_dequant_grouped_matmul_amd.mojo
  • max/kernels/test/linalg/test_mxfp4_dequant_matmul_amd.mojo

on both gfx942 (MI300X) and gfx950 (MI355X) to confirm correctness on
both arches. No test changes should be needed — the tests already drive
the kernels through the public entry points and the dtype change is
internal.

Future work (out of scope for this PR)

Once the dispatch is functional on CDNA3, the bigger lever is fusing the
dequant into the GEMM tile loop rather than materializing the full FP8
weight buffer first. In a quick standalone bench on MI300X I measured
dequant_mxfp4 running at ~0.6–0.7 TB/s (≈12–14% of the chip's ~5.3 TB/s
peak HBM3) at typical MoE-expert shapes — i.e., the bandwidth-bound
prologue is currently a meaningful fraction of total wall time and the
materialize-then-GEMM strategy is paying for those bytes twice. A fused
"load packed → unpack/scale → MMA" tile would avoid the round trip and
should close most of that gap on CDNA3, where there's no native
mfma.scale.f32.*.f8f6f4 intrinsic to lean on. Happy to follow up in a
separate issue / PR if there's interest; this PR keeps things minimal and
just unblocks the existing path on MI300X.

Checklist

  • PR is small and focused — consider splitting larger changes into a
    sequence of smaller PRs
  • I ran ./bazelw run format to format my changes — I did not run
    bazel locally; the diff is +3 / -2 and stylistically identical to
    surrounding code, so I'm relying on CI to catch any formatter nits.
    Happy to rebase a formatted version on request.
  • I added or updated tests to cover my changes — no new tests; the
    existing test_mxfp4_dequant_matmul_amd.mojo and
    test_mxfp4_dequant_grouped_matmul_amd.mojo already exercise the
    affected code paths and should now pass on gfx942.
  • If AI tools assisted with this contribution, I have included an
    Assisted-by: trailer in my commit message or this PR description
    (see AI Tool Use Policy)

Assisted-by: AI (per Modular AI Tool Use Policy). The diff (2 files, +3 / -2) was reviewed by the human author before submission; root cause analysis and validation were done on actual MI300X hardware.

…z dtype

The mxfp4_dequant_matmul_amd and mxfp4_dequant_grouped_matmul_amd kernels
both hardcode `DType.float8_e4m3fn`, which only works on CDNA4. On CDNA3
(MI300X / gfx942) the AMD MMA stdlib only matches `float8_e4m3fnuz`, so
both kernels fail to compile/dispatch on that arch.

Route the FP8 working dtype through `get_amd_fp8_dtype()` (already
exported from `std.gpu.compute.mma` and already imported by the
single-expert variant) so the dequant buffer dtype, the activation cast
dtype, and the MMA dispatch dtype agree on both CDNA3 and CDNA4.

Assisted-by: AI
@ramineroane ramineroane requested a review from a team as a code owner April 30, 2026 05:11
@github-actions
Copy link
Copy Markdown


Thank you for your submission, we really appreciate it. Like many open-source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution. You can sign the CLA by just posting a Pull Request Comment same as the below format.


I have read the CLA Document and I hereby sign the CLA


You can retrigger this bot by commenting recheck in this Pull Request. Posted by the CLA Assistant Lite bot.

@ramineroane
Copy link
Copy Markdown
Author

I have read the CLA Document and I hereby sign the CLA

@ramineroane
Copy link
Copy Markdown
Author

recheck

@abduld
Copy link
Copy Markdown
Contributor

abduld commented May 13, 2026

!sync

@modularbot modularbot added the imported-internally Signals that a given pull request has been imported internally. label May 13, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

imported-internally Signals that a given pull request has been imported internally. waiting-on-review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants