Skip to content

Conversation

@peaceh-nv
Copy link
Collaborator

@peaceh-nv peaceh-nv commented Aug 14, 2025

  1. The bmm1_scale is the scale from softmax[scale(Q*K)] and its value is 1 / (q_scaling * sqrt(qk_head_dim)) for any dtype
  2. For FP8 fmha/mla, bmm1 has Q*K quant scale, which is q_scale_quant_orig * kv_scale_quant_orig, and bmm1_scale = q_scale_quant_orig * kv_scale_quant_orig * (the scale mentioned in 1)
  3. bmm2 scale is for FP8 fmha/mla, its value is o_scale_orig_quant * kv_scale_quant_orig

Summary by CodeRabbit

  • Bug Fixes
    • Corrected FP8/MLA attention scaling in context path to improve numeric stability and consistency.
    • Populates additional quant/dequant scaling values used for Q/K/O to reduce sporadic deviations in affected workloads.
    • No changes to public APIs or expected performance.

Description

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Aug 14, 2025

📝 Walkthrough

Walkthrough

Recomputes and assigns FP8 BMM1 scaling and additional FP8 quant/dequant scales in the MLA-enabled attention path on the host, and extends the MLA kernel to accept/produce FP8 bmm scales and additional quant/dequant scale parameters; host invocations updated to pass the new parameters. No public API signature changes.

Changes

Cohort / File(s) Summary of Changes
AttentionOp MLA FP8 scale & MLA params
cpp/tensorrt_llm/common/attentionOp.cpp
Removed use of decoder_params.fmhaHostBmm1Scale; compute host_bmm1_scale = 1 / (mQScaling * sqrt((float)(mMLAParams.qk_nope_head_dim + mMLAParams.qk_rope_head_dim))) and assign to params.mla_param->host_bmm1_scale. Populate quant_scale_q = kv_scale_orig_quant and quant_scale_kv = kv_scale_orig_quant; keep quant_scale_o unchanged. dequant_scale_q and dequant_scale_kv continue to use kv_scale_quant_orig. No fmha launch control-flow changes aside from these param updates.
MLA kernels: signature and FP8 precompute
cpp/tensorrt_llm/kernels/mlaKernels.cu
Extended kernel signature applyMLARopeAndAssignQKVKernelOptContext to accept float* bmm1_scale, float* bmm2_scale, float const* quant_scale_o, float const* quant_scale_kv, float const* dequant_scale_q, float const* dequant_scale_kv, float host_bmm1_scale. Added FP8 precompute (executed once per kernel launch) that computes bmm1_scale and bmm2_scale from the supplied quant/dequant values and host_bmm1_scale, exposing results back to host via bmm1_scale/bmm2_scale. Updated all host invocation sites to pass the new parameters. Core kernel control flow and templates otherwise unchanged.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant AttentionOp
    participant CUDA as Kernel
    Note over AttentionOp: Host prepares MLA params
    Caller->>AttentionOp: enqueueContext(...)
    AttentionOp->>AttentionOp: compute host_bmm1_scale\nset quant/dequant scale fields
    AttentionOp->>CUDA: launch applyMLARopeAndAssignQKVKernelOptContext(..., bmm1_scale, bmm2_scale, quant_scale_o, quant_scale_kv, dequant_scale_q, dequant_scale_kv, host_bmm1_scale)
    CUDA->>CUDA: (first thread) if FP8: compute bmm1_scale, bmm2_scale
    CUDA-->>AttentionOp: kernel returns (bmm1_scale/bmm2_scale available on device)
    AttentionOp-->>Caller: return status
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~35 minutes

Possibly related PRs

Suggested labels

Community want to contribute

Suggested reviewers

  • lucifer1004
  • bobboli
  • jinyangyuan-nvidia
  • litaotju
  • syuoni

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🔭 Outside diff range comments (1)
cpp/tensorrt_llm/common/attentionOp.cpp (1)

768-770: Workspace under-allocation for FP8 context MLA BMM scales

In getWorkspaceSizeForContext(), the BMM scale buffers are only reserved when mFP8ContextFMHA is true. The new MLA code path writes two floats to fmha_bmm1_scale_ptr (and uses fmha_bmm2_scale_ptr). If mFP8ContextMLA is enabled without mFP8ContextFMHA, the workspace will be under-sized, leading to out-of-bounds writes.

Update the allocations to include mFP8ContextMLA:

// Before:
size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof(float) * 2 : 0;
size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof(float) : 0;

// After:
size_t const fmha_bmm1_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) * 2 : 0;
size_t const fmha_bmm2_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) : 0;

Follow-up: Please verify other workspace size helpers do not miss similar FP8ContextMLA cases. I can help run a repo-wide scan if needed.

🧹 Nitpick comments (1)
cpp/tensorrt_llm/common/attentionOp.cpp (1)

1777-1782: Prefer sanitized memcpy and remove TODO-like comment

  • Use tensorrt_llm::common::cudaMemcpyAsyncSanitized or at least TLLM_CUDA_CHECK for consistency and immediate error surfacing.
  • The “Need to figure out the correct bmm1 scales” comment should be removed once the computation is used (see previous suggestion).
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4200fa4 and 29172af.

📒 Files selected for processing (1)
  • cpp/tensorrt_llm/common/attentionOp.cpp (1 hunks)
🧰 Additional context used
📓 Path-based instructions (4)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...

Files:

  • cpp/tensorrt_llm/common/attentionOp.cpp
**/*.{cpp,cxx,cc,cu}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu}: Avoid literal values except for 0, nullptr, true, false; use named constexpr for other literals
Place semicolon of empty for/while loop on a new line
Always use brace-delimited bodies for switch/while/do-for/if/else
Use inline C comments in argument lists when parameter meaning is unclear (e.g., /* checkForErrors = */ false)
Do not use assignment in subexpressions (e.g., if (x = y) ... is forbidden)
Switch on enums should enumerate all values and omit default to catch new values at compile time
Structure switch statements; prohibit fallthrough except between empty cases; each case ends with break or throw; return at end of case not allowed; put break inside braces for compound case
Prefer anonymous namespaces over static for internal linkage of functions
Every defined function must be called at least once (no unused methods)

Files:

  • cpp/tensorrt_llm/common/attentionOp.cpp
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Parameter names must be consistent between declarations and definitions

Files:

  • cpp/tensorrt_llm/common/attentionOp.cpp
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • cpp/tensorrt_llm/common/attentionOp.cpp

@peaceh-nv peaceh-nv force-pushed the fix-deepseek-lite-acc branch from 29172af to bff138f Compare August 14, 2025 05:41
… context MLA

Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
@peaceh-nv peaceh-nv force-pushed the fix-deepseek-lite-acc branch from bff138f to d19422e Compare August 14, 2025 07:48
@peaceh-nv peaceh-nv changed the title [https://nvbugspro.nvidia.com/bug/5451373][fix]: Fix the accuracy whe… [https://nvbugs/5451373][fix] : Fix the accuracy issue when using FP8 context MLA Aug 14, 2025
@peaceh-nv peaceh-nv requested a review from PerkzZheng August 14, 2025 07:50
@peaceh-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15260 [ run ] triggered by Bot

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
cpp/tensorrt_llm/kernels/mlaKernels.cu (2)

235-260: Const-correctness and constant naming per guidelines

  • Make local temporaries const.
  • Rename kLog2e to kLOG2E to follow kPREFIXED_UPPER_SNAKE style for constants (coding_guidelines).

Apply:

-        if (cache_type == KvCacheDataType::FP8)
+        if (cache_type == KvCacheDataType::FP8)
         {
-            float dequant_scale_q_val = dequant_scale_q ? dequant_scale_q[0] : 1.f;
-            float dequant_scale_kv_val = dequant_scale_kv ? dequant_scale_kv[0] : 1.f;
-            float quant_scale_o_val = quant_scale_o ? quant_scale_o[0] : 1.f;
+            float const dequant_scale_q_val = dequant_scale_q ? dequant_scale_q[0] : 1.f;
+            float const dequant_scale_kv_val = dequant_scale_kv ? dequant_scale_kv[0] : 1.f;
+            float const quant_scale_o_val = quant_scale_o ? quant_scale_o[0] : 1.f;
             if (bmm1_scale)
             {
                 // The scale prepared for log2 optimization.
-                constexpr float kLog2e = 1.4426950408889634074f;
+                constexpr float kLOG2E = 1.4426950408889634074f;
                 // The scale after fmha bmm1.
-                float bmm1_scale_val = dequant_scale_q_val * dequant_scale_kv_val * host_bmm1_scale;
+                float const bmm1_scale_val = dequant_scale_q_val * dequant_scale_kv_val * host_bmm1_scale;
                 bmm1_scale[0] = bmm1_scale_val;
-                bmm1_scale[1] = bmm1_scale_val * kLog2e;
+                bmm1_scale[1] = bmm1_scale_val * kLOG2E;
             }
             if (bmm2_scale)
             {
                 // The scale after fmha bmm2.
-                bmm2_scale[0] = quant_scale_o_val * dequant_scale_kv_val;
+                bmm2_scale[0] = quant_scale_o_val * dequant_scale_kv_val;
             }
         }

1-15: Update copyright year

Per coding guidelines, prepend NVIDIA copyright header (current year). The file shows 2019-2023; please update to include 2025 (e.g., “2019-2025”).

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bff138f and d19422e.

📒 Files selected for processing (2)
  • cpp/tensorrt_llm/common/attentionOp.cpp (1 hunks)
  • cpp/tensorrt_llm/kernels/mlaKernels.cu (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • cpp/tensorrt_llm/common/attentionOp.cpp
🧰 Additional context used
📓 Path-based instructions (4)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...

Files:

  • cpp/tensorrt_llm/kernels/mlaKernels.cu
**/*.{cpp,cxx,cc,cu}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu}: Avoid literal values except for 0, nullptr, true, false; use named constexpr for other literals
Place semicolon of empty for/while loop on a new line
Always use brace-delimited bodies for switch/while/do-for/if/else
Use inline C comments in argument lists when parameter meaning is unclear (e.g., /* checkForErrors = */ false)
Do not use assignment in subexpressions (e.g., if (x = y) ... is forbidden)
Switch on enums should enumerate all values and omit default to catch new values at compile time
Structure switch statements; prohibit fallthrough except between empty cases; each case ends with break or throw; return at end of case not allowed; put break inside braces for compound case
Prefer anonymous namespaces over static for internal linkage of functions
Every defined function must be called at least once (no unused methods)

Files:

  • cpp/tensorrt_llm/kernels/mlaKernels.cu
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Parameter names must be consistent between declarations and definitions

Files:

  • cpp/tensorrt_llm/kernels/mlaKernels.cu
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • cpp/tensorrt_llm/kernels/mlaKernels.cu
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
cpp/tensorrt_llm/kernels/mlaKernels.cu (2)

949-953: Host invocation updated correctly

The new arguments order matches the kernel signature and types. No issues spotted with grid/block dimensions or parameter wiring.


949-953: Verified — kernel call sites updated and bmm scale buffers sized correctly

  • Summary: I checked the repo — the kernel is called only with the new signature and bmm1_scale/bmm2_scale workspace sizes allocate 2 and 1 floats respectively; host_bmm1_scale is present.

  • Evidence (key locations):

    • Kernel signature: cpp/tensorrt_llm/kernels/mlaKernels.cu:207–213 (expects float* bmm1_scale, float* bmm2_scale, float host_bmm1_scale).
    • Kernel call sites: cpp/tensorrt_llm/kernels/mlaKernels.cu:949 and cpp/tensorrt_llm/kernels/mlaKernels.cu:1024 (both pass params.bmm1_scale, params.bmm2_scale, params.host_bmm1_scale).
    • bmm1_scale writes two values (indices 0 and 1): cpp/tensorrt_llm/kernels/mlaKernels.cu:249–251 and 430–432.
    • Workspaces allocate correct sizes: cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.h:276–281 (bmm1_scale_size = sizeof(float) * 2; bmm2_scale_size = sizeof(float)); attentionOp.cpp also sets/assigns mla_bmm1_scale_ptr and mla_bmm2_scale_ptr (see the workspace/pointer assignments around cpp/tensorrt_llm/common/attentionOp.cpp:957–972 and usage at 1081–1084).

Conclusion: No lingering call sites or sizing issues found — no changes required.

@peaceh-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15357 [ run ] triggered by Bot

@peaceh-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15362 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15357 [ run ] completed with state ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15362 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11585 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@kaiyux kaiyux merged commit 1c1d5d2 into NVIDIA:main Aug 15, 2025
6 of 7 checks passed
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Aug 17, 2025
… context MLA (NVIDIA#6881)

Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Aug 17, 2025
… context MLA (NVIDIA#6881)

Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Aug 17, 2025
… context MLA (NVIDIA#6881)

Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Aug 17, 2025
… context MLA (NVIDIA#6881)

Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Aug 18, 2025
… context MLA (NVIDIA#6881)

Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Aug 18, 2025
… context MLA (NVIDIA#6881)

Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Aug 18, 2025
… context MLA (NVIDIA#6881)

Signed-off-by: peaceh <103117813+peaceh-nv@users.noreply.github.com>
Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants