Skip to content

Conversation

@czhu-cohere
Copy link
Contributor

@czhu-cohere czhu-cohere commented Aug 19, 2025

Purpose

Add support in vLLM for CUTLASS-based W4A8 kernel on Hopper, see example 55 which uses LUT trick to bypass int4 -> bf16 -> fp8 conversion in the GEMM mainloop. This improves the compute-bound performance and allows W4A8 to approach peak FP8 throughput while still maintaining the fast decoding speed of W4A16.

The kernel performs the computation

out = (w_q.to(scale_type) * w_s) @ a * s_a * s_c

where

  • out is the output matrix with type bf16
  • w_q is the packed quantized weight matrix with type int4
  • scale_type is fp8 e4m3
  • w_s is the packed scales with type fp8 e4m3 and group size 128
  • a is the (dynamically quantized) activations with type fp8 e4m3
  • s_a is per-tok activation scales with type fp32
  • s_c is per-channel scales with type fp32

and the per-tok/per-chan scaling is done in the epilogue. Note that zp/activation reordering/smaller group size not supported yet.

There are additional requirements on the layout/encoding of scales and weights, which are handled by two helper routines cutlass_pack_scale_fp8 and cutlass_encode_and_reorder_int4b. The original weights are also expected to be encoded as signed int4, which notably is different from the commonly used int4b8 (though we can losslessly convert between the two - more on that in the Test Plan section).

Kernel

The main file is w4a8_mm_entry.cu which implements

  • prepack/weight shuffling routines for scales and weights
  • CUTLASS kernel templates/instantiations
  • dispatch logic based on problem shape

The heuristic used in mm_dispatch was distilled from a sweep over various tile/cluster shapes and problem shapes taken from open source models like Llama 8/70/405B. In aggregate, this heuristic achieves perf within ~1-2% of the best config for each problem shape tested.

The new registered torch ops are

cutlass_w4a8_mm
cutlass_pack_scale_fp8
cutlass_encode_and_reorder_int4b

along with their fake variants for cudagraph.

vLLM Frontend

An example quantization config which will trigger W4A8

  "quantization_config": {
    "config_groups": {
      "group_0": {
        "input_activations": {
          "actorder": null,
          "block_structure": null,
          "dynamic": true,
          "group_size": null,
          "num_bits": 8,
          "observer": null,
          "observer_kwargs": {},
          "strategy": "token",
          "symmetric": true,
          "type": "float"
        },
        "output_activations": null,
        "targets": [
          "Linear"
        ],
        "weights": {
          "actorder": "weight",
          "block_structure": null,
          "dynamic": false,
          "group_size": 128,
          "num_bits": 4,
          "observer": "minmax",
          "observer_kwargs": {},
          "strategy": "group",
          "symmetric": true,
          "type": "int"
        }
      }
    },
    "format": "pack-quantized",
    "global_compression_ratio": null,
    "ignore": [
      "lm_head"
    ],
    "kv_cache_scheme": null,
    "quant_method": "compressed-tensors",
    "quantization_status": "compressed"
  },

Basically, the weights are pack-quantized (8 4-bit values packed to int32), group size is 128, and activations are quantized to 8 bits (fp8 e4m3) with dynamic scaling. Both weight and activation quantization are symmetric.

_is_fp8_w4a8_sm90 checks the config/device compatible with w4a8, and returns the CompressedTensorsW4A8Fp8 scheme.

vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py
implements CutlassW4A8LinearKernel which wraps the per-tok activation quant + w4a8 op and calls the pre-processing routines.

Test Plan

Kernel

pytest tests/kernels/quantization/test_cutlass_w4a8.py - tests kernel correctness + cudagraph. Note we use fp8 with fast accumulate for the reference computation as suggested in CUTLASS upstream.

perf benchmark against w4a16 (machete), fp16 and cutlass fp8

python3 kernels/benchmark_machete.py --act-type float8_e4m3fn --group-scale-type float16 --out-type float16 --channel-scale-type float --token-scale-type float model_bench

E2E

baseline: w4a16 checkpoint for CohereLabs/c4ai-command-a-03-2025 (111b dense model)
generate w4a8 checkpoint by

  • repacking int4b8 as signed int4
  • cast fp16 scales to fp8

e2e perf: run serving benchmark
e2e quality: run lm-eval gsm8k to sanity check

more detailed perf/quality evals pending.

Test Result

pytest tests/kernels/quantization/test_cutlass_w4a8.py - pass

kernel perf benchmarks

| shape                       | torch.matmul (fp16) | cutlass_scaled_mm (fp8) | machete w4a8 | cutlass w4a8 |
|-----------------------------|---------------------|-------------------------|--------------|--------------|
| MKN=(1x12288x14336), L=2    | 270.6               | 129                     | 88.3         | 94.5         |
| MKN=(1x12288x12288), L=2    | 209                 | 114.4                   | 76           | 82.3         |
| MKN=(1x12288x73728), L=1    | 591.8               | 298.2                   | 172.4        | 196.3        |
| MKN=(1x36864x12288), L=1    | 303.2               | 155.2                   | 98.2         | 104.2        |
| MKN=(32x12288x14336), L=2   | 240.4               | 141.7                   | 97.3         | 90.5         |
| MKN=(32x12288x12288), L=2   | 212                 | 126.9                   | 87.7         | 79.7         |
| MKN=(32x12288x73728), L=1   | 601.8               | 314.6                   | 173.7        | 200.4        |
| MKN=(32x36864x12288), L=1   | 302                 | 173.9                   | 96.3         | 105.3        |
| MKN=(64x12288x14336), L=2   | 241.1               | 134.3                   | 93.4         | 86.6         |
| MKN=(64x12288x12288), L=2   | 207                 | 117.5                   | 89.4         | 84.3         |
| MKN=(64x12288x73728), L=1   | 608.1               | 316.2                   | 207.5        | 213.4        |
| MKN=(64x36864x12288), L=1   | 307.3               | 160.2                   | 120.1        | 112.8        |
| MKN=(128x12288x14336), L=2  | 245.2               | 142.9                   | 122.2        | 106.5        |
| MKN=(128x12288x12288), L=2  | 212.5               | 131.4                   | 108.3        | 100.6        |
| MKN=(128x12288x73728), L=1  | 617.8               | 333.8                   | 246.5        | 267.3        |
| MKN=(128x36864x12288), L=1  | 320.3               | 179.7                   | 129.5        | 140.9        |
| MKN=(256x12288x14336), L=2  | 268.8               | 170.8                   | 214.1        | 147.1        |
| MKN=(256x12288x12288), L=2  | 219.3               | 159.3                   | 177.4        | 139.9        |
| MKN=(256x12288x73728), L=1  | 660.2               | 399.1                   | 497          | 377.7        |
| MKN=(256x36864x12288), L=1  | 321.4               | 228.2                   | 239.8        | 196.8        |
| MKN=(512x12288x14336), L=2  | 509.8               | 308.5                   | 382          | 285.4        |
| MKN=(512x12288x12288), L=2  | 417.8               | 240.3                   | 327.3        | 272.3        |
| MKN=(512x12288x73728), L=1  | 1216.9              | 709.2                   | 949.7        | 656.9        |
| MKN=(512x36864x12288), L=1  | 614.1               | 338                     | 468.4        | 386.8        |
| MKN=(1024x12288x14336), L=2 | 982.2               | 560.9                   | 741.9        | 568.2        |
| MKN=(1024x12288x12288), L=2 | 788.6               | 485.3                   | 635.9        | 448.4        |
| MKN=(1024x12288x73728), L=1 | 2381.5              | 1526.2                  | 1846.2       | 1354.2       |
| MKN=(1024x36864x12288), L=1 | 1183.9              | 722.5                   | 932.6        | 615          |
| MKN=(2048x12288x14336), L=2 | 1837.6              | 1150.3                  | 1495.7       | 1033.9       |
| MKN=(2048x12288x12288), L=2 | 1581.2              | 966.3                   | 1283.9       | 893.2        |
| MKN=(2048x12288x73728), L=1 | 4798.4              | 3196.7                  | 3763.7       | 2533.5       |
| MKN=(2048x36864x12288), L=1 | 2346.4              | 1419.3                  | 1861.6       | 1246.7       |
| MKN=(4096x12288x14336), L=2 | 3740.3              | 2501.6                  | 3009.7       | 2161         |
| MKN=(4096x12288x12288), L=2 | 3159.7              | 2098.4                  | 2561.5       | 1831.6       |
| MKN=(4096x12288x73728), L=1 | 9589.9              | 7706.7                  | 7638.7       | 5093.2       |
| MKN=(4096x36864x12288), L=1 | 4670.4              | 3712.6                  | 3767         | 2550.9       |

serving benchmark settings:

  • tp1, tp2
  • max(32, 2*concurrency) prompts
  • 10k in/1k out
  • max concurrency (1, 4, 8, 16, 32)
Screenshot 2025-08-21 at 12 38 38 PM Screenshot 2025-08-21 at 12 38 44 PM Screenshot 2025-08-21 at 12 38 55 PM Screenshot 2025-08-21 at 12 38 49 PM

gsm8k

w4a16: 0.8483699773 (1119/1319)
w4a8: 0.8476118271 (1118/1319)

mmlu_pro

w4a16: 0.6907413564
w4a8: 0.6760305851

TODOs

  • more thorough quality evals
  • llm-compressor changes/recipe to generate signed int4 packed weights + fp8 scales
  • investigate if adding fp32 channel scales helps quality
  • fp16 output type
  • codegen for more tile/cluster shape and schedule configs
  • optimize speed/memory footprint for prepack routines

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Aug 19, 2025
@mergify
Copy link

mergify bot commented Aug 19, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @czhu-cohere.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added performance Performance-related issues needs-rebase labels Aug 19, 2025
@czhu-cohere czhu-cohere force-pushed the cutlass_w4a8 branch 3 times, most recently from 5c2e0d2 to 5a18e66 Compare August 19, 2025 20:28
@mergify mergify bot removed the needs-rebase label Aug 19, 2025
@czhu-cohere czhu-cohere changed the title [wip] support w4a8 on hopper [wip][kernel] Support W4A8 on Hopper Aug 19, 2025
CMakeLists.txt Outdated
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the condition to build is the same as machete, I can also merge them together if that is preferred.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think having them separate for now id fine 👍 keeps the CMakeList more compartmentalized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is saying to not ignore the activation config even if the format. is not in the activation types. are there tests I can refer/scenarios which rely on this behavior? i can further special case if needed.

@dosubot
Copy link

dosubot bot commented Aug 20, 2025

Related Documentation

No published documentation to review for changes on this repository.
Write your first living document

How did I do? Any feedback?  Join Discord

Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
@czhu-cohere czhu-cohere changed the title [wip][kernel] Support W4A8 on Hopper [kernel] Support W4A8 on Hopper Aug 22, 2025
Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work! Thank you for the clean integration follow existing abstractions; its very much appreciated 😄

@LucasWilkinson LucasWilkinson enabled auto-merge (squash) August 23, 2025 17:28
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 23, 2025
@LucasWilkinson LucasWilkinson merged commit e76e233 into vllm-project:main Aug 24, 2025
82 checks passed
johnnynunez pushed a commit to johnnynunez/vllm that referenced this pull request Aug 24, 2025
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: johnnynunez <johnnynuca14@gmail.com>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: Xiao Yu <xiao.yu@amd.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
mengxingkongzhouhan pushed a commit to mengxingkongzhouhan/vllm that referenced this pull request Aug 30, 2025
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Sep 3, 2025
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
ekagra-ranjan pushed a commit to ekagra-ranjan/vllm that referenced this pull request Sep 4, 2025
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
@josiahrohrer
Copy link

Hey, thanks for the PR I just tried with the following scheme
{
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": [
128001,
128008,
128009
],
"head_dim": 128,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 14336,
"max_position_embeddings": 131072,
"mlp_bias": false,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"quantization_config": {
"config_groups": {
"group_0": {
"format": "int-quantized",
"input_activations": {
"actorder": null,
"block_structure": null,
"dynamic": true,
"group_size": null,
"num_bits": 8,
"observer": null,
"observer_kwargs": {},
"strategy": "token",
"symmetric": true,
"type": "int"
},
"output_activations": null,
"targets": [
"Linear"
],
"weights": {
"actorder": null,
"block_structure": null,
"dynamic": false,
"group_size": 128,
"num_bits": 4,
"observer": "minmax",
"observer_kwargs": {},
"strategy": "group",
"symmetric": true,
"type": "int"
}
}
},
"format": "int-quantized",
"global_compression_ratio": null,
"ignore": [
"lm_head"
],
"kv_cache_scheme": null,
"quant_method": "compressed-tensors",
"quantization_status": "compressed",
"sparsity_config": {},
"transform_config": {},
"version": "0.11.0"
},
"rms_norm_eps": 1e-05,
"rope_scaling": {
"factor": 8.0,
"high_freq_factor": 4.0,
"low_freq_factor": 1.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3"
},
"rope_theta": 500000.0,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.55.2",
"use_cache": true,
"vocab_size": 128256
}

and got a key error

(EngineCore_DP0 pid=30408) File "/teamspace/studios/this_studio/vllm/vllm/v1/worker/gpu_worker.py", line 213, in load_model
(EngineCore_DP0 pid=30408) self.model_runner.load_model(eep_scale_up=eep_scale_up)
(EngineCore_DP0 pid=30408) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=30408) File "/teamspace/studios/this_studio/vllm/vllm/v1/worker/gpu_model_runner.py", line 2211, in load_model
(EngineCore_DP0 pid=30408) self.model = model_loader.load_model(
(EngineCore_DP0 pid=30408) ~~~~~~~~~~~~~~~~~~~~~~~^
(EngineCore_DP0 pid=30408) vllm_config=self.vllm_config, model_config=self.model_config)
(EngineCore_DP0 pid=30408) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=30408) File "/teamspace/studios/this_studio/vllm/vllm/model_executor/model_loader/base_loader.py", line 49, in load_model
(EngineCore_DP0 pid=30408) self.load_weights(model, model_config)
(EngineCore_DP0 pid=30408) ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=30408) File "/teamspace/studios/this_studio/vllm/vllm/model_executor/model_loader/default_loader.py", line 263, in load_weights
(EngineCore_DP0 pid=30408) loaded_weights = model.load_weights(
(EngineCore_DP0 pid=30408) self.get_all_weights(model_config, model))
(EngineCore_DP0 pid=30408) File "/teamspace/studios/this_studio/vllm/vllm/model_executor/models/llama.py", line 607, in load_weights
(EngineCore_DP0 pid=30408) return loader.load_weights(
(EngineCore_DP0 pid=30408) ~~~~~~~~~~~~~~~~~~~^
(EngineCore_DP0 pid=30408) self.maybe_remap_mistral(name, loaded_weight)
(EngineCore_DP0 pid=30408) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=30408) for name, loaded_weight in weights)
(EngineCore_DP0 pid=30408) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_DP0 pid=30408) File "/teamspace/studios/this_studio/vllm/vllm/model_executor/models/utils.py", line 291, in load_weights
(EngineCore_DP0 pid=30408) autoloaded_weights = set(self._load_module("", self.module, weights))
(EngineCore_DP0 pid=30408) File "/teamspace/studios/this_studio/vllm/vllm/model_executor/models/utils.py", line 249, in _load_module
(EngineCore_DP0 pid=30408) yield from self._load_module(prefix,
(EngineCore_DP0 pid=30408) child_modules[child_prefix],
(EngineCore_DP0 pid=30408) child_weights)
(EngineCore_DP0 pid=30408) File "/teamspace/studios/this_studio/vllm/vllm/model_executor/models/utils.py", line 222, in _load_module
(EngineCore_DP0 pid=30408) loaded_params = module_load_weights(weights)
(EngineCore_DP0 pid=30408) File "/teamspace/studios/this_studio/vllm/vllm/model_executor/models/llama.py", line 467, in load_weights
(EngineCore_DP0 pid=30408) param = params_dict[name]
(EngineCore_DP0 pid=30408) ~~~~~~~~~~~^^^^^^
(EngineCore_DP0 pid=30408) KeyError: 'layers.0.mlp.down_proj.weight'

any tips on what one would need to change to make it work for this w4a8 quantization scheme?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants