Skip to content

Commit f14dda8

Browse files
Update docs/source/torch/features/torch_compile_and_piecewise_cuda_graph.md
Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com> Signed-off-by: Yi Zhang <187001205+yizhang-nv@users.noreply.github.com>
1 parent b44135d commit f14dda8

File tree

1 file changed

+38
-11
lines changed

1 file changed

+38
-11
lines changed

docs/source/torch/features/torch_compile_and_piecewise_cuda_graph.md

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,45 @@ In this guide, we show how to enable torch.compile and Piecewise CUDA Graph in T
44

55
Piecewise CUDA Graph is a technique that runs cudagraph-unsupported components (primarily attention) in eager mode while capturing and replaying the supported parts with CUDA Graph to reduce context-phase launch overhead. We implement this on top of torch.compile because partitioning a model between CUDA Graph and eager execution—and managing graphs in pure eager mode—is cumbersome.
66

7+
## Table of Contents
8+
9+
- [Torch Compile & Piecewise CUDA Graph](#torch-compile--piecewise-cuda-graph)
10+
- [Table of Contents](#table-of-contents)
11+
- [Usage](#usage)
12+
- [Tips for Piecewise CUDA Graph](#tips-for-piecewise-cuda-graph)
13+
- [Piecewise CUDA Graph & Generation Only CUDA Graph](#piecewise-cuda-graph--generation-only-cuda-graph)
14+
- [Piecewise CUDA Graph Padding](#piecewise-cuda-graph-padding)
15+
- [Performance Tuning](#performance-tuning)
16+
- [Known Issue](#known-issue)
17+
- [Development Guide](#development-guide)
18+
- [TensorRT LLM Custom Backend](#tensorrt-llm-custom-backend)
19+
- [Torch IR Optimization](#torch-ir-optimization)
20+
- [ATen IR Optimization](#aten-ir-optimization)
21+
- [Operation Fusion](#operation-fusion)
22+
- [Re-inplace Optimization](#re-inplace-optimization)
23+
- [Auto Multi-stream](#auto-multi-stream)
24+
- [Background Knowledge](#background-knowledge)
25+
- [Custom Op](#custom-op)
26+
- [Current Status](#current-status)
27+
- [Piecewise CUDA Graph](#piecewise-cuda-graph)
28+
- [Common Trace Failure](#common-trace-failure)
29+
- [Graph Break](#graph-break)
30+
- [Recompilation](#recompilation)
31+
732
## Usage
833

934
To enable torch.compile and Piecewise CUDA Graph, add the following configuration to `extra_config.yml`. Typically the `extra_config.yml` can be used by adding launching args `--extra_llm_api_options extra_config.yml` to `trtllm-serve` or `trtllm-bench`.
1035

1136
```yaml
1237
... # Other extra config
1338
torch_compile_config:
14-
capture_num_tokens: '${capture_num_tokens}' # List of num tokens to capture
39+
capture_num_tokens: '${capture_num_tokens}' # List of num tokens to capture. e.g., [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, ..., 3072]
1540
enable_userbuffers: false
1641
enable_piecewise_cuda_graph: true
1742
```
1843
1944
## Tips for Piecewise CUDA Graph
2045
21-
### Piecewise CUDA Graph Padding
22-
23-
Unlike a generation-only CUDA Graph setup, Piecewise CUDA Graph enable padding by default. The token number is padded up to the next captured token number, since token number in the context phase vary widely and it’s impractical to capture graphs for every possible number of tokens.
24-
2546
### Piecewise CUDA Graph & Generation Only CUDA Graph
2647
2748
The Piecewise CUDA Graph only handles context-only and mixed context+generation iterations, while the generation-only CUDA Graph only handles pure generation iterations. Users need to specify the number of tokens to capture for each type of CUDA Graph separately in the extra config. Currently, the default value for `capture_num_tokens` is `[2**i for i in range(8)] + [i for i in range(256, 3073, 256)]`. However, this configuration should be tuned based on specific hardware, model, and parallel strategy. For guidance on tuning these values, see the [Performance Tuning](#performance-tuning) section below.
@@ -37,9 +58,15 @@ torch_compile_config:
3758
enable_piecewise_cuda_graph: true
3859
```
3960

61+
### Piecewise CUDA Graph Padding
62+
63+
Padding means that, at runtime, the token count is padded to the next captured token count. Unlike generation-only CUDA Graph, padding is mandatory for Piecewise CUDA Graph because context-phase token counts vary widely, making it impractical to capture graphs for every possible length.
64+
4065
### Performance Tuning
4166

42-
The optimal token counts to capture vary by system hardware configuration and model. Piecewise CUDA Graph primarily benefit host-bound iterations in the context phase. To find the optimal number of tokens to capture, we recommend tuning capture_num_tokens manually.
67+
The Piecewise CUDA Graph uses a token-count–based capture strategy: it captures a CUDA graph for each user-specified token count and, at runtime, selects and replays the graph that matches the iteration’s token count(or can be padded to the next captured token count graph) in a single forward pass.
68+
69+
Piecewise CUDA Graphs primarily benefit host-bound iterations in the context phase. Within a single iteration, larger token counts reduce exposure to host-side overhead. However, capturing a broader set of token counts increases GPU memory usage and can reduce achievable concurrency. We recommend manually tuning capture_num_tokens to balance latency, memory footprint, and concurrency for your workload.
4370

4471
Guidelines for `capture_num_tokens`:
4572

@@ -103,9 +130,9 @@ All fusions are located in `tensorrt_llm/_torch/compilation/patterns` and implem
103130
- Lists are flattened, turning elements into separate input arguments, making it impossible to match the original operation.
104131
2. Trace-driven pitfalls: Because it’s trace-based, the generated source patterns may not meet our needs and can introduce additional issues as we expand pattern coverage.
105132

106-
We mainly do the operation fusion for allreduce & rms norm.
133+
We mainly do the operation fusion for AllReduce & RMSNorm.
107134

108-
1. AllReduce related fusion: Fuse the following operations into one allreduce op.
135+
1. AllReduce related fusion: Fuse the following operations into one Allreduce op.
109136
+ AllReduce + Residual + RMSNorm
110137
+ AllReduce + Residual + RMSNorm + FP8 Quantization
111138
+ AllReduce + Residual + RMSNorm + FP4 Quantization
@@ -216,7 +243,7 @@ We implement Piecewise CUDA Graph execution on top of torch.compile: non-captura
216243
In the current design, we assume the attention block is the only non-capturable component. To maintain stable input pointers across segment boundaries, we convert attention to an in-place variant. Instead of allocating its own output, attention writes results into a tensor preallocated by the preceding CUDA Graph segment. This guarantees that each segment’s inputs are allocated by CUDA Graph and therefore stable for that segment’s capture.
217244

218245
<p align="center">
219-
<img src="../../media/piecewise_runner.png" alt="Piecewise Runner" width=35% height=35% />
246+
<img src="../../media/piecewise_runner.svg" alt="Piecewise Runner" width=35% height=35% />
220247
</p>
221248

222249
Notes:
@@ -227,7 +254,7 @@ Notes:
227254
### Common Trace Failure
228255

229256
1. Custom op fake kernel: For every custom op, developers must implement a correct fake kernel. **Make sure to update the corresponding fake kernel when the custom op is changed**
230-
2. Dynamic Iteration Number Loop: This is technically not a trace failure, but it will introduce long time tracing that generally not acceptable. When torch.compile tries to convert PyTorch modeling code to Fx graph, it will try to unroll the loop. For a loop that has a large and dynamic loop number with large loop body, the tracing process will take a long time to do the unrolling.
257+
2. Dynamic Iteration Number Loop: This is technically not a trace failure, but it will introduce long-time tracing that generally not acceptable. When torch.compile tries to convert PyTorch modeling code to Fx graph, it will try to unroll the loop. For a loop that has a large and dynamic loop number with large loop body, the tracing process will take a long time to do the unrolling.
231258
1. If the IO of the loop can be easily written into a custom op format, try to replace it with a custom op
232259
2. If the loop num is unchanged during the whole inference service lifetime, then it is ok to leave the loop as is. (e.g., Model decoder layer loop)
233260

@@ -308,7 +335,7 @@ Notes:
308335

309336
3. Next power of two: Previously, we used `bit_length()` to implement the next power of 2 function. However, it will cause a recompile for every int value. Now rewrite the code to be torch.compile-friendly.
310337

311-
```
338+
```python
312339
def next_positive_power_of_2(x: int) -> int:
313340
if x < 1:
314341
return 1

0 commit comments

Comments
 (0)