Skip to content

Conversation

@JackChuang
Copy link
Contributor

Summary

This PR introduces support for FP4 (float4_e2m1fn_x2) KV caching in Multi-Headed Attention (MHA) e.g., Qwen and GPT-OSS . See #10083, points 1-2, for more context.

Co-authored-by: @yicwang Yichen Wang yichen.wang@bytedance.com

Usage

$ python3 -m sglang.launch_server --kv-cache-dtype fp4_e2m1 ... 

Motivation and Benefits

Large models often face GPU memory constraints when storing KV cache.
By introducing FP4 quantization with scale buffers, this PR significantly reduces KV memory usage and improves efficiency:

  • Supports significantly more tokens than KV8 (≈1.78×) and KV16 (≈3.56×) due to FP4 quantization with block_size = 16.
  • Improves scalability for longer context windows and throughput for large batch requests
  • Enables inference of larger models or longer context windows on memory-limited GPUs.
  • Seamless integration with existing inference pipelines without breaking KV16/KV8 workflows.

Key Changes

  • MHATokenToKVPool
    • Added FP4 KV cache support with uint8 storage format.
    • Introduced k_scale_buffer and v_scale_buffer for per-block scaling factors.
    • Integrated batched quantization (on update) and dequantization (on access) using KVFP4QuantizeUtil.
  • ModelRunner
    • Updated GPU memory estimation logic to account for FP4 cache and scale buffers.
  • Compatibility
    • Preserves existing FP16/FP8 KV cache behavior without changes.

Accuracy tests for KV4 MHA

  • FP4 KV cache is well-suited for large-scale models, providing memory savings with minimal accuracy impact.
  • For smaller models, careful evaluation is needed to balance memory efficiency and accuracy.

Qwen3-235B-A22B

Model Dataset Metric Subset Num Score Cat.0
KV4 (fp4_e2m1)
KV4 gsm8k mean_acc main 6595 0.9186 default
KV4 aime25 mean_acc OVERALL 150 0.6 -
KV4 gpqa_diamond mean_acc default 990 0.6778 default
KV8 (fp8_e4m3)
KV8 gsm8k mean_acc main 6595 0.9181 default
KV8 aime25 mean_acc OVERALL 150 0.7333 -
KV8 gpqa_diamond mean_acc default 990 0.6899 default
KV16
KV16 gsm8k mean_acc main 6595 0.9168 default
KV16 aime25 mean_acc OVERALL 150 0.7733 -
KV16 gpqa_diamond mean_acc default 990 0.701 default

gpt-oss-120b

Model Dataset Metric Subset Num Score Cat.0
KV4 (fp4_e2m1)
KV4 aime25 mean_acc OVERALL 150 0.3533 -
KV4 gsm8k mean_acc main 6595 0.9152 default
KV4 gpqa_diamond mean_acc default 990 0.3202 default
KV8 (fp8_e4m3)
KV8 aime25 mean_acc OVERALL 150 0.7667 -
KV8 gsm8k mean_acc main 6595 0.9163 default
KV8 gpqa_diamond mean_acc default 990 0.5434 default
KV16
KV16 aime25 mean_acc OVERALL 150 0.7533 -
KV16 gsm8k mean_acc main 6595 0.9161 default
KV16 gpqa_diamond mean_acc default 990 0.5081 default

Observation:  

  • On large models (Qwen3-235B-A22B), FP4 maintains accuracy close to FP8/FP16.  
  • On smaller models (gpt-oss-120b), FP4 shows more pronounced accuracy drops on difficult datasets.  
  • Trend: Accuracy degradation is more significant in smaller models.

Performance Results

Although speed is not the main goal of this PR (will be addressed in #10083 3-2), we ran throughput tests using torch_native to provide reference:

  • Reason for torch_native:  
      - Other backends (e.g., trtllm_mha, Triton attention) have fused kernels for FP8 only, making FP8 faster there.
      - KV8 lacks a fused kernel on torch_native, so both KV4 and KV8 are measured on the same backend.  

    Note:  KV8 could not run when the attention backend was set to torch_native. We have fixed this problem in PR Support kv8 (FP8) with torch_native attention backend #12596

  • Test configuration:  
      - --num-prompts: 100–400  
      - --max-concurrency: 50–200  
      - Unit: Output token throughput (tok/s)

Num Prompts Concurrency KV8 (tok/s) KV4 (tok/s) Gain TTFT (ms) TPOT (ms)
100 50  62.43 60.35 -3.33% 5323  798 
200 100 67.34 68.02 +1.0%  9378  1480
300 150 68.81 71.63 +4.1%  13500 2172
400 200 69.75 74.19 +6.36% 19595 2685

Observation:  

Checklist

Based on PR sgl-project#10078, this patch
- introduces FP4 KV cache support in MHATokenToKVPool with uint8 storage.
- adds k_scale_buffer and v_scale_buffer to store FP4 scaling factors.
- implements batched quantization on cache update and dequantization on access.
- updates ModelRunner memory estimation to account for FP4 scale buffers.
- maintains backward compatibility with FP16/FP8 KV cache.

Signed-off-by: Ho-Ren (Jack) Chuang <horenchuang@bytedance.com>
Co-authored-by: Yichen Wang <yichen.wang@bytedance.com>
@gemini-code-assist
Copy link
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@JackChuang
Copy link
Contributor Author

Hi @Fridge003 @AniZpZ @zhyncs,
Thank you very much for helping review and merge the PR for MLA KV4 (#10078).
Could you please help review this PR for MLA KV4? Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant