-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[DTensor] implement dist_split as a sharding prop rule #93306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
56c4cea
e5b8af3
61c676a
e1cc5d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -598,3 +598,81 @@ def _update_schema_suggestion_for_cat( | |
| ) | ||
| ] | ||
| return output_sharding | ||
|
|
||
|
|
||
| @register_prop_rule(aten.split.Tensor) | ||
| def split_rule(op_schema: OpSchema) -> OutputSharding: | ||
| output_spec_list: List[DTensorSpec] = [] | ||
| input_spec = cast(DTensorSpec, op_schema.args_schema[0]) | ||
| ndim = input_spec.ndim | ||
| split_size_or_sections = op_schema.args_schema[1] | ||
| dim = ( | ||
| cast(int, op_schema.args_schema[2]) | ||
| if len(op_schema.args_schema) > 2 | ||
| else 0 | ||
| ) | ||
| dim = normalize_dim(dim, ndim) | ||
|
|
||
| # TODO: tensor to split cannot have _Partial | ||
| # in its placements for now. Will need to | ||
| # support in future. | ||
| if input_spec.sums: | ||
| raise NotImplementedError( | ||
| f"splitting distributed tensor with " | ||
| f"_Partial placement is not implemented!\n" | ||
| f"DTensorSpec={input_spec}" | ||
| ) | ||
|
|
||
| # TODO: just like slice op, split replicates before | ||
| # splitting on a sharded dimension | ||
| need_reshard = False | ||
| if is_tensor_dim_sharded(input_spec, dim=dim): | ||
| need_reshard = True | ||
| input_spec = DTensorSpec( | ||
| mesh=input_spec.mesh, | ||
| placements=unshard_tensor_dim(input_spec.placements, dim=dim), | ||
| shape=input_spec.shape, | ||
| ndim=input_spec.ndim, | ||
| ) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: add a check to partial input_spec and raise NotImplementedError so we know to implement this later?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good! |
||
|
|
||
| if need_reshard: | ||
| return OutputSharding( | ||
| None, | ||
| schema_suggestions=[ | ||
| OpSchema( | ||
| func_schema=op_schema.func_schema, | ||
| args_schema=(input_spec,) + op_schema.args_schema[1:], | ||
| kwargs_schema=op_schema.kwargs_schema, | ||
| ), | ||
| ] | ||
| ) | ||
|
|
||
| def size_split(N, i): | ||
| # Last chunk will be smaller if the tensor size N | ||
| # along the given dimension dim is not divisible by i. | ||
| assert i > 0 | ||
| return [i] * (N // i) + ([N % i] if N % i != 0 else []) | ||
XilunWu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| output_size_list = ( | ||
| size_split(input_spec.shape[dim], split_size_or_sections) | ||
| if isinstance(split_size_or_sections, int) | ||
| else split_size_or_sections | ||
| ) | ||
| output_shape_list = [ | ||
| torch.Size( | ||
| tuple(input_spec.shape[:dim]) | ||
| + (size,) | ||
| + tuple(input_spec.shape[dim + 1 :]) | ||
| ) | ||
| for size in output_size_list | ||
| ] | ||
| output_spec_list = [ | ||
| DTensorSpec( | ||
| mesh=input_spec.mesh, | ||
| placements=input_spec.placements, | ||
| shape=shape, | ||
| ndim=input_spec.ndim, | ||
| ) | ||
| for shape in output_shape_list | ||
| ] | ||
| return OutputSharding(output_spec_list) | ||
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This somehow broke TP's code logic. Because a common technique people are using is that the DTensor is sharded on the last dim and they call split on the last dim too. We still want the result to be sharded on dim=-1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently after split, we got replicate as a DTensor.