Skip to content

Conversation

@yubofredwang
Copy link
Contributor

@yubofredwang yubofredwang commented Dec 19, 2025

Motivation

See: #14624

In the previous implementation, we did not correctly handle the page_table and lead to IMA because we did not add the first last_partial_page before decode step pages when calculating the expand_metadata's page table. This PR fixes it.

Modifications

Why we can't simply do:

max_pages_expand = decode_length + 2 * self.page_size - 1
page_base = (cache_loc[:, 0] // self.page_size).to(torch.int32)
page_offsets = torch.arange(
      max_pages_expand, device=page_base.device, dtype=torch.int32
)

like in PR: #15107
Because the pages assigned in out_cache_loc is not necessarily continuous chunk. For example, your last extended page is page 1 and the first topk your assigned page is 4. You need to calculate the actual page based on the allocated location.

We need to // page_size for each location, and count the unique ones. We have the following cases to handle:

  • last_page_len != 0, duplicate. Say page_size = 4, last_page_len = 3, [0, 1, 2] is duplicate. Decode length = 5, out_cache_loc = [3, 4, 5, 6, 7, 8], expected pages = [0, 1, 2], selected indices: [0, 1, 5].
  • last_page_len = 0, no duplicate. Say page_size = 4, decode length = 5, out_cache_loc = [4, 5, 6, 7, 8, 9], expected pages = [1, 2, 3].
  • Else, pages are not consecutive. Say page_size = 4, decode length = 3, out_cache_loc = [3, 40, 41, 42], expected pages = [0, 10].
  • Decode length < page_size. page_size = 4, decode length = 2, last_page_len = 3, out_cache_loc = [3, 4, 5], expected pages = [0, 1].

The solution we are doing here: allocate zeros tensor same size as out_cache_loc, move the unique pages (after // page_size) to the front. We rely on the cache_seq_lens to limit the accessing of the kv cache in a page.

Accuracy Tests

Added detailed unit tests.

Benchmarking and Profiling

bench_sglang_eagle.py
#questions: 80, Throughput: 1410.20 token/s, Acceptance length: 2.67

Checklist

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @yubofredwang, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical Illegal Memory Access (IMA) bug that manifested under specific conditions during speculative decoding, particularly when dealing with page-sized memory allocations. The core of the fix involves a significant refactor of the page table generation logic within the attention backend, ensuring that memory pages are correctly identified and mapped, even in non-contiguous scenarios. This change enhances the stability and correctness of the speculative decoding process, preventing crashes and improving overall system reliability.

Highlights

  • Illegal Memory Access (IMA) Fix: Resolved an IMA bug occurring during speculative decoding when fa3 + spec + topk + page_size > 1. The issue stemmed from incorrect handling of the page table, specifically not accounting for the first last_partial_page before decode step pages.
  • Page Table Calculation Logic Refactor: Introduced a new vectorized approach for calculating the expand_metadata's page table. The previous logic, which relied on strided_indices_expand and direct subtraction, has been replaced with a more robust method that identifies unique pages from cache_loc using torch.unique_consecutive-like operations.
  • Enhanced Test Coverage: Added detailed unit tests for the new page table calculation logic, including scenarios with multiple batches and varying decode_length values. A new test case with a large page_size (256) was introduced to specifically target and expose potential IMA bugs.
  • Code Simplification and Consistency: Refactored the update_draft_decode_set_expand_metadata_with_page_size function, renaming it to draft_decode_set_expand_metadata and streamlining its parameters. Minor updates were also made to benchmark scripts for consistency (e.g., changing 'prompt' to 'turns').

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@yubofredwang
Copy link
Contributor Author

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an Illegal Memory Access issue that occurs under specific speculative decoding configurations. The fix involves a significant refactoring of the page table metadata generation for the expanded draft tokens, moving to a more robust and efficient vectorized implementation. The new logic correctly handles non-consecutive page allocations. The unit tests have been extended to cover more complex scenarios, including multi-batch generation and larger page sizes, which is great for preventing regressions.

However, I've identified a potential critical issue in the CUDA graph path. It appears that a pre-allocated tensor for the page table is not sized correctly, which could lead to an out-of-bounds memory access during graph replay. Please see my specific comment for details.

I am having trouble creating individual review comments. Click here to see my feedback.

python/sglang/srt/layers/attention/flashattention_backend.py (1386-1391)

critical

The size of the page_table for draft_decode_metadata_topk_expand seems incorrect for the CUDA graph path. It's allocated with size decode_length + 1, where decode_length is self.speculative_step_id + 1.

However, during CUDA graph replay, the draft_decode_set_expand_metadata function is called with a cache_loc tensor of shape (..., self.speculative_num_steps). The function processes the full cache_loc tensor to compute unique page indices. The resulting positions tensor for the scatter_ operation can have indices up to self.speculative_num_steps - 1.

If self.speculative_num_steps > decode_length + 1 (which is true for earlier draft steps), the scatter_ operation will attempt to write out of bounds, leading to an Illegal Memory Access.

The non-graph path correctly allocates a tensor of size self.speculative_num_steps for this purpose (in init_forward_metadata). The CUDA graph path should be consistent.

I suggest changing the allocation to use self.speculative_num_steps.

"page_table": torch.zeros(
    max_bs * self.topk,
    self.speculative_num_steps,
    dtype=torch.int32,
    device=self.device,
),

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses an illegal memory access issue that occurs under specific speculative decoding configurations with paged attention. The core of the fix involves a significant refactoring of the page table calculation logic for the 'expand' metadata, replacing a complex and buggy implementation with a more robust vectorized approach to handle non-consecutive memory pages. The changes are well-supported by new, comprehensive unit tests that cover various edge cases, including multi-batch scenarios, and a new integration test with a large page size to prevent regressions.

While the new logic is a clear improvement, I've identified a critical issue in the CUDA graph path where a buffer (page_table) is allocated with an incorrect size. This could lead to out-of-bounds memory access when running with CUDA graphs. My review includes a suggestion to fix this.

Comment on lines 1407 to 1412
"page_table": torch.zeros(
max_bs * self.topk,
decode_length,
decode_length + 1, # Additional page for last partial page
dtype=torch.int32,
device=self.device,
),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The size of page_table for the CUDA graph path seems incorrect. It is initialized with decode_length + 1 columns, where decode_length is self.speculative_step_id + 1. However, the draft_decode_set_expand_metadata function processes cache_loc which has self.speculative_num_steps columns. The number of unique pages can be up to self.speculative_num_steps, so the page_table needs to be large enough to hold them.

The non-CUDA graph path correctly allocates self.speculative_num_steps columns. To fix this potential out-of-bounds error during the scatter_ operation in draft_decode_set_expand_metadata, the page_table for the CUDA graph should also be allocated with self.speculative_num_steps columns.

Suggested change
"page_table": torch.zeros(
max_bs * self.topk,
decode_length,
decode_length + 1, # Additional page for last partial page
dtype=torch.int32,
device=self.device,
),
"page_table": torch.zeros(
max_bs * self.topk,
self.speculative_num_steps,
dtype=torch.int32,
device=self.device,
),

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"The number of unique pages can be up to self.speculative_num_steps" This assumption is wrong. We are doing page_size > 1. So the number of unique pages can only be less.

"choices": {
"index": 0,
"prompt": [answers[i][0], answers[i][1]],
"turns": [answers[i][0], answers[i][1]],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverting the previous wrong change. https://github.com/lm-sys/FastChat/blob/main/fastchat/llm_judge/data/mt_bench/question.jsonl is the right place to download. Not the HF copy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants