Skip to content

Commit bd30b29

Browse files
committed
Address comments
Signed-off-by: Hui Gao <huig@nvidia.com>
1 parent 9f712e1 commit bd30b29

File tree

4 files changed

+101
-118
lines changed

4 files changed

+101
-118
lines changed

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,7 @@ def _post_init_with_buffers(self, buffers) -> None:
621621
capture_graph = torch.cuda.is_current_stream_capturing()
622622

623623
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
624-
cache_name: str, pin_memory: bool) -> torch.Tensor:
624+
cache_name: str) -> torch.Tensor:
625625
"""
626626
Finds a compatible, reusable buffer from a cache or creates a new one.
627627
@@ -637,35 +637,33 @@ def get_empty(tensor_shape: list[int], dtype: torch.dtype,
637637
tensor_shape: The required shape.
638638
dtype: The required dtype.
639639
cache_name: The key for the specific list of buffers to search in.
640-
pin_memory: This buffer block shall be kept in buffer pool if provided
641640
Returns:
642641
An existing compatible buffer or a newly created one.
643642
"""
644643
if buffers is None:
645644
return torch.zeros(tensor_shape, device='cuda', dtype=dtype)
646645

647646
return buffers.get_buffer(tensor_shape, dtype, cache_name,
648-
pin_memory)
647+
capture_graph)
649648

650-
def get_empty_like(like_tensor: torch.Tensor, cache_name: str,
651-
pin_memory: bool) -> torch.Tensor:
649+
def get_empty_like(like_tensor: torch.Tensor,
650+
cache_name: str) -> torch.Tensor:
652651
return get_empty(like_tensor.shape,
653652
cache_name=cache_name,
654-
dtype=like_tensor.dtype,
655-
pin_memory=pin_memory)
653+
dtype=like_tensor.dtype)
656654

657-
self.prompt_lens_cuda = get_empty((self.max_num_sequences, ),
658-
cache_name="prompt_lens_cuda",
659-
dtype=torch.int,
660-
pin_memory=capture_graph)
655+
self.prompt_lens_cuda = get_empty(
656+
(self.max_num_sequences, ),
657+
cache_name="prompt_lens_cuda",
658+
dtype=torch.int,
659+
)
661660
self.prompt_lens_cpu = torch.empty_like(
662661
self.prompt_lens_cuda,
663662
device='cpu',
664663
pin_memory=True,
665664
)
666665
self.kv_lens_cuda = get_empty_like(self.prompt_lens_cuda,
667-
cache_name="kv_lens_cuda",
668-
pin_memory=capture_graph)
666+
cache_name="kv_lens_cuda")
669667
self.kv_lens = torch.empty_like(self.kv_lens_cuda,
670668
device='cpu',
671669
pin_memory=True)
@@ -687,7 +685,7 @@ def get_empty_like(like_tensor: torch.Tensor, cache_name: str,
687685
],
688686
cache_name="kv_cache_block_offsets",
689687
dtype=torch.int32,
690-
pin_memory=capture_graph)
688+
)
691689
self.host_kv_cache_block_offsets = torch.empty_like(
692690
self.kv_cache_block_offsets,
693691
device='cpu',
@@ -703,22 +701,22 @@ def get_empty_like(like_tensor: torch.Tensor, cache_name: str,
703701
],
704702
cache_name="block_ids_per_seq",
705703
dtype=torch.int32,
706-
pin_memory=capture_graph)
704+
)
707705
self.kv_block_ids_per_seq = get_empty(
708706
[
709707
self.kv_cache_manager.max_batch_size,
710708
self.kv_cache_manager.max_blocks_per_seq
711709
],
712710
cache_name="kv_block_ids_per_seq",
713711
dtype=torch.int32,
714-
pin_memory=capture_graph)
712+
)
715713
if self.enable_context_mla_with_cached_kv:
716714
# for kv cache reuse/chunked context in MLA
717715
self.ctx_cached_token_indptr = get_empty(
718716
(self.max_num_requests + 1, ),
719717
cache_name="ctx_cached_token_indptr",
720718
dtype=torch.int64,
721-
pin_memory=capture_graph)
719+
)
722720
self.host_ctx_cached_token_indptr = torch.zeros_like(
723721
self.ctx_cached_token_indptr,
724722
device='cpu',
@@ -728,17 +726,18 @@ def get_empty_like(like_tensor: torch.Tensor, cache_name: str,
728726
(self.max_num_requests + 1, ),
729727
cache_name="ctx_uncached_token_indptr",
730728
dtype=torch.int64,
731-
pin_memory=capture_graph)
729+
)
732730
self.host_ctx_uncached_token_indptr = torch.zeros_like(
733731
self.ctx_uncached_token_indptr,
734732
device='cpu',
735733
pin_memory=True,
736734
)
737735
# context full seqlens include cached tokens and uncached tokens
738-
self.ctx_kv_indptr = get_empty((self.max_num_requests + 1, ),
739-
cache_name="ctx_kv_indptr",
740-
dtype=torch.int64,
741-
pin_memory=capture_graph)
736+
self.ctx_kv_indptr = get_empty(
737+
(self.max_num_requests + 1, ),
738+
cache_name="ctx_kv_indptr",
739+
dtype=torch.int64,
740+
)
742741
self.host_ctx_kv_indptr = torch.zeros_like(
743742
self.ctx_kv_indptr,
744743
device='cpu',

tensorrt_llm/_torch/memory_buffer_utils.py

Lines changed: 73 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -7,109 +7,93 @@
77

88
@dataclass
99
class BufferBlock:
10+
"""A container for a buffer tensor and its state."""
1011
buffer: torch.Tensor = None
11-
pin_memory: bool = False
12-
13-
14-
# Intention to have this buffer is to reuse buffer tensors across graph and non-graph
15-
# situation (across layer/round).
16-
# When forward is under graph capturing, one stream is created and all tensors' memory
17-
# is associated with this stream and be kept in a graph pool. Then, all buffer memory
18-
# allocated during graph capture won't be released back to allocator/system.
19-
# Then, in non-graph mode, additional buffers are allocated which give bigger pressure
20-
# on memory consumption at runtime.
21-
# Timeline example:
22-
# [t0] start cudagraph capture
23-
# [t1] A = torch.zeros(....) -> allocate buffer A and put into graph pool
24-
# [t2] end cudagraph capture
25-
# [t3] in non-graph forward
26-
# [t4] A = torch.zeros(....) -> allocate buffer A in allocator but not use memory in cudagraph pool
27-
# OOM may happen
28-
# TODO:
29-
# The final resolution to this problem shall be supported in pytorch that to allocate memory
30-
# from a give pool, it's the graph pool here.
31-
# It will be like
32-
# try:
33-
# with torch.cuda.use_mem_pool(graphpool):
34-
# allocate_memory_here
35-
# except exception as ex:
36-
# allocate_memory_outside of graphpool
37-
# Need some archeteture change:
38-
# 1. a. set a thread local graphpool context object when cudagraphRunner start a fn
39-
# b. check and get the thread local graphpool
40-
# b. allocate memory
41-
# 2. aggregate workspaces in the same OP to be a big one in graph pool
42-
# allocate memory for the big workspace and slice them into small ones.
43-
# However, in non-graph mode, allocate workspace one by one
12+
is_reserved: bool = False
13+
14+
4415
class Buffers:
16+
"""
17+
Manages and reuses CUDA memory buffers to reduce allocation overhead,
18+
especially when interacting with CUDA graphs.
19+
20+
This class maintains a pool of named buffers. When a buffer is requested,
21+
it tries to find an existing, available buffer that is large enough.
22+
If none is found, a new one is allocated and added to the pool. This helps
23+
avoid repeated allocations, which can be slow and cause memory fragmentation,
24+
particularly when the same operations are run inside and outside of a
25+
CUDA graph context.
26+
"""
4527

4628
def __init__(self):
4729
self.buffers: dict[str, list[BufferBlock]] = {}
4830

49-
def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
50-
buffer_name: str, pin_memory: bool):
31+
@staticmethod
32+
def _view_as(buffer: torch.Tensor, target_shape: list[int],
33+
target_dtype: torch.dtype) -> torch.Tensor:
34+
"""Safely creates a view of a raw byte buffer with the desired shape and dtype."""
35+
# The buffer is stored as uint8, so its numel is its size in bytes.
36+
required_size_in_bytes = math.prod(target_shape) * target_dtype.itemsize
37+
if buffer.numel() < required_size_in_bytes:
38+
raise ValueError(
39+
"Buffer is too small for the requested shape and dtype.")
5140

52-
def select_buffer_with_more_elements(
53-
pinned_buffer: Optional[torch.Tensor],
54-
runtime_buffer: Optional[torch.Tensor]
55-
) -> tuple[Optional[torch.Tensor]]:
56-
if pinned_buffer is None:
57-
return runtime_buffer
58-
if runtime_buffer is None:
59-
return pinned_buffer
41+
# Slice the buffer to the exact required size, then view it with the correct type and shape.
42+
return buffer[:required_size_in_bytes].view(target_dtype).view(
43+
target_shape)
6044

61-
return runtime_buffer if runtime_buffer.buffer.numel(
62-
) > pinned_buffer.buffer.numel() else pinned_buffer
63-
64-
def view_to(buffer: torch.Tensor, dtype: torch.dtype,
65-
tensor_shape: list[int]) -> torch.Tensor:
66-
return buffer[0:math.prod(tensor_shape) *
67-
dtype.itemsize].view(dtype).view(tensor_shape)
45+
def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype,
46+
buffer_name: str, reserve_buffer: bool):
6847

6948
# all buffers are allocated with 1 byte element size
70-
element_size = dtype.itemsize
71-
required_memory_size = math.prod(tensor_shape) * element_size
72-
candidate_buffers = self.buffers.get(buffer_name, [])
73-
pinned_buffer = None
74-
free_buffer = None
75-
for buffer in candidate_buffers:
76-
buffer_size = buffer.buffer.numel()
77-
if buffer_size >= required_memory_size:
78-
if buffer.pin_memory:
79-
pinned_buffer = buffer
80-
else:
81-
free_buffer = buffer
82-
83-
if free_buffer is not None and pinned_buffer is not None:
84-
break
85-
86-
if pin_memory:
87-
if pinned_buffer is not None:
88-
return view_to(pinned_buffer.buffer, dtype, tensor_shape)
89-
elif free_buffer is not None:
90-
free_buffer.pin_memory = True
91-
return view_to(free_buffer.buffer, dtype, tensor_shape)
92-
93-
if buffer_name in self.buffers:
94-
candidate_buffers = self.buffers.get(buffer_name, [])
95-
for buffer in list(candidate_buffers):
96-
if not buffer.pin_memory:
97-
# Need to call del BufferBlock.buffer, otherwise memory isn't
98-
# released and OOM may happen.
99-
del buffer.buffer
100-
candidate_buffers.remove(buffer)
101-
102-
new_buffer = torch.zeros((required_memory_size, ),
103-
device='cuda',
104-
dtype=torch.uint8)
105-
self.buffers.setdefault(buffer_name, []).append(
106-
BufferBlock(buffer=new_buffer, pin_memory=pin_memory))
107-
return view_to(new_buffer, dtype, tensor_shape)
49+
required_memory_size = math.prod(tensor_shape) * dtype.itemsize
50+
candidate_blocks = self.buffers.get(buffer_name, [])
51+
52+
# Find the best-fit available buffer.
53+
best_fit_block: Optional[BufferBlock] = None
54+
smallest_sufficient_size = float('inf')
55+
for block in candidate_blocks:
56+
# Skip buffers that are too small.
57+
if block.buffer.numel() < required_memory_size:
58+
continue
59+
60+
# Find the smallest buffer that is still large enough (best-fit).
61+
if block.buffer.numel() < smallest_sufficient_size:
62+
# Use reserved block if find one.
63+
if best_fit_block is not None and best_fit_block.is_reserved and not block.is_reserved:
64+
continue
65+
66+
best_fit_block = block
67+
smallest_sufficient_size = block.buffer.numel()
68+
69+
if reserve_buffer and best_fit_block is not None:
70+
# A suitable buffer was found, so reuse it.
71+
best_fit_block.is_reserved = True
72+
return self._view_as(best_fit_block.buffer, tensor_shape, dtype)
73+
74+
for block in list(candidate_blocks):
75+
if not block.is_reserved:
76+
# Need to call del BufferBlock.buffer, otherwise memory isn't
77+
# released and OOM may happen.
78+
del block.buffer
79+
candidate_blocks.remove(block)
80+
81+
# No suitable buffer was found, so allocate a new one.
82+
# The new buffer is created with uint8 to represent raw bytes.
83+
new_buffer_tensor = torch.zeros((required_memory_size, ),
84+
device='cuda',
85+
dtype=torch.uint8)
86+
new_block = BufferBlock(buffer=new_buffer_tensor,
87+
is_reserved=reserve_buffer)
88+
89+
# Add the new buffer to the pool for this name.
90+
self.buffers.setdefault(buffer_name, []).append(new_block)
91+
return self._view_as(new_block.buffer, tensor_shape, dtype)
10892

10993

11094
_buffer = Buffers()
11195

11296

113-
def get_memory_buffer():
97+
def get_memory_buffers():
11498
global _buffer
11599
return _buffer

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from tensorrt_llm._utils import nvtx_range
1010

1111
from ...distributed import allgather
12-
from ...memory_buffer_utils import get_memory_buffer
12+
from ...memory_buffer_utils import get_memory_buffers
1313
from ...model_config import ModelConfig
1414
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
1515
from .fused_moe_cutlass import CutlassFusedMoE
@@ -367,7 +367,7 @@ class DeepGemmFusedMoE(CutlassFusedMoE):
367367
"""
368368

369369
# To reuse pytorch memory segments allocated during graph capture.
370-
buffers = get_memory_buffer()
370+
buffers = get_memory_buffers()
371371

372372
def __init__(
373373
self,
@@ -425,12 +425,12 @@ def get_workspace(self, m_max: int, group_size: int):
425425
(num_experts * m_max * fp8_dim, ),
426426
dtype=torch.float8_e4m3fn,
427427
buffer_name='workspace_0',
428-
pin_memory=capture_graph)
428+
reserve_buffer=capture_graph)
429429
workspace_1 = DeepGemmFusedMoE.buffers.get_buffer(
430430
(num_experts * m_max * max(intermediate_size * 2, hidden_size), ),
431431
dtype=torch.bfloat16,
432432
buffer_name='workspace_1',
433-
pin_memory=capture_graph)
433+
reserve_buffer=capture_graph)
434434

435435
# create workspace for scaling factors
436436
m_padded = fp8_utils.align(m_max, 4)
@@ -441,7 +441,7 @@ def get_workspace(self, m_max: int, group_size: int):
441441
(num_experts * (scale_k_padded // 4) * m_padded, ),
442442
dtype=torch.int32,
443443
buffer_name='workspace_sf',
444-
pin_memory=capture_graph)
444+
reserve_buffer=capture_graph)
445445

446446
workspace = {
447447
"workspace_0": workspace_0,

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from ...inputs.multimodal import MultimodalParams
99
from ..expert_statistic import ExpertStatistic
10-
from ..memory_buffer_utils import get_memory_buffer
10+
from ..memory_buffer_utils import get_memory_buffers
1111
from ..modules.multi_stream_utils import with_multi_stream
1212
from ..speculative.eagle3 import Eagle3ResourceManager
1313
from ..utils import make_weak_ref, piecewise_cuda_graph
@@ -54,7 +54,7 @@ def __init__(self, engine: "PyTorchModelEngine"):
5454
self.shared_static_tensors: Dict[str, torch.Tensor] = {}
5555
if self.enabled:
5656
self._create_shared_static_tensors()
57-
self.cuda_graph_meta_buffers = get_memory_buffer()
57+
self.cuda_graph_meta_buffers = get_memory_buffers()
5858

5959
def _create_shared_static_tensors(self):
6060
"""Allocates static tensors sized for the largest possible batch."""

0 commit comments

Comments
 (0)