Skip to content

Conversation

@dcaox
Copy link
Collaborator

@dcaox dcaox commented Sep 25, 2025

Summary by CodeRabbit

  • New Features
    • Added support for returning top-k per-token log probabilities via a configurable logprobs parameter.
    • Removed prior limitation on PyTorch backend, allowing logprobs values greater than 1.
  • Performance
    • Reduced overhead by avoiding unnecessary host-device transfers when log probabilities are not requested.
  • Tests
    • Expanded validations to ensure per-step logprob counts, ordering, and rank consistency.
    • Added new streaming test case to cover higher logprobs configurations.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@dcaox dcaox requested review from a team as code owners September 25, 2025 04:52
@dcaox dcaox requested review from Naveassaf and hchings September 25, 2025 04:52
@dcaox
Copy link
Collaborator Author

dcaox commented Sep 25, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19889 [ run ] triggered by Bot

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 25, 2025

📝 Walkthrough

Walkthrough

Adds num_logprobs propagation from task/request layers into executor and LlmRequest, updates PyTorch sampler to compute per-token logprobs via top-k over log-softmaxed logits, removes a PyTorch-only validation limiting logprobs > 1, wires worker/scaffolding to pass logprobs, and extends tests to validate ordering/ranks.

Changes

Cohort / File(s) Summary
Request propagation (num_logprobs)
tensorrt_llm/_torch/pyexecutor/llm_request.py
LlmRequest gains num_logprobs parameter and stores py_num_logprobs; executor_request_to_llm_request forwards it (default 0).
Sampler logprobs derivation
tensorrt_llm/_torch/pyexecutor/sampler.py
Switches to compute logprobs via F.log_softmax + top-k; constructs Logprob entries from top-k values/indices; replaces log_probs_host tensor with boolean flag; updates multiple internal methods’ signatures and control flow.
Backend request wiring
tensorrt_llm/executor/base_worker.py
Forwards sampling_params.logprobs to executor_request.py_num_logprobs.
Backend validation update
tensorrt_llm/llmapi/llm.py
Removes PyTorch-specific restriction that disallowed logprobs > 1.
Worker scaffolding
tensorrt_llm/scaffolding/worker.py
Adds logprobs=task.num_logprobs to SamplingParams in TRTLLMWorker.convert_task_params.
Tests: API behavior
tests/unittest/llmapi/test_llm.py
Extends harness to validate per-step logprobs count equals request, non-increasing order, and consecutive ranks starting at 1.
Tests: PyTorch backend
tests/unittest/llmapi/test_llm_pytorch.py
Adds test case tuple (2, 3, False, False) for streaming logprobs.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Client
  participant LLMAPI as LLM API
  participant Worker as BaseWorker
  participant ExecReq as ExecutorRequest
  participant LlmReq as LlmRequest (Py)
  participant Sampler

  Client->>LLMAPI: GenerationRequest(sampling_params.logprobs = N)
  LLMAPI->>Worker: create request (no PyTorch logprobs cap)
  Worker->>ExecReq: set py_num_logprobs = N
  ExecReq->>LlmReq: construct(num_logprobs = py_num_logprobs)
  note over LlmReq: Stores py_num_logprobs

  LlmReq->>Sampler: sample(step logits, num_logprobs = N)
  rect rgba(200, 235, 255, 0.2)
    note right of Sampler: New/changed
    Sampler->>Sampler: log_softmax(logits)
    Sampler->>Sampler: top-k over log-probs (k = N)
    Sampler->>Sampler: build Logprob entries (rank=1..N)
  end
  Sampler-->>Client: tokens + per-step top-k logprobs
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • Superjomn
  • Naveassaf
  • mikeiovine

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The PR description has only template placeholders and lacks any filled-in summary, detailed description, or test coverage information, leaving required sections empty and providing no context for the changes or how they are validated. Replace the placeholder comments with a concise summary of the changes and rationale, fill out the Description and Test Coverage sections with implementation details and relevant tests, and complete the PR checklist to demonstrate readiness for review.
Docstring Coverage ⚠️ Warning Docstring coverage is 8.33% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title clearly indicates the addition of top-k log probability support in the Torch backend and aligns with the required [None][feat] template, concisely describing the primary change without extraneous detail, making the main feature clear to reviewers.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
tests/unittest/llmapi/test_llm.py (1)

1865-1876: Align size expectation with earlier check; remove k+1 inconsistency.

Above (Line 1861), the test allows len(...) in {k, k+1}, but here it enforces exactly k. Pick one behavior to avoid flaky failures. If the sampler now guarantees exactly top‑k, assert k consistently in both places.

Apply this change outside the current hunk (near Line 1861) to make both checks consistent:

# Replace:
assert logprobs_result and len(logprobs_result[0].keys()) in {logprobs, logprobs + 1}
# With:
assert logprobs_result and len(logprobs_result[0].keys()) == logprobs
tensorrt_llm/_torch/pyexecutor/llm_request.py (1)

314-315: Constructor surface looks good; consider documenting num_logprobs.

Signature extension with num_logprobs: int = 0 is fine. Add a brief docstring entry so downstream readers know it controls top‑k logprobs size for generation.

tensorrt_llm/scaffolding/worker.py (1)

183-185: Propagate prompt_logprobs as well (if available on task).

You’re wiring logprobs for generation. If GenerationTask also carries prompt‑time logprobs (mirroring SamplingParams(prompt_logprobs=...)), pass it through to support prompt logprobs in this worker path.

For example:

         sampling_params = SamplingParams(
             max_tokens=task.max_tokens,
             temperature=task.temperature,
             top_p=task.top_p,
             top_k=task.top_k,
             return_context_logits=task.return_context_logits,
-            logprobs=task.num_logprobs)
+            logprobs=task.num_logprobs,
+            prompt_logprobs=getattr(task, "prompt_logprobs", None))
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bb60671 and b5fc716.

📒 Files selected for processing (7)
  • tensorrt_llm/_torch/pyexecutor/llm_request.py (3 hunks)
  • tensorrt_llm/_torch/pyexecutor/sampler.py (7 hunks)
  • tensorrt_llm/executor/base_worker.py (1 hunks)
  • tensorrt_llm/llmapi/llm.py (0 hunks)
  • tensorrt_llm/scaffolding/worker.py (1 hunks)
  • tests/unittest/llmapi/test_llm.py (1 hunks)
  • tests/unittest/llmapi/test_llm_pytorch.py (1 hunks)
💤 Files with no reviewable changes (1)
  • tensorrt_llm/llmapi/llm.py
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/executor/base_worker.py
  • tensorrt_llm/_torch/pyexecutor/llm_request.py
  • tensorrt_llm/scaffolding/worker.py
  • tests/unittest/llmapi/test_llm.py
  • tests/unittest/llmapi/test_llm_pytorch.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/executor/base_worker.py
  • tensorrt_llm/_torch/pyexecutor/llm_request.py
  • tensorrt_llm/scaffolding/worker.py
  • tests/unittest/llmapi/test_llm.py
  • tests/unittest/llmapi/test_llm_pytorch.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/executor/base_worker.py
  • tensorrt_llm/_torch/pyexecutor/llm_request.py
  • tensorrt_llm/scaffolding/worker.py
  • tests/unittest/llmapi/test_llm.py
  • tests/unittest/llmapi/test_llm_pytorch.py
  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧠 Learnings (1)
📚 Learning: 2025-08-28T10:25:22.370Z
Learnt from: ixlmar
PR: NVIDIA/TensorRT-LLM#7294
File: tensorrt_llm/_torch/pyexecutor/sampler.py:887-891
Timestamp: 2025-08-28T10:25:22.370Z
Learning: In tensorrt_llm/_torch/pyexecutor/sampler.py, the draft_probs and target_probs tensors have shapes [1, steps] not [steps, vocab_size] as might be expected, making the .squeeze(0) operations appropriate for removing the batch dimension of size 1.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/sampler.py
🧬 Code graph analysis (3)
tensorrt_llm/executor/base_worker.py (1)
tensorrt_llm/scaffolding/task.py (1)
  • logprobs (99-100)
tensorrt_llm/scaffolding/worker.py (2)
tests/unittest/llmapi/test_llm.py (6)
  • task (481-488)
  • task (528-533)
  • task (1881-1890)
  • task (2000-2013)
  • task (2400-2401)
  • task (2491-2510)
tensorrt_llm/scaffolding/task.py (1)
  • logprobs (99-100)
tensorrt_llm/_torch/pyexecutor/sampler.py (2)
tensorrt_llm/executor/result.py (1)
  • Logprob (37-40)
tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
  • all_requests (38-39)
🪛 Ruff (0.13.1)
tensorrt_llm/_torch/pyexecutor/sampler.py

1276-1276: Unused method argument: log_probs_host

(ARG002)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tests/unittest/llmapi/test_llm_pytorch.py (1)

918-918: Good addition to cover top-k>1 in streaming path.

This parametrization exercises the new top‑3 generation logprobs in the PyTorch backend under streaming. Looks consistent with the PR intent.

tensorrt_llm/_torch/pyexecutor/llm_request.py (1)

359-360: LGTM: carry num_logprobs on the Python side.

Storing py_num_logprobs on the request mirrors other py_* fields and will propagate to child requests via the existing copy path.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19889 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14968 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@dcaox dcaox requested review from QiJune and Superjomn September 26, 2025 01:45
Signed-off-by: Dong Cao <docao@nvidia.com>
@dcaox dcaox force-pushed the docao/support_topk_logprobs_torch_backend_v2 branch from b5fc716 to adbceca Compare September 26, 2025 01:47
@dcaox
Copy link
Collaborator Author

dcaox commented Sep 26, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20016 [ run ] triggered by Bot

Copy link
Collaborator

@hchings hchings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Signed-off-by: Cao Dong <87467313+dcaox@users.noreply.github.com>
@dcaox
Copy link
Collaborator Author

dcaox commented Sep 26, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20052 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20016 [ run ] completed with state ABORTED
LLM/main/L0_MergeRequest_PR #15076 (Blue Ocean) completed with status: ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20052 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15105 completed with status: 'FAILURE'

@dcaox
Copy link
Collaborator Author

dcaox commented Sep 27, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20128 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20128 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15171 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@dcaox
Copy link
Collaborator Author

dcaox commented Sep 29, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20260 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20260 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15277 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@juney-nvidia juney-nvidia merged commit 62010c0 into NVIDIA:main Sep 30, 2025
5 checks passed
faradawn pushed a commit to faradawn/TensorRT-LLM that referenced this pull request Oct 2, 2025
Signed-off-by: Cao Dong <87467313+dcaox@users.noreply.github.com>
Signed-off-by: Faradawn Yang <faradawny@gmail.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 1, 2025
Signed-off-by: Cao Dong <87467313+dcaox@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
Signed-off-by: Cao Dong <87467313+dcaox@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
Signed-off-by: Cao Dong <87467313+dcaox@users.noreply.github.com>
dominicshanshan pushed a commit to dominicshanshan/TensorRT-LLM that referenced this pull request Nov 3, 2025
Signed-off-by: Cao Dong <87467313+dcaox@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants