|
8 | 8 | PatternMatcherPass, fwd_only, |
9 | 9 | register_replacement) |
10 | 10 |
|
11 | | -import tensorrt_llm |
12 | | - |
13 | 11 | from ...distributed import AllReduceFusionOp, AllReduceStrategy |
14 | 12 |
|
15 | 13 | aten = torch.ops.aten |
16 | 14 | from tensorrt_llm.mapping import Mapping |
17 | 15 |
|
18 | 16 |
|
19 | | -def register_ar_residual_norm(custom_pass: PatternMatcherPass): |
20 | | - # TODO: add pp + tp support |
21 | | - mapping = Mapping( |
22 | | - world_size=tensorrt_llm.mpi_world_size(), |
23 | | - tp_size=tensorrt_llm.mpi_world_size(), |
24 | | - rank=tensorrt_llm.mpi_rank(), |
25 | | - ) |
| 17 | +def register_ar_residual_norm(custom_pass: PatternMatcherPass, |
| 18 | + mapping: Mapping): |
26 | 19 | residual_key = KeywordArg("residual") |
27 | 20 | trtllm_allreduce_default = CallFunction( |
28 | 21 | torch.ops.trtllm.allreduce.default, KeywordArg("input"), None, None, |
@@ -117,14 +110,8 @@ def check_non_ub_strategy(match, strategy_node) -> bool: |
117 | 110 | return True |
118 | 111 |
|
119 | 112 |
|
120 | | -def register_ar_residual_norm_out_fp8_quant(custom_pass: PatternMatcherPass): |
121 | | - # TODO: add pp + tp support |
122 | | - mapping = Mapping( |
123 | | - world_size=tensorrt_llm.mpi_world_size(), |
124 | | - tp_size=tensorrt_llm.mpi_world_size(), |
125 | | - rank=tensorrt_llm.mpi_rank(), |
126 | | - ) |
127 | | - |
| 113 | +def register_ar_residual_norm_out_fp8_quant(custom_pass: PatternMatcherPass, |
| 114 | + mapping: Mapping): |
128 | 115 | input_node = KeywordArg("input") |
129 | 116 | strategy_node = KeywordArg("strategy") |
130 | 117 | allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default, |
@@ -200,14 +187,8 @@ def extra_check(match: Match) -> bool: |
200 | 187 | ) |
201 | 188 |
|
202 | 189 |
|
203 | | -def register_ar_residual_norm_fp8_quant(custom_pass: PatternMatcherPass): |
204 | | - # TODO: add pp + tp support |
205 | | - mapping = Mapping( |
206 | | - world_size=tensorrt_llm.mpi_world_size(), |
207 | | - tp_size=tensorrt_llm.mpi_world_size(), |
208 | | - rank=tensorrt_llm.mpi_rank(), |
209 | | - ) |
210 | | - |
| 190 | +def register_ar_residual_norm_fp8_quant(custom_pass: PatternMatcherPass, |
| 191 | + mapping: Mapping): |
211 | 192 | input_node = KeywordArg("input") |
212 | 193 | strategy_node = KeywordArg("strategy") |
213 | 194 | allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default, |
@@ -282,14 +263,8 @@ def extra_check(match: Match) -> bool: |
282 | 263 | ) |
283 | 264 |
|
284 | 265 |
|
285 | | -def register_ar_residual_norm_out_fp4_quant(custom_pass: PatternMatcherPass): |
286 | | - # TODO: add pp + tp support |
287 | | - mapping = Mapping( |
288 | | - world_size=tensorrt_llm.mpi_world_size(), |
289 | | - tp_size=tensorrt_llm.mpi_world_size(), |
290 | | - rank=tensorrt_llm.mpi_rank(), |
291 | | - ) |
292 | | - |
| 266 | +def register_ar_residual_norm_out_fp4_quant(custom_pass: PatternMatcherPass, |
| 267 | + mapping: Mapping): |
293 | 268 | input_node = KeywordArg("input") |
294 | 269 | strategy_node = KeywordArg("strategy") |
295 | 270 | allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default, |
@@ -360,14 +335,8 @@ def extra_check(match: Match) -> bool: |
360 | 335 | ) |
361 | 336 |
|
362 | 337 |
|
363 | | -def register_ar_residual_norm_fp4_quant(custom_pass: PatternMatcherPass): |
364 | | - # TODO: add pp + tp support |
365 | | - mapping = Mapping( |
366 | | - world_size=tensorrt_llm.mpi_world_size(), |
367 | | - tp_size=tensorrt_llm.mpi_world_size(), |
368 | | - rank=tensorrt_llm.mpi_rank(), |
369 | | - ) |
370 | | - |
| 338 | +def register_ar_residual_norm_fp4_quant(custom_pass: PatternMatcherPass, |
| 339 | + mapping: Mapping): |
371 | 340 | input_node = KeywordArg("input") |
372 | 341 | strategy_node = KeywordArg("strategy") |
373 | 342 | allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default, |
@@ -437,12 +406,8 @@ def extra_check(match: Match) -> bool: |
437 | 406 | ) |
438 | 407 |
|
439 | 408 |
|
440 | | -def register_ub_patterns(custom_passes: List[PatternMatcherPass]): |
441 | | - mapping = Mapping( |
442 | | - world_size=tensorrt_llm.mpi_world_size(), |
443 | | - tp_size=tensorrt_llm.mpi_world_size(), |
444 | | - rank=tensorrt_llm.mpi_rank(), |
445 | | - ) |
| 409 | +def register_ub_patterns(custom_passes: List[PatternMatcherPass], |
| 410 | + mapping: Mapping): |
446 | 411 |
|
447 | 412 | def register_convert_supported_ar_to_ub(custom_pass: PatternMatcherPass): |
448 | 413 | strategy = int(AllReduceStrategy.AUTO) |
@@ -717,16 +682,16 @@ def target_finalize_pattern( |
717 | 682 |
|
718 | 683 |
|
719 | 684 | def register_ar_fusions(custom_passes: List[PatternMatcherPass], |
720 | | - enable_ub: bool): |
721 | | - register_ar_residual_norm(custom_passes[-1]) |
| 685 | + mapping: Mapping, enable_ub: bool): |
| 686 | + register_ar_residual_norm(custom_passes[-1], mapping) |
722 | 687 |
|
723 | 688 | custom_passes.append(PatternMatcherPass()) |
724 | | - register_ar_residual_norm_fp8_quant(custom_passes[-1]) |
725 | | - register_ar_residual_norm_fp4_quant(custom_passes[-1]) |
| 689 | + register_ar_residual_norm_fp8_quant(custom_passes[-1], mapping) |
| 690 | + register_ar_residual_norm_fp4_quant(custom_passes[-1], mapping) |
726 | 691 | # AR-Residual-Norm-Out-Quant-X is not supported by Userbuffers kernel. |
727 | 692 | if not enable_ub: |
728 | | - register_ar_residual_norm_out_fp8_quant(custom_passes[-1]) |
729 | | - register_ar_residual_norm_out_fp4_quant(custom_passes[-1]) |
| 693 | + register_ar_residual_norm_out_fp8_quant(custom_passes[-1], mapping) |
| 694 | + register_ar_residual_norm_out_fp4_quant(custom_passes[-1], mapping) |
730 | 695 |
|
731 | 696 | if enable_ub: |
732 | | - register_ub_patterns(custom_passes) |
| 697 | + register_ub_patterns(custom_passes, mapping) |
0 commit comments