You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/torch/features/torch_compile_and_piecewise_cuda_graph.md
+38-11Lines changed: 38 additions & 11 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -4,24 +4,45 @@ In this guide, we show how to enable torch.compile and Piecewise CUDA Graph in T
4
4
5
5
Piecewise CUDA Graph is a technique that runs cudagraph-unsupported components (primarily attention) in eager mode while capturing and replaying the supported parts with CUDA Graph to reduce context-phase launch overhead. We implement this on top of torch.compile because partitioning a model between CUDA Graph and eager execution—and managing graphs in pure eager mode—is cumbersome.
6
6
7
+
## Table of Contents
8
+
9
+
-[Torch Compile & Piecewise CUDA Graph](#torch-compile--piecewise-cuda-graph)
10
+
-[Table of Contents](#table-of-contents)
11
+
-[Usage](#usage)
12
+
-[Tips for Piecewise CUDA Graph](#tips-for-piecewise-cuda-graph)
13
+
-[Piecewise CUDA Graph & Generation Only CUDA Graph](#piecewise-cuda-graph--generation-only-cuda-graph)
14
+
-[Piecewise CUDA Graph Padding](#piecewise-cuda-graph-padding)
To enable torch.compile and Piecewise CUDA Graph, add the following configuration to `extra_config.yml`. Typically the `extra_config.yml` can be used by adding launching args `--extra_llm_api_options extra_config.yml` to `trtllm-serve` or `trtllm-bench`.
10
35
11
36
```yaml
12
37
... # Other extra config
13
38
torch_compile_config:
14
-
capture_num_tokens: '${capture_num_tokens}'# List of num tokens to capture
39
+
capture_num_tokens: '${capture_num_tokens}'# List of num tokens to capture. e.g., [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, ..., 3072]
15
40
enable_userbuffers: false
16
41
enable_piecewise_cuda_graph: true
17
42
```
18
43
19
44
## Tips for Piecewise CUDA Graph
20
45
21
-
### Piecewise CUDA Graph Padding
22
-
23
-
Unlike a generation-only CUDA Graph setup, Piecewise CUDA Graph enable padding by default. The token number is padded up to the next captured token number, since token number in the context phase vary widely and it’s impractical to capture graphs for every possible number of tokens.
24
-
25
46
### Piecewise CUDA Graph & Generation Only CUDA Graph
26
47
27
48
The Piecewise CUDA Graph only handles context-only and mixed context+generation iterations, while the generation-only CUDA Graph only handles pure generation iterations. Users need to specify the number of tokens to capture for each type of CUDA Graph separately in the extra config. Currently, the default value for `capture_num_tokens` is `[2**i for i in range(8)] + [i for i in range(256, 3073, 256)]`. However, this configuration should be tuned based on specific hardware, model, and parallel strategy. For guidance on tuning these values, see the [Performance Tuning](#performance-tuning) section below.
@@ -37,9 +58,15 @@ torch_compile_config:
37
58
enable_piecewise_cuda_graph: true
38
59
```
39
60
61
+
### Piecewise CUDA Graph Padding
62
+
63
+
Padding means that, at runtime, the token count is padded to the next captured token count. Unlike generation-only CUDA Graph, padding is mandatory for Piecewise CUDA Graph because context-phase token counts vary widely, making it impractical to capture graphs for every possible length.
64
+
40
65
### Performance Tuning
41
66
42
-
The optimal token counts to capture vary by system hardware configuration and model. Piecewise CUDA Graph primarily benefit host-bound iterations in the context phase. To find the optimal number of tokens to capture, we recommend tuning capture_num_tokens manually.
67
+
The Piecewise CUDA Graph uses a token-count–based capture strategy: it captures a CUDA graph for each user-specified token count and, at runtime, selects and replays the graph that matches the iteration’s token count(or can be padded to the next captured token count graph) in a single forward pass.
68
+
69
+
Piecewise CUDA Graphs primarily benefit host-bound iterations in the context phase. Within a single iteration, larger token counts reduce exposure to host-side overhead. However, capturing a broader set of token counts increases GPU memory usage and can reduce achievable concurrency. We recommend manually tuning capture_num_tokens to balance latency, memory footprint, and concurrency for your workload.
43
70
44
71
Guidelines for `capture_num_tokens`:
45
72
@@ -103,9 +130,9 @@ All fusions are located in `tensorrt_llm/_torch/compilation/patterns` and implem
103
130
- Lists are flattened, turning elements into separate input arguments, making it impossible to match the original operation.
104
131
2. Trace-driven pitfalls: Because it’s trace-based, the generated source patterns may not meet our needs and can introduce additional issues as we expand pattern coverage.
105
132
106
-
We mainly do the operation fusion for allreduce & rms norm.
133
+
We mainly do the operation fusion for AllReduce & RMSNorm.
107
134
108
-
1. AllReduce related fusion: Fuse the following operations into one allreduce op.
135
+
1. AllReduce related fusion: Fuse the following operations into one Allreduce op.
@@ -216,7 +243,7 @@ We implement Piecewise CUDA Graph execution on top of torch.compile: non-captura
216
243
In the current design, we assume the attention block is the only non-capturable component. To maintain stable input pointers across segment boundaries, we convert attention to an in-place variant. Instead of allocating its own output, attention writes results into a tensor preallocated by the preceding CUDA Graph segment. This guarantees that each segment’s inputs are allocated by CUDA Graph and therefore stable for that segment’s capture.
1. Custom op fake kernel: For every custom op, developers must implement a correct fake kernel. **Make sure to update the corresponding fake kernel when the custom op is changed**
230
-
2. Dynamic Iteration Number Loop: This is technically not a trace failure, but it will introduce longtime tracing that generally not acceptable. When torch.compile tries to convert PyTorch modeling code to Fx graph, it will try to unroll the loop. For a loop that has a large and dynamic loop number with large loop body, the tracing process will take a long time to do the unrolling.
257
+
2. Dynamic Iteration Number Loop: This is technically not a trace failure, but it will introduce long-time tracing that generally not acceptable. When torch.compile tries to convert PyTorch modeling code to Fx graph, it will try to unroll the loop. For a loop that has a large and dynamic loop number with large loop body, the tracing process will take a long time to do the unrolling.
231
258
1. If the IO of the loop can be easily written into a custom op format, try to replace it with a custom op
232
259
2. If the loop num is unchanged during the whole inference service lifetime, then it is ok to leave the loop as is. (e.g., Model decoder layer loop)
233
260
@@ -308,7 +335,7 @@ Notes:
308
335
309
336
3. Next power of two: Previously, we used `bit_length()` to implement the next power of 2 function. However, it will cause a recompile for every int value. Now rewrite the code to be torch.compile-friendly.
0 commit comments