-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[DTensor] implement dist_cat as a sharding prop rule #92677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
815f56c
0e9ffe3
41145df
cb506a5
3ea6b8c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,8 +11,8 @@ | |
| Shard, | ||
| ) | ||
| from torch.distributed._tensor.dispatch import OpSchema, OutputSharding | ||
| from torch.distributed._tensor.ops.common_rules import pointwise_rule | ||
| from torch.distributed._tensor.ops.utils import register_prop_rule | ||
| from torch.distributed._tensor.ops.common_rules import einop_rule, pointwise_rule | ||
| from torch.distributed._tensor.ops.utils import register_prop_rule, normalize_dim | ||
|
|
||
|
|
||
| # NOTE: the default propagation rule should apply for | ||
|
|
@@ -160,6 +160,13 @@ def unshard_tensor_dim( | |
| ) | ||
|
|
||
|
|
||
| def is_tensor_dim_sharded( | ||
XilunWu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| spec: DTensorSpec, dim: int | ||
| ) -> bool: | ||
| """Return True if tensor dim is sharded""" | ||
| return (dim < spec.ndim) and spec.dim_map[dim] >= 0 | ||
|
|
||
|
|
||
| def _prop_all_but_dim( | ||
| op_schema: OpSchema, dim: int, out_shape: torch.Size | ||
| ) -> OutputSharding: | ||
|
|
@@ -472,3 +479,124 @@ def place(vp: Placement, ip: Placement) -> Placement: | |
| ], | ||
| ) | ||
| return result | ||
|
|
||
|
|
||
| @register_prop_rule("aten.cat.default") | ||
| def cat_rule(op_schema: OpSchema) -> OutputSharding: | ||
| # the first arg is a list of input tensors' specs | ||
| tensor_list_specs = cast(List[DTensorSpec], op_schema.args_schema[0]) | ||
| # ndim will also be the result's ndim | ||
| ndim = 1 | ||
| for spec in tensor_list_specs: | ||
| ndim = max(ndim, spec.ndim) | ||
|
|
||
| dim = 0 # default dim = 0 | ||
| if (len(op_schema.args_schema) > 1): | ||
| dim = cast(int, op_schema.args_schema[1]) | ||
| dim = normalize_dim(dim, ndim) | ||
|
|
||
| # Unshard all input tensors on cat dim before running einop rule | ||
| # to avoid _Partial in result. | ||
| need_reshard = False | ||
| tensor_list_specs_after = [] | ||
| for spec in tensor_list_specs: | ||
| if is_tensor_dim_sharded(spec, dim=dim): | ||
| need_reshard = True | ||
| tensor_list_specs_after.append( | ||
| DTensorSpec( | ||
| mesh=spec.mesh, | ||
| placements=unshard_tensor_dim(spec.placements, dim=dim), | ||
| shape=spec.shape, | ||
| ndim=spec.ndim, | ||
| ) | ||
| ) | ||
| else: | ||
| tensor_list_specs_after.append(spec) | ||
| tensor_list_specs = tensor_list_specs_after | ||
|
|
||
| # TODO: currently einop rule requires every character | ||
| # in result notation must have appeared in inputs | ||
| # so we temporarily design cat notation as | ||
| # "aij,bij->aij". Once we modify this requirement, | ||
| # we can switch to the more logically reasonable notation | ||
| # "aij,bij->cij" | ||
| alphabet = "abcdefghijklmnopqrstuvwxyz" | ||
| einop_notation_list = [] | ||
|
|
||
| l = len(tensor_list_specs) | ||
| free_dim = alphabet[l:l + ndim - 1] | ||
| for i, spec in enumerate(tensor_list_specs): | ||
| if spec.ndim == ndim: | ||
| # rewrite concat dim | ||
| dim_word = free_dim[:dim] + alphabet[i] + free_dim[dim:] | ||
| einop_notation_list.append(dim_word) | ||
| else: | ||
| einop_notation_list.append(alphabet[i]) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this the empty tensor annotation where it have a single char?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not entirely for empty tensor but empty tensor whose In this case, an empty annotation may still work but we want to ensure that the dim char for |
||
|
|
||
| cat_dim_char = alphabet[0] | ||
| dim_word = free_dim[:dim] + cat_dim_char + free_dim[dim:] | ||
| einop_equation = f"{','.join(einop_notation_list)}->{dim_word}" | ||
| output_sharding = einop_rule( | ||
| einop_equation, | ||
| OpSchema( | ||
| func_schema=op_schema.func_schema, | ||
| args_schema=tuple(tensor_list_specs), | ||
| kwargs_schema={}, | ||
| ), | ||
| linearity=False | ||
| ) | ||
|
|
||
| if ( | ||
| (output_sharding.output_spec is not None) and | ||
| need_reshard | ||
| ): | ||
| output_sharding.output_spec = None | ||
| output_sharding.schema_suggestions = [ | ||
| OpSchema( | ||
| func_schema=op_schema.func_schema, | ||
| args_schema=tuple(tensor_list_specs), | ||
| kwargs_schema={}, | ||
| ), | ||
| ] | ||
|
|
||
| if output_sharding.output_spec is None: | ||
| if output_sharding.schema_suggestions is not None: | ||
| # Convert args_schema from a tuple of DTensorSpec into a list | ||
| return _update_schema_suggestion_for_cat( | ||
| output_sharding, | ||
| op_schema, | ||
| ) | ||
| else: | ||
| return output_sharding | ||
|
|
||
| # change output shape | ||
| new_size = 0 | ||
| for spec in tensor_list_specs: | ||
| if dim < spec.ndim: | ||
| new_size += spec.shape[dim] | ||
| assert isinstance(output_sharding.output_spec, DTensorSpec) | ||
| output_sharding.output_spec.shape = torch.Size( | ||
| tuple(output_sharding.output_spec.shape[:dim]) | ||
| + (new_size,) | ||
| + tuple(output_sharding.output_spec.shape[dim + 1 :]) | ||
| ) | ||
| return output_sharding | ||
|
|
||
|
|
||
| def _update_schema_suggestion_for_cat( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you tell me what exactly this function is doing? it looks like a lot of duplicate logic with the rule itself and I am not quite sure what this function is used for.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| output_sharding: OutputSharding, | ||
| op_schema: OpSchema, | ||
| ) -> OutputSharding: | ||
| assert output_sharding.schema_suggestions is not None | ||
| suggestion_specs = output_sharding.schema_suggestions[0].args_spec | ||
|
|
||
| args_schema = (suggestion_specs,) + op_schema.args_schema[1:] | ||
|
|
||
| output_sharding.schema_suggestions = [ | ||
| OpSchema( | ||
| func_schema=op_schema.func_schema, | ||
| args_schema=args_schema, | ||
| kwargs_schema=op_schema.kwargs_schema, | ||
| ) | ||
| ] | ||
| return output_sharding | ||
Uh oh!
There was an error while loading. Please reload this page.