Skip to content

Conversation

@bdhirsh
Copy link
Contributor

@bdhirsh bdhirsh commented Jul 29, 2024

fixes #132016.

Right now if you run an op that DTensor has no sharding prop rule, and that op accepts non-trivial pytrees of inputs tensors as arguments, DTensor can end up infinite looping before it has the chance to error due to not having a sharding prop rule.

This PR doesn't fix the problem, but adds rules for the culprit ops (missing foreach ops)

Stack from ghstack (oldest at bottom):

cc @XilunWu @H-Huang @awgu @kwen2501 @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @c-p-i-o

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 29, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/132066

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (8 Unrelated Failures)

As of commit 5036dd9 with merge base a356a03 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue labels Jul 29, 2024
else:
kwargs_schema[k] = v
local_kwargs[k] = v

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before landing this, I probably need to:

(1) try adding a sharding prop rule for aten._foreach_mul
(2) see if that fixes the repro and add tests for it

I think this partially fixes #132016. It doesn't fully fix it though because:

(1) we don't have a sharding prop rule for `aten. _foreach_mul_`
(2) when we don't have a sharding prop rule, we assume that the schema of our op just has a flat list of inputs that does not require any nested unflattening




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
aten._foreach_addcmul_.Tensor,
aten._foreach_clamp_max_.Scalar,
aten._foreach_clamp_min_.Scalar,
aten._foreach_div_.List,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also add aten. foreach_div.Scalar and aten. _foreach_div.Scalar in this PR as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep good call

@albanD albanD removed their request for review July 29, 2024 17:57
args_schema.append(arg)
local_args.append(arg)

tree_map(arg_to_spec, args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm trying to understand more on this, we do have pytree to flatten the args input for certain ops, i.e. the foreach op list have pytree flatten enabled https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/ops/_pointwise_ops.py#L641

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right, is the idea that DTensor wants to manually specify which ops actually need the pytree machinery, so their usage is more limited (to avoid the perf impact?)

In that case, we can probably tweak this code to only do the tree_map here for ops that have opted into "requiring pytrees" (probably by checking the flag that you just linked)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh right, is the idea that DTensor wants to manually specify which ops actually need the pytree machinery, so their usage is more limited (to avoid the perf impact?)

Yep!

In that case, we can probably tweak this code to only do the tree_map here for ops that have opted into "requiring pytrees" (probably by checking the flag that you just linked)

We already sorta doing that in a few lines above https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/_dispatch.py#L291-L294

So I am wondering if there're some additional bugs that triggered that issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh you are totally right... let me take another look

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yep you're right. So:

(1) this logic does the right thing of operating on the flattened args, as long as we detect that the op's sharding rule has opted into using pytrees

(2) the only bug is that DTensor was missing sharding rules for a few foreach ops, causing DTensor to use the "default" path of not using pytrees, in a "bad" way that causes DTensor to infinite loop horrible (I think, before getting a chance to error about not having a proper sharding rule).

hmm @wanchaol - is there any easy way to know earlier in advance if an op has no sharding prop rule and thus will fail, so we can check it here? https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/_dispatch.py#L291

One option is maybe to detect when there's no sharding prop rule and error earlier, before we try to generate FakeTensor arguments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(for now, I just "fixed" the issue by updated the foreach_ops list that DTensor relies on)

Copy link
Contributor

@wz337 wz337 Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yep you're right. So:

(1) this logic does the right thing of operating on the flattened args, as long as we detect that the op's sharding rule has opted into using pytrees

(2) the only bug is that DTensor was missing sharding rules for a few foreach ops, causing DTensor to use the "default" path of not using pytrees, in a "bad" way that causes DTensor to infinite loop horrible (I think, before getting a chance to error about not having a proper sharding rule).

hmm @wanchaol - is there any easy way to know earlier in advance if an op has no sharding prop rule and thus will fail, so we can check it here? https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/_dispatch.py#L291

One option is maybe to detect when there's no sharding prop rule and error earlier, before we try to generate FakeTensor arguments.

For (2), I think the issue is that _propagate_tensor_meta(op_schema)(https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/_sharding_prop.py#L198) get called first and errors out so we could see all different errors (here are some more examples where we don't see the infinite loop but some other error: #124990) even though the true reason is that no shardping prop has been registered for an op.

I think this partially fixes #132016. It doesn't fully fix it though because:

(1) we don't have a sharding prop rule for `aten. _foreach_mul_`
(2) when we don't have a sharding prop rule, we assume that the schema of our op just has a flat list of inputs that does not require any nested unflattening

UPDATE: the linked repro passes now that I added the missing foreach overloads, so DTensor registers proper sharding rules for them.




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
I think this partially fixes #132016. It doesn't fully fix it though because:

(1) we don't have a sharding prop rule for `aten. _foreach_mul_`
(2) when we don't have a sharding prop rule, we assume that the schema of our op just has a flat list of inputs that does not require any nested unflattening

UPDATE: the linked repro passes now that I added the missing foreach overloads, so DTensor registers proper sharding rules for them.




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
bdhirsh added a commit that referenced this pull request Jul 29, 2024
ghstack-source-id: 17f9738
Pull Request resolved: #132066
@ezyang ezyang removed their request for review July 31, 2024 02:14
Copy link
Collaborator

@wanchaol wanchaol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm thanks!

@wanchaol
Copy link
Collaborator

wanchaol commented Aug 1, 2024

Please change the title and summary to reflect newest changes :)

I think this partially fixes #132016. It doesn't fully fix it though because:

(1) we don't have a sharding prop rule for `aten. _foreach_mul_`
(2) when we don't have a sharding prop rule, we assume that the schema of our op just has a flat list of inputs that does not require any nested unflattening

UPDATE: the linked repro passes now that I added the missing foreach overloads, so DTensor registers proper sharding rules for them.




cc XilunWu H-Huang awgu kwen2501 wanchaol fegin fduwjj wz337 wconstab d4l3k c-p-i-o

[ghstack-poisoned]
@bdhirsh bdhirsh changed the title DTensor: use pytrees to convert DTensors into specs DTensor: add more foreach ops to supported sharding prop list Aug 1, 2024
@bdhirsh
Copy link
Contributor Author

bdhirsh commented Aug 1, 2024

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Aug 1, 2024
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: This PR needs a release notes: label
If your changes are user facing and intended to be a part of release notes, please use a label starting with release notes:.

If not, please add the topic: not user facing label.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "topic: not user facing"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Details for Dev Infra team Raised by workflow job

@bdhirsh bdhirsh added the release notes: distributed (dtensor) release notes category label Aug 2, 2024
@awgu awgu added release notes: distributed (dtensor) release notes category and removed release notes: distributed (dtensor) release notes category labels Aug 5, 2024
@awgu
Copy link
Collaborator

awgu commented Aug 5, 2024

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (dtensor) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants