Skip to content

Conversation

@whx-sjtu
Copy link
Contributor

@whx-sjtu whx-sjtu commented Aug 21, 2025

Purpose

MLA (Multi-Head Latent Attention) is a complex layer with significant optimization potential, and different hardware backends may require distinct optimization approaches—such as operator fusion, multi-streaming, etc.

The current abstraction in vLLM struggles to accommodate scenarios that require integrating large-grained fused operators from different platforms. For example, on the Ascend platform, we need to integrate a fused operator that combines rms_norm, rope, and concat_and_cache_mla. However, in the current community implementation, rms_norm and rope are called within DeepSeek’s modeling code. Without modifying the model itself, it is impossible to integrate such custom fused operators.

Therefore, we hope to enable MLA as a Custom Op to support customizable extensions for various platforms.

Additionally, considering that other models might also benefit from MLA optimizations—for instance, we recently observed community demand for integrating MLA into models like Qwen—decoupling the MLA layer from modeling would better serve such use cases and facilitate broader adoption.

Test Plan

No need to add new test.

Test Result

all tests should pass

(Optional) Documentation Update


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the deepseek Related to DeepSeek models label Aug 21, 2025
@gemini-code-assist
Copy link
Contributor

Warning

Gemini encountered an error creating the review. You can try again by commenting /gemini review.

@whx-sjtu
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Multi-Head Latent Attention (MLA) into a CustomOp to improve extensibility for different hardware platforms. This is a positive change for modularity. However, the current implementation has some critical issues in the new MultiHeadLatentAttention class. Specifically, it's missing initialization for key attributes, which will cause runtime errors. Additionally, it incorrectly uses total number of heads instead of local heads in tensor-parallel contexts, which will lead to shape mismatches. I've provided detailed comments and suggestions to address these issues.

@whx-sjtu
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the Multi-Head Latent Attention (MLA) implementation into a CustomOp to better support different hardware backends and custom fused operators. A new file vllm/model_executor/layers/mla.py is introduced, which contains the MultiHeadLatentAttention custom op and an MLAModules dataclass to group related modules. The DeepseekV2MLAAttention class in deepseek_v2.py is updated to use this new custom op, simplifying its implementation. The changes are well-structured and align with the goal of improving modularity and extensibility. My main feedback is to add assertions in the new MultiHeadLatentAttention class to ensure that required modules are provided, which will improve robustness and prevent potential runtime errors.

@whx-sjtu whx-sjtu force-pushed the mla_refactor branch 2 times, most recently from 742db74 to 08296ea Compare August 21, 2025 13:33
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if this is the best approach. The CustomOp abstraction is meant as a simple abstraction to dispatch between torch implementations and custom/CUDA/platform kernels. For something as complex as MLA we should not use this.

Could you describe the fusion you want to do in more detail? Could this fusion in vllm-ascend instead be performed using a custom torch.compile pass?

EDIT: I think if you want to extract MLA code from the model code I think that's a great idea. But the custom op abstraction seems like the wrong tool for the job.

@whx-sjtu
Copy link
Contributor Author

I'm not sure if this is the best approach. The CustomOp abstraction is meant as a simple abstraction to dispatch between torch implementations and custom/CUDA/platform kernels. For something as complex as MLA we should not use this.

Could you describe the fusion you want to do in more detail? Could this fusion in vllm-ascend instead be performed using a custom torch.compile pass?

EDIT: I think if you want to extract MLA code from the model code I think that's a great idea. But the custom op abstraction seems like the wrong tool for the job.

Thank you for your suggestion. To help clarify, I've created a diagram to demonstrate an example of one of the fused operators we intend to integrate:
image

@ProExpertProg
Copy link
Collaborator

Thank you for the picture. Honestly this is a common issue with attention, there can be a lot of ops hidden from torch.compile. There are two possible approaches:

  1. Add custom passes that "unfuse" these ops out of attention. The passes would look similar to attention+fp8 quant fusion. This would suffer from conflict with piecewise compilation, but we're working to address that in [RFC]: Address piecewise graph splitting and attention fusion incompatibility #23261.
  2. Add an extra forward layer to AttentionImpl where instead of calling the custom op directly, forward_outer is called on the impl. forward_outer just calls the custom op by default but backends can specialize it with other ops preceeding/succeeding the custom op, so that those ops are shown in .

Both approaches would require modifying the unified_attention op, and I'm not sure which is better, and if either is feasible at all.

Before we go down this rabbit hole, is Inductor supported on vllm-ascend? What about cuda graphs?

@whx-sjtu
Copy link
Contributor Author

whx-sjtu commented Aug 22, 2025

Before we go down this rabbit hole, is Inductor supported on vllm-ascend? What about cuda graphs?

Thanks for your attention. Currently, vLLM-Ascend does not support Inductor. We have enabled a graph mode whose functionality is very similar to CUDA Graph. At present, performing fusion passes in vLLM-Ascend cannot be achieved through the mechanism implemented via Inductor in vLLM. We only have a practice of performing kernel fusion through torch.fx.subgraph_rewriter, but I am not sure whether this functionality possesses the full capabilities of the pattern_matcher in Inductor.

@whx-sjtu
Copy link
Contributor Author

whx-sjtu commented Aug 22, 2025

  1. Add an extra forward layer to AttentionImpl where instead of calling the custom op directly, forward_outer is called on the impl. forward_outer just calls the custom op by default but backends can specialize it with other ops preceeding/succeeding the custom op, so that those ops are shown in .

To ensure I correctly understand the second approach you mentioned, I’ll briefly describe the implementation.
First we add a forward_outer abstract to AttentionImpl:

@abstractmethod
def forward_outer(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    raise NotImplementedError

Then we should add a static method to the CustomOp class to check whether a given op name has already been registered:

@staticmethod
def has_registered(op_name):
    """Check if the class has a registered name."""
    return op_name in {**CustomOp.op_registry, **CustomOp.op_registry_oot}

Finally, in the abstracted MLA layer, we check if the mla_outer op has been registered. If it has, we call the forward_outer method:

def forward(
    self,
    positions: torch.Tensor,
    hidden_states: torch.Tensor,
) -> torch.Tensor:
    if CustomOp.has_registered("mla_outer"):
        return self.mla_attn.impl.forward_outer(positions, hidden_states)

Does this align with what you had in mind? @ProExpertProg

@ProExpertProg
Copy link
Collaborator

Not exactly, but it would still require Inductor fusions anyway.

Does vllm-ascend support (or use) Dynamo at all? What about AotDispatcher? Because we could do passes without Inductor if AotDispatcher is used (so normalized and functional IR is produced).

How are the CUDAGraph-equivalent graphs handled?

@whx-sjtu
Copy link
Contributor Author

Not exactly, but it would still require Inductor fusions anyway.

Does vllm-ascend support (or use) Dynamo at all? What about AotDispatcher? Because we could do passes without Inductor if AotDispatcher is used (so normalized and functional IR is produced).

How are the CUDAGraph-equivalent graphs handled?

vllm-ascend supports and uses Dynamo. We only implement our own piecewise backend to receive and process FX Graph. Our graph mode also works through capture and replay machanism, just like CUDA-Graph.

@LucasWilkinson
Copy link
Collaborator

after chatting with @ProExpertProg I think we agree that making layers pluggable (especially MLA for the reasons you outlined) by different HW backends without the need for torch.compile is desirable; however CustomOP might not be the right mechanism for this (per @ProExpertProg comments)

I will be honest I don't know that much about how the HW pluggin infrastructure works so maybe @youkaichao can weigh in here

@ProExpertProg
Copy link
Collaborator

ProExpertProg commented Aug 22, 2025

Yeah, to elaborate, I think a custom pass mechanism to perform fusion would be good in vllm-ascend because plugging layers and fusing manually inside will quickly become unmaintainable due to duplicated logic. It would also suffer from the same reasons vLLM uses custom passes, and fusions across layers are still going to be difficult.

However, I understand that might be too large an undertaking for this case. I still believe CustomOp is not the right abstraction because it interferes with vllm's custom op enablement mechanism in a slightly ugly way. I think we should come up with a lightweight PluggableLayer abstraction that's transparent in the vllm case but can be replaced OOT. But if @youkaichao thinks we should just use CustomOp I'm okay with that.

@whx-sjtu
Copy link
Contributor Author

Thank you for your input! I completely agree that we need a more lightweight PluggableLayer to support custom model layers—this would significantly enhance the extensibility of the vLLM project through plugins. However, to implement PluggableLayer effectively, we need to conduct more detailed design work. We will subsequently propose an RFC to describe and further discuss this direction. For the current MLA-related issues, I believe we can temporarily address them using the Custom Op approach.

@Yikun
Copy link
Member

Yikun commented Aug 23, 2025

I will be honest I don't know that much about how the HW pluggin infrastructure works

@LucasWilkinson Let me do a simple elaboration to show how hardware plugin work:

0. Register platform: vLLM serve start with plugin register and regiter current platform.

1. Register worker: vLLM call current_platform.check_and_update_config to config worker class.

current_platform.check_and_update_config(self)

2. Init worker: vLLM init worker, this will init platform's worker:

worker_class = resolve_obj_by_qualname(

3. Register Custom ops: vLLM hardware plugin (such as vllm-ascend) worker will call register_ascend_customop to register custom ops.

After all above steps, the ops will be replaced by hardware plugin own ops. Actually, the current custom ops mechanism is relatively flexible and supports registration at the forward, layer, and ops levels.


I think the PluggableLayer actually it is equal to custom ops with layer level + layer stable interface, This is worth describing in a detailed separate RFC (this related to #22082 and our vLLM 1.0 goal).

So, from the perspective of hardware plugin, I think custom ops is an acceptable way to go, actually, it's also the original intention of #19164

@jgong5
Copy link

jgong5 commented Aug 23, 2025

Yeah, to elaborate, I think a custom pass mechanism to perform fusion would be good in vllm-ascend because plugging layers and fusing manually inside will quickly become unmaintainable due to duplicated logic. It would also suffer from the same reasons vLLM uses custom passes, and fusions across layers are still going to be difficult.

However, I understand that might be too large an undertaking for this case. I still believe CustomOp is not the right abstraction because it interferes with vllm's custom op enablement mechanism in a slightly ugly way. I think we should come up with a lightweight PluggableLayer abstraction that's transparent in the vllm case but can be replaced OOT. But if @youkaichao thinks we should just use CustomOp I'm okay with that.

I guess both paths (the frontend pluggable layer abstraction and the backend graph fusion) can be useful. The former is direct and simpler while we have to maintain the compatibility of the layer semantics - therefore, the supported layers had better be stable enough and well-defined, e.g., MLA and fused MoE. The latter is more flexible and general but it depends on the torch.compile graph mode with extra compilation overhead as well as the complexity of graph rewrite.

From vllm-ascend, we are trying to enable some graph fusion passes by extending the vllm core. But it seems that the compilation backend is currently bound to the "inductor" backend which not every HW backend supports. In vllm-ascend, the inductor backend is not yet enabled and we are adding direct FX-based graph fusion. Currently, we have to do monkey-patches to plugin our own passes (see the rmsnorm+quant fusion PR from
@ganyi1996ppo: vllm-project/vllm-ascend#2389). We also want to reuse the existing backend-agnostic FX passes from vllm core, e.g., sequence parallel fusion. But these are all "inductor" passes.

In summary, it would be great if

  1. A HW backend can extend vllm to support custom compiler backend that implements CompilerInterface.
  2. A HW backend with custom compiler backend can reuse existing backend-agnostic FX passes..

Would love to get your inputs. @ProExpertProg

PS: For "2", perhaps we can manually apply the AotDispatch passes and invoke those vllm "inductor" passes from inside the custom compiler backend. Not sure if that would work well though.

@LucasWilkinson
Copy link
Collaborator

@Yikun Thanks for the description! I think the main downside of using CustomOp as the plugin mechanism for layers is that it only allows you to hook into the forward call. I would suspect that for layers we would quickly run into that cases where HW plugin developers may want to do some work ahead-of-time instead of exclusively on the hot path (i.e. re-laying out weights during weight loading, or precomputing some kind of metadata). My guess would be that it would be better to have some kinda of layer abstraction that is separate from CustomOp for this reason.

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not up to date on the intentions behind CustomOp or the HW plugins; so I will defer to the experts for final approval (cc @ProExpertProg @youkaichao ). Left a more MLA related comment (happy to review the MLA related portions; although theres not much change there haha)

@whx-sjtu whx-sjtu force-pushed the mla_refactor branch 3 times, most recently from addc5f9 to c30506d Compare September 1, 2025 07:14
@ProExpertProg
Copy link
Collaborator

@whx-sjtu is this done?

@whx-sjtu
Copy link
Contributor Author

whx-sjtu commented Sep 3, 2025

@whx-sjtu is this done?

Yes, I think we can run ci and merge this PR.

@whx-sjtu whx-sjtu force-pushed the mla_refactor branch 2 times, most recently from 84aa004 to 3f6e479 Compare September 3, 2025 03:43
@Yikun
Copy link
Member

Yikun commented Sep 3, 2025

I review again on this, I also think it's ready to go.

@Yikun Yikun enabled auto-merge (squash) September 3, 2025 08:43
@Yikun Yikun added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 3, 2025
auto-merge was automatically disabled September 3, 2025 09:06

Head branch was pushed to by a user without write access

@whx-sjtu whx-sjtu force-pushed the mla_refactor branch 2 times, most recently from 5d61664 to 7f45a92 Compare September 3, 2025 09:28
@ProExpertProg ProExpertProg enabled auto-merge (squash) September 3, 2025 21:10
auto-merge was automatically disabled September 4, 2025 01:44

Head branch was pushed to by a user without write access

Signed-off-by: whx-sjtu <2952154980@qq.com>
@Yikun
Copy link
Member

Yikun commented Sep 4, 2025

buildkite/ci/pr/distributed-tests-2-gpus failed due to unrelated fix, @whx-sjtu Please don't re-push again, I will trigger the auto-merge check.

If there still CI issue here, we can forge merge this.

@Yikun Yikun enabled auto-merge (squash) September 4, 2025 07:12
@vllm-bot vllm-bot merged commit 3efb9f4 into vllm-project:main Sep 4, 2025
41 of 43 checks passed
eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…#23332)

Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…#23332)

Signed-off-by: whx-sjtu <2952154980@qq.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants