-
Notifications
You must be signed in to change notification settings - Fork 26.3k
DTensor: add more foreach ops to supported sharding prop list #132066
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/132066
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (8 Unrelated Failures)As of commit 5036dd9 with merge base a356a03 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| else: | ||
| kwargs_schema[k] = v | ||
| local_kwargs[k] = v | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
before landing this, I probably need to:
(1) try adding a sharding prop rule for aten._foreach_mul
(2) see if that fixes the repro and add tests for it
I think this partially fixes #132016. It doesn't fully fix it though because: (1) we don't have a sharding prop rule for `aten. _foreach_mul_` (2) when we don't have a sharding prop rule, we assume that the schema of our op just has a flat list of inputs that does not require any nested unflattening cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
| aten._foreach_addcmul_.Tensor, | ||
| aten._foreach_clamp_max_.Scalar, | ||
| aten._foreach_clamp_min_.Scalar, | ||
| aten._foreach_div_.List, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also add aten. foreach_div.Scalar and aten. _foreach_div.Scalar in this PR as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep good call
| args_schema.append(arg) | ||
| local_args.append(arg) | ||
|
|
||
| tree_map(arg_to_spec, args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm trying to understand more on this, we do have pytree to flatten the args input for certain ops, i.e. the foreach op list have pytree flatten enabled https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/_pointwise_ops.py#L641
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh right, is the idea that DTensor wants to manually specify which ops actually need the pytree machinery, so their usage is more limited (to avoid the perf impact?)
In that case, we can probably tweak this code to only do the tree_map here for ops that have opted into "requiring pytrees" (probably by checking the flag that you just linked)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh right, is the idea that DTensor wants to manually specify which ops actually need the pytree machinery, so their usage is more limited (to avoid the perf impact?)
Yep!
In that case, we can probably tweak this code to only do the tree_map here for ops that have opted into "requiring pytrees" (probably by checking the flag that you just linked)
We already sorta doing that in a few lines above https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/_dispatch.py#L291-L294
So I am wondering if there're some additional bugs that triggered that issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh you are totally right... let me take another look
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yep you're right. So:
(1) this logic does the right thing of operating on the flattened args, as long as we detect that the op's sharding rule has opted into using pytrees
(2) the only bug is that DTensor was missing sharding rules for a few foreach ops, causing DTensor to use the "default" path of not using pytrees, in a "bad" way that causes DTensor to infinite loop horrible (I think, before getting a chance to error about not having a proper sharding rule).
hmm @wanchaol - is there any easy way to know earlier in advance if an op has no sharding prop rule and thus will fail, so we can check it here? https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/_dispatch.py#L291
One option is maybe to detect when there's no sharding prop rule and error earlier, before we try to generate FakeTensor arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(for now, I just "fixed" the issue by updated the foreach_ops list that DTensor relies on)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yep you're right. So:
(1) this logic does the right thing of operating on the flattened args, as long as we detect that the op's sharding rule has opted into using pytrees
(2) the only bug is that DTensor was missing sharding rules for a few foreach ops, causing DTensor to use the "default" path of not using pytrees, in a "bad" way that causes DTensor to infinite loop horrible (I think, before getting a chance to error about not having a proper sharding rule).
hmm @wanchaol - is there any easy way to know earlier in advance if an op has no sharding prop rule and thus will fail, so we can check it here? https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/_dispatch.py#L291
One option is maybe to detect when there's no sharding prop rule and error earlier, before we try to generate FakeTensor arguments.
For (2), I think the issue is that _propagate_tensor_meta(op_schema)(https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/_sharding_prop.py#L198) get called first and errors out so we could see all different errors (here are some more examples where we don't see the infinite loop but some other error: #124990) even though the true reason is that no shardping prop has been registered for an op.
I think this partially fixes #132016. It doesn't fully fix it though because: (1) we don't have a sharding prop rule for `aten. _foreach_mul_` (2) when we don't have a sharding prop rule, we assume that the schema of our op just has a flat list of inputs that does not require any nested unflattening UPDATE: the linked repro passes now that I added the missing foreach overloads, so DTensor registers proper sharding rules for them. cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
I think this partially fixes #132016. It doesn't fully fix it though because: (1) we don't have a sharding prop rule for `aten. _foreach_mul_` (2) when we don't have a sharding prop rule, we assume that the schema of our op just has a flat list of inputs that does not require any nested unflattening UPDATE: the linked repro passes now that I added the missing foreach overloads, so DTensor registers proper sharding rules for them. cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
wanchaol
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm thanks!
|
Please change the title and summary to reflect newest changes :) |
I think this partially fixes #132016. It doesn't fully fix it though because: (1) we don't have a sharding prop rule for `aten. _foreach_mul_` (2) when we don't have a sharding prop rule, we assume that the schema of our op just has a flat list of inputs that does not require any nested unflattening UPDATE: the linked repro passes now that I added the missing foreach overloads, so DTensor registers proper sharding rules for them. cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o [ghstack-poisoned]
|
@pytorchbot merge |
Merge failedReason: This PR needs a If not, please add the To add a label, you can comment to pytorchbot, for example For more information, see Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
fixes #132016.
Right now if you run an op that DTensor has no sharding prop rule, and that op accepts non-trivial pytrees of inputs tensors as arguments, DTensor can end up infinite looping before it has the chance to error due to not having a sharding prop rule.
This PR doesn't fix the problem, but adds rules for the culprit ops (missing foreach ops)
Stack from ghstack (oldest at bottom):
cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o