|
7 | 7 |
|
8 | 8 | @dataclass |
9 | 9 | class BufferBlock: |
| 10 | + """A container for a buffer tensor and its state.""" |
10 | 11 | buffer: torch.Tensor = None |
11 | | - pin_memory: bool = False |
12 | | - |
13 | | - |
14 | | -# Intention to have this buffer is to reuse buffer tensors across graph and non-graph |
15 | | -# situation (across layer/round). |
16 | | -# When forward is under graph capturing, one stream is created and all tensors' memory |
17 | | -# is associated with this stream and be kept in a graph pool. Then, all buffer memory |
18 | | -# allocated during graph capture won't be released back to allocator/system. |
19 | | -# Then, in non-graph mode, additional buffers are allocated which give bigger pressure |
20 | | -# on memory consumption at runtime. |
21 | | -# Timeline example: |
22 | | -# [t0] start cudagraph capture |
23 | | -# [t1] A = torch.zeros(....) -> allocate buffer A and put into graph pool |
24 | | -# [t2] end cudagraph capture |
25 | | -# [t3] in non-graph forward |
26 | | -# [t4] A = torch.zeros(....) -> allocate buffer A in allocator but not use memory in cudagraph pool |
27 | | -# OOM may happen |
28 | | -# TODO: |
29 | | -# The final resolution to this problem shall be supported in pytorch that to allocate memory |
30 | | -# from a give pool, it's the graph pool here. |
31 | | -# It will be like |
32 | | -# try: |
33 | | -# with torch.cuda.use_mem_pool(graphpool): |
34 | | -# allocate_memory_here |
35 | | -# except exception as ex: |
36 | | -# allocate_memory_outside of graphpool |
37 | | -# Need some archeteture change: |
38 | | -# 1. a. set a thread local graphpool context object when cudagraphRunner start a fn |
39 | | -# b. check and get the thread local graphpool |
40 | | -# b. allocate memory |
41 | | -# 2. aggregate workspaces in the same OP to be a big one in graph pool |
42 | | -# allocate memory for the big workspace and slice them into small ones. |
43 | | -# However, in non-graph mode, allocate workspace one by one |
| 12 | + is_reserved: bool = False |
| 13 | + |
| 14 | + |
44 | 15 | class Buffers: |
| 16 | + """ |
| 17 | + Manages and reuses CUDA memory buffers to reduce allocation overhead, |
| 18 | + especially when interacting with CUDA graphs. |
| 19 | +
|
| 20 | + This class maintains a pool of named buffers. When a buffer is requested, |
| 21 | + it tries to find an existing, available buffer that is large enough. |
| 22 | + If none is found, a new one is allocated and added to the pool. This helps |
| 23 | + avoid repeated allocations, which can be slow and cause memory fragmentation, |
| 24 | + particularly when the same operations are run inside and outside of a |
| 25 | + CUDA graph context. |
| 26 | + """ |
45 | 27 |
|
46 | 28 | def __init__(self): |
47 | 29 | self.buffers: dict[str, list[BufferBlock]] = {} |
48 | 30 |
|
49 | | - def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype, |
50 | | - buffer_name: str, pin_memory: bool): |
| 31 | + @staticmethod |
| 32 | + def _view_as(buffer: torch.Tensor, target_shape: list[int], |
| 33 | + target_dtype: torch.dtype) -> torch.Tensor: |
| 34 | + """Safely creates a view of a raw byte buffer with the desired shape and dtype.""" |
| 35 | + # The buffer is stored as uint8, so its numel is its size in bytes. |
| 36 | + required_size_in_bytes = math.prod(target_shape) * target_dtype.itemsize |
| 37 | + if buffer.numel() < required_size_in_bytes: |
| 38 | + raise ValueError( |
| 39 | + "Buffer is too small for the requested shape and dtype.") |
51 | 40 |
|
52 | | - def select_buffer_with_more_elements( |
53 | | - pinned_buffer: Optional[torch.Tensor], |
54 | | - runtime_buffer: Optional[torch.Tensor] |
55 | | - ) -> tuple[Optional[torch.Tensor]]: |
56 | | - if pinned_buffer is None: |
57 | | - return runtime_buffer |
58 | | - if runtime_buffer is None: |
59 | | - return pinned_buffer |
| 41 | + # Slice the buffer to the exact required size, then view it with the correct type and shape. |
| 42 | + return buffer[:required_size_in_bytes].view(target_dtype).view( |
| 43 | + target_shape) |
60 | 44 |
|
61 | | - return runtime_buffer if runtime_buffer.buffer.numel( |
62 | | - ) > pinned_buffer.buffer.numel() else pinned_buffer |
63 | | - |
64 | | - def view_to(buffer: torch.Tensor, dtype: torch.dtype, |
65 | | - tensor_shape: list[int]) -> torch.Tensor: |
66 | | - return buffer[0:math.prod(tensor_shape) * |
67 | | - dtype.itemsize].view(dtype).view(tensor_shape) |
| 45 | + def get_buffer(self, tensor_shape: list[int], dtype: torch.dtype, |
| 46 | + buffer_name: str, reserve_buffer: bool): |
68 | 47 |
|
69 | 48 | # all buffers are allocated with 1 byte element size |
70 | | - element_size = dtype.itemsize |
71 | | - required_memory_size = math.prod(tensor_shape) * element_size |
72 | | - candidate_buffers = self.buffers.get(buffer_name, []) |
73 | | - pinned_buffer = None |
74 | | - free_buffer = None |
75 | | - for buffer in candidate_buffers: |
76 | | - buffer_size = buffer.buffer.numel() |
77 | | - if buffer_size >= required_memory_size: |
78 | | - if buffer.pin_memory: |
79 | | - pinned_buffer = buffer |
80 | | - else: |
81 | | - free_buffer = buffer |
82 | | - |
83 | | - if free_buffer is not None and pinned_buffer is not None: |
84 | | - break |
85 | | - |
86 | | - if pin_memory: |
87 | | - if pinned_buffer is not None: |
88 | | - return view_to(pinned_buffer.buffer, dtype, tensor_shape) |
89 | | - elif free_buffer is not None: |
90 | | - free_buffer.pin_memory = True |
91 | | - return view_to(free_buffer.buffer, dtype, tensor_shape) |
92 | | - |
93 | | - if buffer_name in self.buffers: |
94 | | - candidate_buffers = self.buffers.get(buffer_name, []) |
95 | | - for buffer in list(candidate_buffers): |
96 | | - if not buffer.pin_memory: |
97 | | - # Need to call del BufferBlock.buffer, otherwise memory isn't |
98 | | - # released and OOM may happen. |
99 | | - del buffer.buffer |
100 | | - candidate_buffers.remove(buffer) |
101 | | - |
102 | | - new_buffer = torch.zeros((required_memory_size, ), |
103 | | - device='cuda', |
104 | | - dtype=torch.uint8) |
105 | | - self.buffers.setdefault(buffer_name, []).append( |
106 | | - BufferBlock(buffer=new_buffer, pin_memory=pin_memory)) |
107 | | - return view_to(new_buffer, dtype, tensor_shape) |
| 49 | + required_memory_size = math.prod(tensor_shape) * dtype.itemsize |
| 50 | + candidate_blocks = self.buffers.get(buffer_name, []) |
| 51 | + |
| 52 | + # Find the best-fit available buffer. |
| 53 | + best_fit_block: Optional[BufferBlock] = None |
| 54 | + smallest_sufficient_size = float('inf') |
| 55 | + for block in candidate_blocks: |
| 56 | + # Skip buffers that are too small. |
| 57 | + if block.buffer.numel() < required_memory_size: |
| 58 | + continue |
| 59 | + |
| 60 | + # Find the smallest buffer that is still large enough (best-fit). |
| 61 | + if block.buffer.numel() < smallest_sufficient_size: |
| 62 | + # Use reserved block if find one. |
| 63 | + if best_fit_block is not None and best_fit_block.is_reserved and not block.is_reserved: |
| 64 | + continue |
| 65 | + |
| 66 | + best_fit_block = block |
| 67 | + smallest_sufficient_size = block.buffer.numel() |
| 68 | + |
| 69 | + if reserve_buffer and best_fit_block is not None: |
| 70 | + # A suitable buffer was found, so reuse it. |
| 71 | + best_fit_block.is_reserved = True |
| 72 | + return self._view_as(best_fit_block.buffer, tensor_shape, dtype) |
| 73 | + |
| 74 | + for block in list(candidate_blocks): |
| 75 | + if not block.is_reserved: |
| 76 | + # Need to call del BufferBlock.buffer, otherwise memory isn't |
| 77 | + # released and OOM may happen. |
| 78 | + del block.buffer |
| 79 | + candidate_blocks.remove(block) |
| 80 | + |
| 81 | + # No suitable buffer was found, so allocate a new one. |
| 82 | + # The new buffer is created with uint8 to represent raw bytes. |
| 83 | + new_buffer_tensor = torch.zeros((required_memory_size, ), |
| 84 | + device='cuda', |
| 85 | + dtype=torch.uint8) |
| 86 | + new_block = BufferBlock(buffer=new_buffer_tensor, |
| 87 | + is_reserved=reserve_buffer) |
| 88 | + |
| 89 | + # Add the new buffer to the pool for this name. |
| 90 | + self.buffers.setdefault(buffer_name, []).append(new_block) |
| 91 | + return self._view_as(new_block.buffer, tensor_shape, dtype) |
108 | 92 |
|
109 | 93 |
|
110 | 94 | _buffer = Buffers() |
111 | 95 |
|
112 | 96 |
|
113 | | -def get_memory_buffer(): |
| 97 | +def get_memory_buffers(): |
114 | 98 | global _buffer |
115 | 99 | return _buffer |
0 commit comments