-
Notifications
You must be signed in to change notification settings - Fork 2k
[https://nvbugs/5508536][fix] Revert #7041: Move stop_criteria to sample_async (#7041) #7796
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[https://nvbugs/5508536][fix] Revert #7041: Move stop_criteria to sample_async (#7041) #7796
Conversation
c24b820 to
c2c7245
Compare
📝 WalkthroughWalkthroughRefactors sampling to be beam-aware, centralizes stop-criteria handling inside Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant SR as ScheduledRequests
participant TS as TorchSampler
participant M as Model
participant Req as LlmRequest
rect rgb(245,248,252)
note over TS: Initialization
TS->>TS: create_store() -> Store(new_tokens)
end
SR->>M: run scheduled batch
M-->>TS: model_outputs, num_context_logits_prefix_sum
TS->>TS: sample_async(...)
TS->>TS: _process_requests(..., new_tokens, ...)
loop per request
TS->>Req: process_draft_tokens(...)
alt stop criteria met
TS->>TS: _handle_stop_criteria(END_ID/LENGTH/STOP_WORDS)
TS->>Req: finish_by(reason, beam=self.BEAM)
else continue
TS->>Req: add_token(new_token, beam=self.BEAM)
end
end
TS-->>SR: SampleState
sequenceDiagram
autonumber
participant Old as Old (single-beam utils)
participant New as New (beam-aware sampler)
rect rgb(252,249,245)
note left of Old: Previous flow
Old->>Old: handle_stop_single_beam(request, new_token, max_seq_len)
Old-->>Old: Uses BEAM_0, SINGLE_BEAM_WIDTH\nproduce_stop_words()
end
rect rgb(245,252,248)
note right of New: Current flow
New->>New: _handle_stop_criteria(request, token)
New-->>New: Uses self.BEAM, self.MAX_BEAM_WIDTH\n_local stop checks (_meet_*)
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60–90 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
/bot run |
…-critera-to-sample-async
|
/bot run |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/speculative/mtp.py (1)
1-1: Add NVIDIA Apache-2.0 header (2025) at file top.Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.tensorrt_llm/_torch/pyexecutor/sampler.py (1)
327-341: Make tensorrt_llm/_torch/pyexecutor/sampler.py Python 3.8-compatible: replace match/case and PEP-585/604 type hints (critical)Replace the match/case at line ~330 with an if/elif chain (use this replacement):
def sample(strategy: Strategy, logits: torch.Tensor, generator: Optional[torch.Generator] = None): kind = strategy[0] if kind == "top_k": _, top_k = strategy return top_k_sampling_batch(logits, top_k, generator) elif kind == "top_p": _, top_p, temperature = strategy return top_p_sampling_batch(logits, top_p, temperature, generator) elif kind == "top_k_top_p": _, top_k, top_p, temperature = strategy return top_k_top_p_sampling_batch(logits, top_k, top_p, temperature, generator) else: # "greedy" return greedy_search_sampling_batch(logits)Also convert all PEP-585 and PEP-604 annotations in the file to typing equivalents:
- list[int] -> List[int]
- dict[str, T] -> Dict[str, T]
- tuple[int, ...] -> Tuple[int, ...]
- X | None -> Optional[X]
Add the required imports (e.g., from typing import List, Dict, Tuple, Optional).
Occurrences found (from grep): lines 60, 85, 124, 269, 330, 354, 581, 582, 608, 609, 659, 661, 775, 938 in tensorrt_llm/_torch/pyexecutor/sampler.py — update those locations.
🧹 Nitpick comments (3)
tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
427-461: convert_wordlist: tighten signature and trivially sanitize input.The implementation matches the documented 2×N format. Make the param typed and guard non-int tokens.
Apply:
-def convert_wordlist(word_list) -> List[List[int]]: +def convert_wordlist(word_list: List[List[int]]) -> List[List[int]]: @@ - for word_tokens in word_list: + for word_tokens in word_list: + # ignore empty entries; enforce ints + if not word_tokens: + continue + if not all(isinstance(t, int) for t in word_tokens): + raise TypeError("word_list must be List[List[int]]") tokens.extend(word_tokens)tensorrt_llm/_torch/speculative/mtp.py (1)
285-303: Guard against unset seq slots before index_copy_.If any
py_seq_slotis None,index_copy_will fail. Add a quick assert.Apply:
- slots = torch.as_tensor([r.py_seq_slot for r in requests]) + seq_slots_py = [r.py_seq_slot for r in requests] + assert all(s is not None for s in seq_slots_py), "py_seq_slot must be set for all requests" + slots = torch.as_tensor(seq_slots_py)tensorrt_llm/_torch/pyexecutor/sampler.py (1)
515-527: Usepy_seq_slotconsistently.Mixing
request.seq_slotandrequest.py_seq_slotis error-prone. Other code paths rely onpy_seq_slot.Apply:
- new_tokens[i, request.seq_slot, self.BEAM] = new_token + new_tokens[i, request.py_seq_slot, self.BEAM] = new_token @@ - new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token + new_tokens[num_accepted, request.py_seq_slot, self.BEAM] = new_token
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (5)
tensorrt_llm/_torch/pyexecutor/llm_request.py(2 hunks)tensorrt_llm/_torch/pyexecutor/sampler.py(11 hunks)tensorrt_llm/_torch/pyexecutor/sampler_utils.py(0 hunks)tensorrt_llm/_torch/speculative/mtp.py(3 hunks)tests/unittest/_torch/test_torch_sampler.py(0 hunks)
💤 Files with no reviewable changes (2)
- tensorrt_llm/_torch/pyexecutor/sampler_utils.py
- tests/unittest/_torch/test_torch_sampler.py
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tensorrt_llm/_torch/speculative/mtp.pytensorrt_llm/_torch/pyexecutor/llm_request.pytensorrt_llm/_torch/pyexecutor/sampler.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tensorrt_llm/_torch/speculative/mtp.pytensorrt_llm/_torch/pyexecutor/llm_request.pytensorrt_llm/_torch/pyexecutor/sampler.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tensorrt_llm/_torch/speculative/mtp.pytensorrt_llm/_torch/pyexecutor/llm_request.pytensorrt_llm/_torch/pyexecutor/sampler.py
🧠 Learnings (1)
📚 Learning: 2025-08-13T16:20:37.987Z
Learnt from: dcampora
PR: NVIDIA/TensorRT-LLM#6867
File: tensorrt_llm/_torch/pyexecutor/sampler.py:67-72
Timestamp: 2025-08-13T16:20:37.987Z
Learning: In TensorRT-LLM sampler code, performance is prioritized over additional validation checks. The beam_width helper method intentionally returns the first request's beam_width without validating consistency across all requests to avoid performance overhead from iterating through the entire batch.
Applied to files:
tensorrt_llm/_torch/pyexecutor/sampler.py
🧬 Code graph analysis (2)
tensorrt_llm/_torch/speculative/mtp.py (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (9)
SampleState(39-45)SampleStateTensors(30-35)TorchSampler(358-751)add_token(342-351)int_tensor(354-355)Args(374-379)Store(366-368)create_store(370-371)_handle_stop_criteria(440-456)
tensorrt_llm/_torch/pyexecutor/sampler.py (3)
tensorrt_llm/_torch/speculative/mtp.py (4)
Store(225-228)create_store(230-239)update_requests(249-274)sample_async(276-323)tensorrt_llm/_torch/pyexecutor/llm_request.py (3)
LlmRequest(282-424)finish_by(395-398)get_draft_token_length(560-571)tensorrt_llm/_torch/pyexecutor/scheduler.py (2)
ScheduledRequests(18-39)all_requests(38-39)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (7)
tensorrt_llm/_torch/speculative/mtp.py (4)
16-17: Import collapse to canonical API looks good.
230-239: Store creation shape logic: LGTM.Asserts draft_len alignment and allocates tensors on CUDA; matches NEW_TOKENS_SHAPE semantics.
256-262: Stop-criteria handling in MTPSampler.update_requests is correctly centralized.Using
self.BEAMand_handle_stop_criteriakeeps semantics aligned with TorchSampler.Also applies to: 269-271
299-303: CPU host copies: LGTM.Non-blocking D2H transfers are appropriate here.
Also applies to: 310-313
tensorrt_llm/_torch/pyexecutor/sampler.py (2)
664-666: Beam constants usage: LGTM.Hard-coding single-beam via
MAX_BEAM_WIDTH=1is explicit and aligns with the assertion in__init__.Also applies to: 699-701
749-751: Host log-probs writeback: LGTM.The unsqueeze and slice layout match
(max_num_sequences, MAX_BEAM_WIDTH, max_tokens).tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
3-3: Incorrect — repo targets Python 3.10+: do not remove PEP 604/585 typingRepository uses Python 3.10+ features; PEP 604 ('|') and PEP 585 ('list[...]','dict[...]','tuple[...]') appear across the codebase and are expected. Verified occurrences (from search):
- tensorrt_llm/parameter.py:255: def _get_weights(self, network) -> trt.Weights | Tensor | None:
- tensorrt_llm/lora_manager.py:242: lora_target_modules: list[str]
- tensorrt_llm/lora_manager.py:243: trtllm_modules_to_hf_modules: dict[str, str]
- tensorrt_llm/lora_manager.py:663: self, cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None
- tensorrt_llm/lora_manager.py:672: # _lora_uid_to_low_ranks: dict[str -> dict[int -> dict[str -> int]]]
- tensorrt_llm/lora_manager.py:685: # _lora_weights_pointers_list: dict[str -> dict[int -> dict[str -> [Tensor, Tensor]]]]
- tensorrt_llm/builder.py:783: managed_weights: dict[str, np.ndarray] = {},
- tensorrt_llm/builder.py:1027:def serialize_managed_weights(managed_weights: dict[str, np.ndarray],
- tensorrt_llm/functional.py:1185: as_dtype: trt.DataType | None = None,
- tensorrt_llm/functional.py:4498: cp_group: list[int] = None,
- tensorrt_llm/functional.py:4567: cp_group: list[int] = None
- tensorrt_llm/functional.py:6769: out_hidden_sizes: list[int],
- tensorrt_llm/functional.py:6770: lora_weights_pointers: list[Tensor],
- tensorrt_llm/functional.py:6772: host_context_lengths: Tensor | None = None) -> Tensor:
- tensorrt_llm/functional.py:6780: out_hidden_sizes : list[int]
Do not apply the suggested backport to Optional/Union/List/Dict unless the project policy changes to require Python 3.8 support.
Likely an incorrect or invalid review comment.
|
PR_Github #18944 [ run ] triggered by Bot |
|
PR_Github #18945 [ run ] triggered by Bot |
|
PR_Github #18944 [ run ] completed with state |
|
PR_Github #18945 [ run ] completed with state |
Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com>
…-critera-to-sample-async
|
/bot run |
|
PR_Github #19009 [ run ] triggered by Bot |
… format (NVIDIA#7796) * Modified the create_input_processor function to accept a checkpoint_format parameter, defaulting to "HF". * The function now conditionally attempts to load the model configuration based on the specified format. Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
|
PR_Github #19009 [ run ] completed with state |
…-critera-to-sample-async
|
/bot run --disable-fail-fast |
|
PR_Github #19039 [ run ] triggered by Bot |
|
PR_Github #19039 [ run ] completed with state |
…to sample_async (NVIDIA#7041) (NVIDIA#7796) Signed-off-by: Netanel Haber <nhaber@nvidia.com> Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Co-authored-by: Mike Iovine <miovine@nvidia.com>
…to sample_async (NVIDIA#7041) (NVIDIA#7796) Signed-off-by: Netanel Haber <nhaber@nvidia.com> Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Co-authored-by: Mike Iovine <miovine@nvidia.com>
This reverts commit 0fee8cd.
Summary by CodeRabbit
New Features
Refactor
Tests