|
11 | 11 | Shard, |
12 | 12 | ) |
13 | 13 | from torch.distributed._tensor.dispatch import OpSchema, OutputSharding |
14 | | -from torch.distributed._tensor.ops.common_rules import pointwise_rule |
| 14 | +from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule |
15 | 15 | from torch.distributed._tensor.ops.utils import register_prop_rule |
16 | 16 |
|
17 | 17 |
|
@@ -472,3 +472,93 @@ def place(vp: Placement, ip: Placement) -> Placement: |
472 | 472 | ], |
473 | 473 | ) |
474 | 474 | return result |
| 475 | + |
| 476 | + |
| 477 | +@register_prop_rule("aten.cat.default") |
| 478 | +def cat_rule(op_schema: OpSchema) -> OutputSharding: |
| 479 | + dim = 0 # default dim = 0 |
| 480 | + tensor_list_specs = cast(List[DTensorSpec], op_schema.args_schema[0]) |
| 481 | + if (len(op_schema.args_schema) > 1): |
| 482 | + dim = cast(int, op_schema.args_schema[1]) |
| 483 | + # normalize arguments |
| 484 | + if dim < 0: |
| 485 | + dim += tensor_list_specs[0].ndim |
| 486 | + |
| 487 | + # check concat dim |
| 488 | + needs_reshard_on_cat_dim = False |
| 489 | + for spec in tensor_list_specs: |
| 490 | + if dim < len(spec.placements) and spec.placements[dim].is_shard(): |
| 491 | + needs_reshard_on_cat_dim = True |
| 492 | + spec.placements = unshard_tensor_dim(spec.placements, dim=dim) |
| 493 | + if needs_reshard_on_cat_dim: |
| 494 | + args_schema = (tensor_list_specs,) + op_schema.args_schema[1:] |
| 495 | + suggested_schema = OpSchema( |
| 496 | + func_schema=op_schema.func_schema, |
| 497 | + args_schema=args_schema, |
| 498 | + kwargs_schema=op_schema.kwargs_schema, |
| 499 | + ) |
| 500 | + return OutputSharding( |
| 501 | + None, |
| 502 | + schema_suggestions=[suggested_schema], |
| 503 | + failed_reason="All tensors in concat must have no sharding on cat dim, need to reshard!", |
| 504 | + ) |
| 505 | + alphabet = "abcdefghijklmnopqrstuvwxyz" |
| 506 | + einop_equation = "" |
| 507 | + for spec in tensor_list_specs: |
| 508 | + einop_equation += alphabet[:spec.ndim] |
| 509 | + einop_equation += ',' |
| 510 | + einop_equation = einop_equation[:-1] + "->" + alphabet[:tensor_list_specs[0].ndim] |
| 511 | + output_sharding = einop_rule( |
| 512 | + einop_equation, |
| 513 | + OpSchema( |
| 514 | + func_schema=op_schema.func_schema, |
| 515 | + args_schema=tuple(tensor_list_specs), |
| 516 | + kwargs_schema={}, |
| 517 | + ), |
| 518 | + linearity=False |
| 519 | + ) |
| 520 | + |
| 521 | + if output_sharding.output_spec is None: |
| 522 | + if output_sharding.schema_suggestions is not None: |
| 523 | + return _update_schema_suggestion_for_cat( |
| 524 | + output_sharding, |
| 525 | + op_schema, |
| 526 | + dim, |
| 527 | + ) |
| 528 | + else: |
| 529 | + return OutputSharding(None) |
| 530 | + # change output shape |
| 531 | + new_size = 0 |
| 532 | + for spec in tensor_list_specs: |
| 533 | + new_size += spec.shape[dim] |
| 534 | + assert isinstance(output_sharding.output_spec, DTensorSpec) |
| 535 | + output_sharding.output_spec.shape = torch.Size( |
| 536 | + tuple(output_sharding.output_spec.shape[:dim]) |
| 537 | + + (new_size,) |
| 538 | + + tuple(output_sharding.output_spec.shape[dim + 1 :]) |
| 539 | + ) |
| 540 | + return output_sharding |
| 541 | + |
| 542 | + |
| 543 | +def _update_schema_suggestion_for_cat( |
| 544 | + output_sharding: OutputSharding, |
| 545 | + op_schema: OpSchema, |
| 546 | + dim: int, |
| 547 | +) -> OutputSharding: |
| 548 | + assert output_sharding.schema_suggestions is not None |
| 549 | + suggestion_specs = output_sharding.schema_suggestions[0].args_spec |
| 550 | + |
| 551 | + # check concat dim |
| 552 | + for spec in suggestion_specs: |
| 553 | + if dim < len(spec.placements) and spec.placements[dim].is_shard(): |
| 554 | + spec.placements = unshard_tensor_dim(spec.placements, dim=dim) |
| 555 | + args_schema = (suggestion_specs,) + op_schema.args_schema[1:] |
| 556 | + |
| 557 | + output_sharding.schema_suggestions = [ |
| 558 | + OpSchema( |
| 559 | + func_schema=op_schema.func_schema, |
| 560 | + args_schema=args_schema, |
| 561 | + kwargs_schema=op_schema.kwargs_schema, |
| 562 | + ) |
| 563 | + ] |
| 564 | + return output_sharding |
0 commit comments