Skip to content

Commit 1c36f62

Browse files
committed
[TRTLLM-7073][feat] Support torch compile for PP for Llama and DeepSeekV3
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
1 parent 5792464 commit 1c36f62

File tree

10 files changed

+88
-140
lines changed

10 files changed

+88
-140
lines changed

tensorrt_llm/_torch/compilation/backend.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import tensorrt_llm
1313
from tensorrt_llm import logger
14+
from tensorrt_llm.mapping import Mapping
1415

1516
from .multi_stream.auto_multi_stream import multi_stream_schedule
1617
from .patterns.ar_residual_norm import register_ar_fusions
@@ -39,13 +40,16 @@ def __init__(
3940
enable_piecewise_cuda_graph: bool = False,
4041
capture_num_tokens: Optional[List[int]] = None,
4142
max_num_streams: int = 1,
43+
mapping=None,
4244
) -> None:
4345
super().__init__()
4446
self.elapsed_time = 0
4547
self.module_inference_event = []
4648
self.module_inference_time = 0
4749
self.call_count = 0
48-
self.custom_passes = Backend.get_custom_pass(enable_userbuffers)
50+
self.mapping = mapping
51+
self.custom_passes = Backend.get_custom_pass(enable_userbuffers,
52+
mapping)
4953
self.rank = tensorrt_llm.mpi_rank()
5054
self.enable_inductor = enable_inductor
5155
self.capture_num_tokens = sorted(capture_num_tokens or [])
@@ -63,8 +67,7 @@ def __init__(
6367
self.match_count = []
6468

6569
@classmethod
66-
def get_custom_pass(cls, enable_userbuffers):
67-
# TODO: add pp + tp support
70+
def get_custom_pass(cls, enable_userbuffers, mapping: Mapping):
6871
world_size = tensorrt_llm.mpi_world_size()
6972
if not cls._custom_pass_instances:
7073
# Really naive pass manager here
@@ -75,7 +78,8 @@ def get_custom_pass(cls, enable_userbuffers):
7578
os.environ["DISABLE_LAMPORT_REDUCE_NORM_FUSION"] = "1"
7679
ub_enabled = enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported(
7780
)
78-
register_ar_fusions(cls._custom_pass_instances, ub_enabled)
81+
register_ar_fusions(cls._custom_pass_instances, mapping,
82+
ub_enabled)
7983
else:
8084
register_add_norm(cls._custom_pass_instances[0])
8185
return cls._custom_pass_instances
@@ -150,6 +154,11 @@ def __call__(self, gm: GraphModule,
150154
assert isinstance(example_value, FakeTensor)
151155
self.input_num_tokens = example_value.shape[0]
152156
break
157+
if node.name == "l_position_ids_":
158+
example_value = node.meta["example_value"]
159+
assert isinstance(example_value, FakeTensor)
160+
self.input_num_tokens = example_value.shape[-1]
161+
break
153162

154163
if self.piecewise_cuda_graph:
155164
assert (

tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,9 @@ def flatten_args(args):
209209
for inplace_arg in inplace_map[func].values():
210210
# At this stage, all inplace op must be using kwargs for all params
211211
assert inplace_arg in node.kwargs
212-
latest_inplace_stat[node.kwargs[inplace_arg]] = vertex
212+
args = flatten_args(node.kwargs[inplace_arg])
213+
for arg in args:
214+
latest_inplace_stat[arg] = vertex
213215

214216
for edge in in_edges.values():
215217
edge.out_edges.append(vertex)

tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py

Lines changed: 19 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,14 @@
88
PatternMatcherPass, fwd_only,
99
register_replacement)
1010

11-
import tensorrt_llm
12-
1311
from ...distributed import AllReduceFusionOp, AllReduceStrategy
1412

1513
aten = torch.ops.aten
1614
from tensorrt_llm.mapping import Mapping
1715

1816

19-
def register_ar_residual_norm(custom_pass: PatternMatcherPass):
20-
# TODO: add pp + tp support
21-
mapping = Mapping(
22-
world_size=tensorrt_llm.mpi_world_size(),
23-
tp_size=tensorrt_llm.mpi_world_size(),
24-
rank=tensorrt_llm.mpi_rank(),
25-
)
17+
def register_ar_residual_norm(custom_pass: PatternMatcherPass,
18+
mapping: Mapping):
2619
residual_key = KeywordArg("residual")
2720
trtllm_allreduce_default = CallFunction(
2821
torch.ops.trtllm.allreduce.default, KeywordArg("input"), None, None,
@@ -117,14 +110,8 @@ def check_non_ub_strategy(match, strategy_node) -> bool:
117110
return True
118111

119112

120-
def register_ar_residual_norm_out_fp8_quant(custom_pass: PatternMatcherPass):
121-
# TODO: add pp + tp support
122-
mapping = Mapping(
123-
world_size=tensorrt_llm.mpi_world_size(),
124-
tp_size=tensorrt_llm.mpi_world_size(),
125-
rank=tensorrt_llm.mpi_rank(),
126-
)
127-
113+
def register_ar_residual_norm_out_fp8_quant(custom_pass: PatternMatcherPass,
114+
mapping: Mapping):
128115
input_node = KeywordArg("input")
129116
strategy_node = KeywordArg("strategy")
130117
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
@@ -200,14 +187,8 @@ def extra_check(match: Match) -> bool:
200187
)
201188

202189

203-
def register_ar_residual_norm_fp8_quant(custom_pass: PatternMatcherPass):
204-
# TODO: add pp + tp support
205-
mapping = Mapping(
206-
world_size=tensorrt_llm.mpi_world_size(),
207-
tp_size=tensorrt_llm.mpi_world_size(),
208-
rank=tensorrt_llm.mpi_rank(),
209-
)
210-
190+
def register_ar_residual_norm_fp8_quant(custom_pass: PatternMatcherPass,
191+
mapping: Mapping):
211192
input_node = KeywordArg("input")
212193
strategy_node = KeywordArg("strategy")
213194
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
@@ -282,14 +263,8 @@ def extra_check(match: Match) -> bool:
282263
)
283264

284265

285-
def register_ar_residual_norm_out_fp4_quant(custom_pass: PatternMatcherPass):
286-
# TODO: add pp + tp support
287-
mapping = Mapping(
288-
world_size=tensorrt_llm.mpi_world_size(),
289-
tp_size=tensorrt_llm.mpi_world_size(),
290-
rank=tensorrt_llm.mpi_rank(),
291-
)
292-
266+
def register_ar_residual_norm_out_fp4_quant(custom_pass: PatternMatcherPass,
267+
mapping: Mapping):
293268
input_node = KeywordArg("input")
294269
strategy_node = KeywordArg("strategy")
295270
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
@@ -360,14 +335,8 @@ def extra_check(match: Match) -> bool:
360335
)
361336

362337

363-
def register_ar_residual_norm_fp4_quant(custom_pass: PatternMatcherPass):
364-
# TODO: add pp + tp support
365-
mapping = Mapping(
366-
world_size=tensorrt_llm.mpi_world_size(),
367-
tp_size=tensorrt_llm.mpi_world_size(),
368-
rank=tensorrt_llm.mpi_rank(),
369-
)
370-
338+
def register_ar_residual_norm_fp4_quant(custom_pass: PatternMatcherPass,
339+
mapping: Mapping):
371340
input_node = KeywordArg("input")
372341
strategy_node = KeywordArg("strategy")
373342
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
@@ -437,12 +406,8 @@ def extra_check(match: Match) -> bool:
437406
)
438407

439408

440-
def register_ub_patterns(custom_passes: List[PatternMatcherPass]):
441-
mapping = Mapping(
442-
world_size=tensorrt_llm.mpi_world_size(),
443-
tp_size=tensorrt_llm.mpi_world_size(),
444-
rank=tensorrt_llm.mpi_rank(),
445-
)
409+
def register_ub_patterns(custom_passes: List[PatternMatcherPass],
410+
mapping: Mapping):
446411

447412
def register_convert_supported_ar_to_ub(custom_pass: PatternMatcherPass):
448413
strategy = int(AllReduceStrategy.AUTO)
@@ -717,16 +682,16 @@ def target_finalize_pattern(
717682

718683

719684
def register_ar_fusions(custom_passes: List[PatternMatcherPass],
720-
enable_ub: bool):
721-
register_ar_residual_norm(custom_passes[-1])
685+
mapping: Mapping, enable_ub: bool):
686+
register_ar_residual_norm(custom_passes[-1], mapping)
722687

723688
custom_passes.append(PatternMatcherPass())
724-
register_ar_residual_norm_fp8_quant(custom_passes[-1])
725-
register_ar_residual_norm_fp4_quant(custom_passes[-1])
689+
register_ar_residual_norm_fp8_quant(custom_passes[-1], mapping)
690+
register_ar_residual_norm_fp4_quant(custom_passes[-1], mapping)
726691
# AR-Residual-Norm-Out-Quant-X is not supported by Userbuffers kernel.
727692
if not enable_ub:
728-
register_ar_residual_norm_out_fp8_quant(custom_passes[-1])
729-
register_ar_residual_norm_out_fp4_quant(custom_passes[-1])
693+
register_ar_residual_norm_out_fp8_quant(custom_passes[-1], mapping)
694+
register_ar_residual_norm_out_fp4_quant(custom_passes[-1], mapping)
730695

731696
if enable_ub:
732-
register_ub_patterns(custom_passes)
697+
register_ub_patterns(custom_passes, mapping)

tensorrt_llm/_torch/compilation/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,12 @@ def inplace_info():
7676
},
7777
torch.ops.trtllm.logits_bitmask.default: {
7878
1: "logits"
79+
},
80+
torch.ops.trtllm.pp_recv.default: {
81+
1: "tensors"
82+
},
83+
torch.ops.trtllm.pp_send.default: {
84+
1: "tensors"
7985
}
8086
}
8187
return inplace_map

tensorrt_llm/_torch/distributed/communicator.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,11 +477,14 @@ def init_pp_comm(mapping):
477477
_pp_comm = PPComm(mapping)
478478

479479

480-
def pp_recv(tensor):
480+
@torch.library.custom_op("trtllm::pp_recv", mutates_args=("tensors", ))
481+
def pp_recv(tensors: List[torch.Tensor]) -> None:
481482
"""Receive tensors from previous pp rank."""
482-
_pp_comm.recv(tensor)
483+
for tensor in tensors:
484+
_pp_comm.recv(tensor)
483485

484486

485-
def pp_send(tensor):
487+
@torch.library.custom_op("trtllm::pp_send", mutates_args=("tensors", ))
488+
def pp_send(tensors: List[torch.Tensor]) -> None:
486489
"""Send tensors to next pp rank."""
487490
_pp_comm.send(tensor)

tensorrt_llm/_torch/models/modeling_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,11 +170,12 @@ def forward_after_recv_fn(
170170
residual=...,
171171
**kwargs,
172172
):
173-
pp_recv(hidden_states)
174173
if residual is not ...:
175174
if residual is None:
176175
residual = torch.empty_like(hidden_states)
177-
pp_recv(residual)
176+
pp_recv([hidden_states, residual])
177+
else:
178+
pp_recv([hidden_states])
178179
return forward_fn(
179180
position_ids,
180181
hidden_states,
@@ -207,11 +208,10 @@ def forward_before_send_fn(
207208
)
208209
if residual is not ...:
209210
hidden_states, residual = output
210-
pp_send(hidden_states)
211-
pp_send(residual)
211+
pp_send([hidden_states, residual])
212212
else:
213213
hidden_states = output
214-
pp_send(hidden_states)
214+
pp_send([hidden_states])
215215
return output
216216

217217
forward_before_send_fn.__wrapped_by_forward_before_send__ = True

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,8 @@ def __init__(
364364
_torch_compile_piecewise_cuda_graph,
365365
capture_num_tokens=self._piecewise_cuda_graph_num_tokens,
366366
max_num_streams=pytorch_backend_config.
367-
torch_compile_max_num_streams)
367+
torch_compile_max_num_streams,
368+
mapping=self.mapping)
368369
if isinstance(self.model, DecoderModelForCausalLM):
369370
self.model.model = torch.compile(
370371
self.model.model,
@@ -2496,7 +2497,7 @@ def _forward_step_mm_encoder_only(
24962497
return {'mm_embeddings': mm_embeddings, 'logits': None}
24972498

24982499
def _init_userbuffers(self, hidden_size):
2499-
if self.mapping.tp_size <= 1:
2500+
if self.mapping.tp_size <= 1 or self.mapping.pp_size > 1:
25002501
return False
25012502

25022503
# Disable UB for unsupported platforms

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,6 @@ def test_bfloat16(self, attn_backend, torch_compile):
113113
ids=["tp4", "tp2pp2", "pp4"])
114114
def test_bfloat16_4gpus(self, tp_size, pp_size, attn_backend,
115115
torch_compile):
116-
if torch_compile and pp_size > 1:
117-
pytest.skip(
118-
"Pipeline parallel with torch.compile is not supported yet.\n"
119-
"Issue: Unfusing flashinfer_fused_add_rmsnorm causes outputs to be "
120-
"discarded at graph breaks.")
121116
torch_compile_config = TorchCompileConfig(
122117
enable_fullgraph=True,
123118
enable_piecewise_cuda_graph=True,
@@ -1187,8 +1182,6 @@ def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph,
11871182
def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
11881183
attention_dp, cuda_graph, overlap_scheduler,
11891184
torch_compile):
1190-
if torch_compile and pp_size > 1:
1191-
pytest.skip("PP with torch.compile is not supported yet.")
11921185
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
11931186
torch_compile_config = TorchCompileConfig(
11941187
enable_fullgraph=True,
@@ -1226,8 +1219,6 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
12261219
@parametrize_with_ids("mtp", ["disable", "eagle", "vanilla"])
12271220
def test_fp8_block_scales(self, mtp, fp8kv, attention_dp, cuda_graph,
12281221
overlap_scheduler, torch_compile):
1229-
if torch_compile and mtp != "disable":
1230-
pytest.skip("https://nvbugs/5252313")
12311222
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
12321223
torch_compile_config = TorchCompileConfig(
12331224
enable_fullgraph=True,
@@ -1280,8 +1271,6 @@ def test_cute_dsl_fp8_block_scales(
12801271
overlap_scheduler,
12811272
torch_compile,
12821273
):
1283-
if torch_compile and attention_dp:
1284-
pytest.skip("https://nvbugs/5252559")
12851274
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
12861275
torch_compile_config = (TorchCompileConfig(
12871276
enable_fullgraph=True,
@@ -1384,8 +1373,6 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
13841373
def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
13851374
fp8kv, attention_dp, cuda_graph,
13861375
overlap_scheduler, torch_compile):
1387-
if torch_compile and pp_size > 1:
1388-
pytest.skip("PP with torch.compile is not supported yet.")
13891376
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
13901377
torch_compile_config = TorchCompileConfig(
13911378
enable_fullgraph=True,
@@ -1446,8 +1433,6 @@ def test_cute_dsl_fp8_block_scales_4gpus(
14461433
overlap_scheduler,
14471434
torch_compile,
14481435
):
1449-
if torch_compile and pp_size > 1:
1450-
pytest.skip("PP with torch.compile is not supported yet.")
14511436
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
14521437
torch_compile_config = (TorchCompileConfig(
14531438
enable_fullgraph=True,
@@ -1669,8 +1654,6 @@ def test_nvfp4_batch_waiting(self, torch_compile, fp8kv, cuda_graph,
16691654
def test_nvfp4_4gpus(self, fp8kv, attention_dp, cuda_graph,
16701655
overlap_scheduler, tp_size, pp_size, ep_size,
16711656
torch_compile, mtp_nextn, moe_backend):
1672-
if torch_compile and pp_size > 1:
1673-
pytest.skip("PP with torch.compile is not supported yet.")
16741657
if moe_backend == "TRTLLM" and (get_sm_version() == 120
16751658
or get_sm_version() == 121):
16761659
pytest.skip(

tests/unittest/_torch/multi_gpu/test_ar_residual_norm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def row_linear_residual_norm_fusion_forward(
6666
x: torch.Tensor, residual: torch.Tensor, hidden_size: int,
6767
dtype: torch.dtype, tensor_parallel_size: int,
6868
tensor_parallel_rank: int, weights: torch.Tensor, fused_add_norm: bool):
69-
backend = Backend()
69+
backend = Backend(mapping=Mapping(world_size=tensor_parallel_size,
70+
tp_size=tensor_parallel_size,
71+
rank=tensor_parallel_rank))
7072
x = x.cuda()
7173
residual = residual.cuda()
7274
norm_weight = torch.randn((hidden_size, ), dtype=dtype, device="cuda")

0 commit comments

Comments
 (0)