feat: Add FP4 (E2M1) KV Cache Support for MHA #12612
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR introduces support for FP4 (float4_e2m1fn_x2) KV caching in Multi-Headed Attention (MHA) e.g., Qwen and GPT-OSS . See #10083, points 1-2, for more context.
Co-authored-by: @yicwang Yichen Wang yichen.wang@bytedance.com
Usage
Motivation and Benefits
Large models often face GPU memory constraints when storing KV cache.
By introducing FP4 quantization with scale buffers, this PR significantly reduces KV memory usage and improves efficiency:
Key Changes
Accuracy tests for KV4 MHA
Qwen3-235B-A22B
gpt-oss-120b
Observation:
Performance Results
Although speed is not the main goal of this PR (will be addressed in #10083 3-2), we ran throughput tests using
torch_nativeto provide reference:Reason for
torch_native:- Other backends (e.g.,
trtllm_mha, Triton attention) have fused kernels for FP8 only, making FP8 faster there.- KV8 lacks a fused kernel on
torch_native, so both KV4 and KV8 are measured on the same backend.Test configuration:
-
--num-prompts: 100–400-
--max-concurrency: 50–200- Unit: Output token throughput (tok/s)
Observation:
Checklist