|
11 | 11 | Shard, |
12 | 12 | ) |
13 | 13 | from torch.distributed._tensor.dispatch import OpSchema, OutputSharding |
14 | | -from torch.distributed._tensor.ops.common_rules import pointwise_rule |
15 | | -from torch.distributed._tensor.ops.utils import register_prop_rule |
| 14 | +from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule |
| 15 | +from torch.distributed._tensor.ops.utils import register_prop_rule, normalize_dim |
16 | 16 |
|
17 | 17 |
|
18 | 18 | # NOTE: the default propagation rule should apply for |
@@ -160,6 +160,13 @@ def unshard_tensor_dim( |
160 | 160 | ) |
161 | 161 |
|
162 | 162 |
|
| 163 | +def is_tensor_dim_sharded( |
| 164 | + spec: DTensorSpec, dim: int |
| 165 | +) -> bool: |
| 166 | + """Return True if tensor dim is sharded""" |
| 167 | + return (dim < spec.ndim) and spec.dim_map[dim] >= 0 |
| 168 | + |
| 169 | + |
163 | 170 | def _prop_all_but_dim( |
164 | 171 | op_schema: OpSchema, dim: int, out_shape: torch.Size |
165 | 172 | ) -> OutputSharding: |
@@ -472,3 +479,124 @@ def place(vp: Placement, ip: Placement) -> Placement: |
472 | 479 | ], |
473 | 480 | ) |
474 | 481 | return result |
| 482 | + |
| 483 | + |
| 484 | +@register_prop_rule("aten.cat.default") |
| 485 | +def cat_rule(op_schema: OpSchema) -> OutputSharding: |
| 486 | + # the first arg is a list of input tensors' specs |
| 487 | + tensor_list_specs = cast(List[DTensorSpec], op_schema.args_schema[0]) |
| 488 | + # ndim will also be the result's ndim |
| 489 | + ndim = 1 |
| 490 | + for spec in tensor_list_specs: |
| 491 | + ndim = max(ndim, spec.ndim) |
| 492 | + |
| 493 | + dim = 0 # default dim = 0 |
| 494 | + if (len(op_schema.args_schema) > 1): |
| 495 | + dim = cast(int, op_schema.args_schema[1]) |
| 496 | + dim = normalize_dim(dim, ndim) |
| 497 | + |
| 498 | + # Unshard all input tensors on cat dim before running einop rule |
| 499 | + # to avoid _Partial in result. |
| 500 | + need_reshard = False |
| 501 | + tensor_list_specs_after = [] |
| 502 | + for spec in tensor_list_specs: |
| 503 | + if is_tensor_dim_sharded(spec, dim=dim): |
| 504 | + need_reshard = True |
| 505 | + tensor_list_specs_after.append( |
| 506 | + DTensorSpec( |
| 507 | + mesh=spec.mesh, |
| 508 | + placements=unshard_tensor_dim(spec.placements, dim=dim), |
| 509 | + shape=spec.shape, |
| 510 | + ndim=spec.ndim, |
| 511 | + ) |
| 512 | + ) |
| 513 | + else: |
| 514 | + tensor_list_specs_after.append(spec) |
| 515 | + tensor_list_specs = tensor_list_specs_after |
| 516 | + |
| 517 | + # TODO: currently einop rule requires every character |
| 518 | + # in result notation must have appeared in inputs |
| 519 | + # so we temporarily design cat notation as |
| 520 | + # "aij,bij->aij". Once we modify this requirement, |
| 521 | + # we can switch to the more logically reasonable notation |
| 522 | + # "aij,bij->cij" |
| 523 | + alphabet = "abcdefghijklmnopqrstuvwxyz" |
| 524 | + einop_notation_list = [] |
| 525 | + |
| 526 | + l = len(tensor_list_specs) |
| 527 | + free_dim = alphabet[l:l + ndim - 1] |
| 528 | + for i, spec in enumerate(tensor_list_specs): |
| 529 | + if spec.ndim == ndim: |
| 530 | + # rewrite concat dim |
| 531 | + dim_word = free_dim[:dim] + alphabet[i] + free_dim[dim:] |
| 532 | + einop_notation_list.append(dim_word) |
| 533 | + else: |
| 534 | + einop_notation_list.append(alphabet[i]) |
| 535 | + |
| 536 | + cat_dim_char = alphabet[0] |
| 537 | + dim_word = free_dim[:dim] + cat_dim_char + free_dim[dim:] |
| 538 | + einop_equation = f"{','.join(einop_notation_list)}->{dim_word}" |
| 539 | + output_sharding = einop_rule( |
| 540 | + einop_equation, |
| 541 | + OpSchema( |
| 542 | + func_schema=op_schema.func_schema, |
| 543 | + args_schema=tuple(tensor_list_specs), |
| 544 | + kwargs_schema={}, |
| 545 | + ), |
| 546 | + linearity=False |
| 547 | + ) |
| 548 | + |
| 549 | + if ( |
| 550 | + (output_sharding.output_spec is not None) and |
| 551 | + need_reshard |
| 552 | + ): |
| 553 | + output_sharding.output_spec = None |
| 554 | + output_sharding.schema_suggestions = [ |
| 555 | + OpSchema( |
| 556 | + func_schema=op_schema.func_schema, |
| 557 | + args_schema=tuple(tensor_list_specs), |
| 558 | + kwargs_schema={}, |
| 559 | + ), |
| 560 | + ] |
| 561 | + |
| 562 | + if output_sharding.output_spec is None: |
| 563 | + if output_sharding.schema_suggestions is not None: |
| 564 | + # Convert args_schema from a tuple of DTensorSpec into a list |
| 565 | + return _update_schema_suggestion_for_cat( |
| 566 | + output_sharding, |
| 567 | + op_schema, |
| 568 | + ) |
| 569 | + else: |
| 570 | + return output_sharding |
| 571 | + |
| 572 | + # change output shape |
| 573 | + new_size = 0 |
| 574 | + for spec in tensor_list_specs: |
| 575 | + if dim < spec.ndim: |
| 576 | + new_size += spec.shape[dim] |
| 577 | + assert isinstance(output_sharding.output_spec, DTensorSpec) |
| 578 | + output_sharding.output_spec.shape = torch.Size( |
| 579 | + tuple(output_sharding.output_spec.shape[:dim]) |
| 580 | + + (new_size,) |
| 581 | + + tuple(output_sharding.output_spec.shape[dim + 1 :]) |
| 582 | + ) |
| 583 | + return output_sharding |
| 584 | + |
| 585 | + |
| 586 | +def _update_schema_suggestion_for_cat( |
| 587 | + output_sharding: OutputSharding, |
| 588 | + op_schema: OpSchema, |
| 589 | +) -> OutputSharding: |
| 590 | + assert output_sharding.schema_suggestions is not None |
| 591 | + suggestion_specs = output_sharding.schema_suggestions[0].args_spec |
| 592 | + |
| 593 | + args_schema = (suggestion_specs,) + op_schema.args_schema[1:] |
| 594 | + |
| 595 | + output_sharding.schema_suggestions = [ |
| 596 | + OpSchema( |
| 597 | + func_schema=op_schema.func_schema, |
| 598 | + args_schema=args_schema, |
| 599 | + kwargs_schema=op_schema.kwargs_schema, |
| 600 | + ) |
| 601 | + ] |
| 602 | + return output_sharding |
0 commit comments