Skip to content

Conversation

@stnie
Copy link
Collaborator

@stnie stnie commented Sep 15, 2025

Summary by CodeRabbit

  • Bug Fixes

    • Corrects loading of weight scales in tensor-parallel setups for fused QKV in quantized linear layers, preventing shard misalignment and ensuring consistent outputs across ranks.
  • Performance and Stability

    • Improves reliability and consistency during initialization and inference for tensor-parallel models using quantized QKV layers.

Description

Using int8 weight only quanization could not load weights when using TP > 1, as loading the weight scales did not get the relevant information regarding TP size/rank/etc.

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@stnie
Copy link
Collaborator Author

stnie commented Sep 15, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18632 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18632 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13990 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@stnie stnie marked this pull request as ready for review September 16, 2025 09:05
@stnie stnie requested a review from a team as a code owner September 16, 2025 09:05
@stnie stnie requested review from Funatiq and hlu1 September 16, 2025 09:05
@stnie stnie force-pushed the develop/quantization/load_int8wo_tp branch from 5fcec50 to ae57700 Compare September 16, 2025 09:06
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 16, 2025

📝 Walkthrough

Walkthrough

The call to load weight scales in the fused QKV path of WeightOnlyQuantLinearMethod was updated to pass explicit tensor-parallel parameters (tp_size, tp_rank, tp_mode), making the loading shard-aware per the module’s TP configuration. Other surrounding logic, including q/k/v concatenation and assignment to module.weight_scale, remains unchanged.

Changes

Cohort / File(s) Summary
TP-aware weight scale loading
tensorrt_llm/_torch/modules/linear.py
Modified the call to self.load_weight_scales(...) in the fused QKV loading path to pass tp_size=module.tp_size, tp_rank=module.tp_rank, and tp_mode=module.tp_mode. Concatenation logic and module.weight_scale assignment unchanged.

Sequence Diagram(s)

sequenceDiagram
    participant M as Module (WeightOnlyQuantLinearMethod)
    participant L as load_weight_scales

    M->>L: load_weight_scales(weights, tp_size=module.tp_size, tp_rank=module.tp_rank, tp_mode=module.tp_mode)
    note over L: Return TP-shard-aware weight_scales
    L-->>M: weight_scales
    M->>M: Assign to module.weight_scale after fused QKV processing
Loading

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title Check ✅ Passed The title "[None][fix] Add TP information in weight scale loading in WeightOnlyQuantLinearMethod" is concise, directly describes the primary change (making weight-scale loading tensor-parallel aware) and matches the changes described in the provided summary, so it communicates the main purpose to reviewers.
Description Check ✅ Passed The PR description follows the repository template and clearly states the problem and intended fix (weight-only int8 quantization failing to load scales under TP>1 because TP metadata was not passed), but the Test Coverage section is empty and does not list the tests that will validate the change even though CI results are referenced in the PR objectives.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/modules/linear.py (1)

1447-1447: Mirror the TP fix in W4A16_AWQ fused-QKV path.

W4A16_AWQ_LinearMethod.load_weights_fused_qkv_linear calls self.load_weight_scales(weights) without TP args, likely causing incorrect scale sharding under TP>1 (same root cause you just fixed for INT8 WO).

Apply:

-        weight_scales = self.load_weight_scales(weights)
+        weight_scales = self.load_weight_scales(
+            weights,
+            tp_size=module.tp_size,
+            tp_rank=module.tp_rank,
+            tp_mode=module.tp_mode,
+        )
🧹 Nitpick comments (1)
tensorrt_llm/_torch/modules/linear.py (1)

1249-1274: Optional: avoid hard-coding device in load_weight_scales.

WeightOnlyQuantLinearMethod.load_weight_scales uses torch.device("cuda"). Prefer aligning with the parameter device (e.g., module.weight_scale.device) to prevent cross-device copies in nonstandard setups. This would require threading a device arg through the call.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between c6ab207 and ae57700.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/linear.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/modules/linear.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/modules/linear.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/modules/linear.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/modules/linear.py (2)
tensorrt_llm/_torch/distributed/communicator.py (2)
  • tp_size (46-47)
  • tp_rank (54-55)
tensorrt_llm/mapping.py (1)
  • tp_rank (338-339)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

@stnie stnie force-pushed the develop/quantization/load_int8wo_tp branch from ae57700 to 80bf557 Compare September 17, 2025 08:31
@stnie
Copy link
Collaborator Author

stnie commented Sep 17, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18947 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18947 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14201 completed with status: 'FAILURE'

@stnie stnie force-pushed the develop/quantization/load_int8wo_tp branch from 80bf557 to ab7e838 Compare September 17, 2025 09:44
@stnie
Copy link
Collaborator Author

stnie commented Sep 17, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18969 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18969 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14220 completed with status: 'FAILURE'

…d to include tensor parallel parameters

Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
@stnie stnie force-pushed the develop/quantization/load_int8wo_tp branch from ab7e838 to cef8e6f Compare September 17, 2025 15:55
@stnie
Copy link
Collaborator Author

stnie commented Sep 17, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19025 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19025 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14265 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@stnie stnie merged commit a55251b into NVIDIA:main Sep 18, 2025
5 checks passed
Wong4j pushed a commit to Wong4j/TensorRT-LLM that referenced this pull request Sep 20, 2025
…uantLinearMethod (NVIDIA#7732)

Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
MrGeva pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Sep 21, 2025
…uantLinearMethod (NVIDIA#7732)

Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
@stnie stnie deleted the develop/quantization/load_int8wo_tp branch December 18, 2025 07:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants