Skip to content

Conversation

@wconstab
Copy link
Contributor

@wconstab wconstab commented Jul 14, 2025

Stack from ghstack (oldest at bottom):

local_tensor input to grouped_mm has a stride requirement.

(see _meta_grouped_mm_common in meta_registrations.py or
check_valid_strides_and_return_transposed in native/cuda/Blas.cpp)

Don't allow sharding a tensor if its shape would result in an
incompatible local_tensor stride.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @d4l3k

local_tensor input to grouped_mm has a stride requirement.

(see `_meta_grouped_mm_common` in meta_registrations.py or
`check_valid_strides_and_return_transposed` in native/cuda/Blas.cpp)

Don't allow sharding a tensor if its shape would result in an
incompatible local_tensor stride.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Jul 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/158245

Note: Links to docs will display an error until the docs builds have been completed.

⏳ 1 Pending, 1 Unrelated Failure

As of commit d56e755 with merge base 0d17029 (image):

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Jul 14, 2025
wconstab added a commit that referenced this pull request Jul 14, 2025
local_tensor input to grouped_mm has a stride requirement.

(see `_meta_grouped_mm_common` in meta_registrations.py or
`check_valid_strides_and_return_transposed` in native/cuda/Blas.cpp)

Don't allow sharding a tensor if its shape would result in an
incompatible local_tensor stride.

ghstack-source-id: 8e05bf9
Pull Request resolved: #158245
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Jul 14, 2025
@wconstab wconstab added the release notes: distributed (dtensor) release notes category label Jul 14, 2025
Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but I have some suggestions and also confusions to clarify.

meta.shape, mesh, placements
)
return local_shape, local_stride, meta.dtype

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

local_shape_stride -> compute_local_tensor_meta? And return TensorMeta instead of a Tuple.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

)

def valid_grouped_mm_strides(
input_specs: list[DTensorSpec], output_specs: tuple[Optional[DTensorSpec], ...]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a style preference: have mat_a_spec and mat_b_spec instead of input_specs list, which is more clear to me. WDYT?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no can do. This function has to have a generic signature that is not specific to grouped_mm. That's becuase this function's signature is defined by the API of expand_to_full_mesh_op_strategy which is a generic util that can be used by any op.

dtype: torch.dtype,
new_local_stride: tuple[int, ...],
) -> bool:
# copied from `_meta_grouped_mm_common` in meta_registrations.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason we prefer replicating the function over calling it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i can't literally call the checker becuase its buried inside a meta-fn, but perhaps i should just refactor that one so i can.

i also considered a more direct approach, create actual meta tensors for the local tensors, then call the grouped_mm meta with them. I could do this under a try: /except, and if it throws any error i'd call it an invalid sharding. This seems better in a way, so i might try that and see if it is easy to do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create actual meta tensors for the local tensors, then call the grouped_mm meta with them

It could work. But I'm also okay that we land this approach first then revise if that works out.

Comment on lines +244 to +246
is_valid_strategy_cb: Optional[
Callable[[list[DTensorSpec], tuple[Optional[DTensorSpec], ...]], bool]
] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like that we restrict the callable type to be list[DTensorSpec], tuple[...] -> bool. Do we have to specify the argument types and return type? Can't we just use type hint Optional[Callable]? If the type checker complaints, I prefer we have something like list[Any/object] for argument types.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh.. i think this is a perfect example of a time it is important to restrict the types of the function. Why? Because more than one user can write a different callback function but they must all share the same type signature. So it should be well defined, as an API.

The idea here is to define an API that gives enough information to let the per-operator callback make its choices. IF you think there is more info needed, then i think we should add that and add it to the type signature explicitly. If we find out later that more info is needed, we can add it to all the callbacks that have been implemented at that time.

spec_list: list[Optional[DTensorSpec]] = []
for specs in zip(*strategy_comb):
if specs[0] is not None:
# TODO: we should fill in tensor_meta here. If nothing else, it helps the filter strategy callback
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @zpcore

Comment on lines +284 to +288
else:
if spec_list[0] is not None:
output_specs = spec_list[0] # type: ignore[assignment]
else:
raise RuntimeError("output spec is None")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can/should input_index < 1 be legitimate? IMO it means that there's no output, which is contradictory to the if spec_list[0] is not None branch where the first spec is treated as output_spec.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI i just moved this logic from below, so it is pre-existing logic you are complaining about.

That said, I agree its not well written.

  1. input_index >1 is being checked above, but ==1 is not being checked. That means the else branch has to consider both ==0 and ==1. (perhaps also <1, which would be bad). It could be written more explicitly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for clarifying!

local_tensor input to grouped_mm has a stride requirement.

(see `_meta_grouped_mm_common` in meta_registrations.py or
`check_valid_strides_and_return_transposed` in native/cuda/Blas.cpp)

Don't allow sharding a tensor if its shape would result in an
incompatible local_tensor stride.

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jul 14, 2025
local_tensor input to grouped_mm has a stride requirement.

(see `_meta_grouped_mm_common` in meta_registrations.py or
`check_valid_strides_and_return_transposed` in native/cuda/Blas.cpp)

Don't allow sharding a tensor if its shape would result in an
incompatible local_tensor stride.

ghstack-source-id: 14ca70f
Pull Request resolved: #158245
# check inputs shardable
inputs_shardable = all(
output_specs: tuple[Optional[DTensorSpec], ...]
if input_index > 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not relate to this PR, I ask AI to figure out what is input_index used for :) A comment will be helpful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ask and you shall recieve.
(updated docstring for expand_to_full_mesh_op_strategy to cover this)

)

def valid_grouped_mm_strides(
input_specs: list[DTensorSpec], output_specs: tuple[Optional[DTensorSpec], ...]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output_specs is not used.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, you are trying to match with is_valid_strategy_cb pattern. Though it is not used here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct, I added it to the signature because I thought it might be useful for some other ops even though I did not use if for this op.

Copy link
Member

@zpcore zpcore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

local_tensor input to grouped_mm has a stride requirement.

(see `_meta_grouped_mm_common` in meta_registrations.py or
`check_valid_strides_and_return_transposed` in native/cuda/Blas.cpp)

Don't allow sharding a tensor if its shape would result in an
incompatible local_tensor stride.

cc H-Huang awgu wanchaol fegin fduwjj wz337 d4l3k

[ghstack-poisoned]
wconstab added a commit that referenced this pull request Jul 14, 2025
local_tensor input to grouped_mm has a stride requirement.

(see `_meta_grouped_mm_common` in meta_registrations.py or
`check_valid_strides_and_return_transposed` in native/cuda/Blas.cpp)

Don't allow sharding a tensor if its shape would result in an
incompatible local_tensor stride.

ghstack-source-id: 468444c
Pull Request resolved: #158245
Copy link
Contributor

@XilunWu XilunWu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for addressing my question, the PR looks good to me.

@wconstab
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 14, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@github-actions github-actions bot deleted the gh/wconstab/425/head branch August 15, 2025 02:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants