-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[DTensor] Fix grouped_mm strategy for invalid stride cases #158245
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
local_tensor input to grouped_mm has a stride requirement. (see `_meta_grouped_mm_common` in meta_registrations.py or `check_valid_strides_and_return_transposed` in native/cuda/Blas.cpp) Don't allow sharding a tensor if its shape would result in an incompatible local_tensor stride. [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158245
Note: Links to docs will display an error until the docs builds have been completed. ⏳ 1 Pending, 1 Unrelated FailureAs of commit d56e755 with merge base 0d17029 ( UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
local_tensor input to grouped_mm has a stride requirement. (see `_meta_grouped_mm_common` in meta_registrations.py or `check_valid_strides_and_return_transposed` in native/cuda/Blas.cpp) Don't allow sharding a tensor if its shape would result in an incompatible local_tensor stride. ghstack-source-id: 8e05bf9 Pull Request resolved: #158245
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM but I have some suggestions and also confusions to clarify.
| meta.shape, mesh, placements | ||
| ) | ||
| return local_shape, local_stride, meta.dtype | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
local_shape_stride -> compute_local_tensor_meta? And return TensorMeta instead of a Tuple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
| ) | ||
|
|
||
| def valid_grouped_mm_strides( | ||
| input_specs: list[DTensorSpec], output_specs: tuple[Optional[DTensorSpec], ...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a style preference: have mat_a_spec and mat_b_spec instead of input_specs list, which is more clear to me. WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no can do. This function has to have a generic signature that is not specific to grouped_mm. That's becuase this function's signature is defined by the API of expand_to_full_mesh_op_strategy which is a generic util that can be used by any op.
| dtype: torch.dtype, | ||
| new_local_stride: tuple[int, ...], | ||
| ) -> bool: | ||
| # copied from `_meta_grouped_mm_common` in meta_registrations.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason we prefer replicating the function over calling it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i can't literally call the checker becuase its buried inside a meta-fn, but perhaps i should just refactor that one so i can.
i also considered a more direct approach, create actual meta tensors for the local tensors, then call the grouped_mm meta with them. I could do this under a try: /except, and if it throws any error i'd call it an invalid sharding. This seems better in a way, so i might try that and see if it is easy to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
create actual meta tensors for the local tensors, then call the grouped_mm meta with them
It could work. But I'm also okay that we land this approach first then revise if that works out.
| is_valid_strategy_cb: Optional[ | ||
| Callable[[list[DTensorSpec], tuple[Optional[DTensorSpec], ...]], bool] | ||
| ] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like that we restrict the callable type to be list[DTensorSpec], tuple[...] -> bool. Do we have to specify the argument types and return type? Can't we just use type hint Optional[Callable]? If the type checker complaints, I prefer we have something like list[Any/object] for argument types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uh.. i think this is a perfect example of a time it is important to restrict the types of the function. Why? Because more than one user can write a different callback function but they must all share the same type signature. So it should be well defined, as an API.
The idea here is to define an API that gives enough information to let the per-operator callback make its choices. IF you think there is more info needed, then i think we should add that and add it to the type signature explicitly. If we find out later that more info is needed, we can add it to all the callbacks that have been implemented at that time.
| spec_list: list[Optional[DTensorSpec]] = [] | ||
| for specs in zip(*strategy_comb): | ||
| if specs[0] is not None: | ||
| # TODO: we should fill in tensor_meta here. If nothing else, it helps the filter strategy callback |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @zpcore
| else: | ||
| if spec_list[0] is not None: | ||
| output_specs = spec_list[0] # type: ignore[assignment] | ||
| else: | ||
| raise RuntimeError("output spec is None") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can/should input_index < 1 be legitimate? IMO it means that there's no output, which is contradictory to the if spec_list[0] is not None branch where the first spec is treated as output_spec.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI i just moved this logic from below, so it is pre-existing logic you are complaining about.
That said, I agree its not well written.
- input_index >1 is being checked above, but ==1 is not being checked. That means the else branch has to consider both ==0 and ==1. (perhaps also <1, which would be bad). It could be written more explicitly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks for clarifying!
local_tensor input to grouped_mm has a stride requirement. (see `_meta_grouped_mm_common` in meta_registrations.py or `check_valid_strides_and_return_transposed` in native/cuda/Blas.cpp) Don't allow sharding a tensor if its shape would result in an incompatible local_tensor stride. cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k [ghstack-poisoned]
local_tensor input to grouped_mm has a stride requirement. (see `_meta_grouped_mm_common` in meta_registrations.py or `check_valid_strides_and_return_transposed` in native/cuda/Blas.cpp) Don't allow sharding a tensor if its shape would result in an incompatible local_tensor stride. ghstack-source-id: 14ca70f Pull Request resolved: #158245
| # check inputs shardable | ||
| inputs_shardable = all( | ||
| output_specs: tuple[Optional[DTensorSpec], ...] | ||
| if input_index > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not relate to this PR, I ask AI to figure out what is input_index used for :) A comment will be helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ask and you shall recieve.
(updated docstring for expand_to_full_mesh_op_strategy to cover this)
| ) | ||
|
|
||
| def valid_grouped_mm_strides( | ||
| input_specs: list[DTensorSpec], output_specs: tuple[Optional[DTensorSpec], ...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
output_specs is not used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, you are trying to match with is_valid_strategy_cb pattern. Though it is not used here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct, I added it to the signature because I thought it might be useful for some other ops even though I did not use if for this op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
local_tensor input to grouped_mm has a stride requirement. (see `_meta_grouped_mm_common` in meta_registrations.py or `check_valid_strides_and_return_transposed` in native/cuda/Blas.cpp) Don't allow sharding a tensor if its shape would result in an incompatible local_tensor stride. cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k [ghstack-poisoned]
local_tensor input to grouped_mm has a stride requirement. (see `_meta_grouped_mm_common` in meta_registrations.py or `check_valid_strides_and_return_transposed` in native/cuda/Blas.cpp) Don't allow sharding a tensor if its shape would result in an incompatible local_tensor stride. ghstack-source-id: 468444c Pull Request resolved: #158245
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for addressing my question, the PR looks good to me.
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Stack from ghstack (oldest at bottom):
local_tensor input to grouped_mm has a stride requirement.
(see
_meta_grouped_mm_commonin meta_registrations.py orcheck_valid_strides_and_return_transposedin native/cuda/Blas.cpp)Don't allow sharding a tensor if its shape would result in an
incompatible local_tensor stride.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k