[DTensor] Fix slow sharding prop for stack#169519
[DTensor] Fix slow sharding prop for stack#169519wconstab wants to merge 1 commit intogh/wconstab/467/basefrom
Conversation
As identified in the original issue, there is quadratic complexity in the number of input tensors, due to an improperly written sharding prop rule. The previous code generated N output strategies for the stack op, one based on each of the original N input strategies. However, Each of the N output strategies was the same. The heuristic in the stack rule is to find one of the N inputs and follow that one. We now just generate one output strategy. Fixes #169445 [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/169519
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit fca2760 with merge base e3f24fd ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
As identified in the original issue, there is quadratic complexity in the number of input tensors, due to an improperly written sharding prop rule. The previous code generated N output strategies for the stack op, one based on each of the original N input strategies. However, Each of the N output strategies was the same. The heuristic in the stack rule is to find one of the N inputs and follow that one. We now just generate one output strategy. Fixes #169445 ghstack-source-id: 5b7b303 Pull Request resolved: #169519
| first_input_strategy = input_tuple_strategy.children[0] | ||
| if not isinstance(first_input_strategy, OpStrategy): | ||
| raise AssertionError(f"Expected OpStrategy, got {first_input_strategy}") | ||
| input_strategies: list[OpStrategy] = [] |
There was a problem hiding this comment.
this part was just to make mypy happy below since children are listed as 'StrategyType' which can be Tuple Strategy or OpStrategy, need to ensure they are all OpStrategy...
|
Hmm, this seems to be a good case where we need detect_exists_identical_opspec to verify op strategy to prevent generating the same opspec. |
why? I don't follow |
| output_spec = DTensorSpec(mesh, tuple(follow_placements)) | ||
| redistribute_cost = [] | ||
| for input_spec in input_specs: | ||
| cost = generate_redistribute_costs(strategy, input_spec) |
There was a problem hiding this comment.
@zpcore one thing i would like to confirm is, this old code looks incorrect to me, in addition to being slower.
we should never be generating the redistribute cost from input 2's placement to input1's dst spec, right? so using 'strategy' here was a bug?
To clearify, This is a necessary but not sufficient test to say the strategy is not generating duplicated OpSpecs. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / linux-jammy-rocm-py3.10 / test (default, 5, 6, linux.rocm.gpu.gfx942.1) Details for Dev Infra teamRaised by workflow job |
albanD
left a comment
There was a problem hiding this comment.
Sounds good!
Any reason this strategy is not shared with cat() ?
|
Historically, not sure. If they can be shared, I'll do it as part of a bigger rewrite I'm working on. |
|
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 1 checks: trunk / linux-jammy-rocm-py3.10 / test (default, 5, 6, linux.rocm.gpu.gfx942.1) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
@pytorchbot merge -f |
|
❌ 🤖 pytorchbot command failed: Try |
|
@pytorchbot merge -f "merge -i got stuck?" |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
As identified in the original issue, there is quadratic complexity in the number of input tensors, due to an improperly written sharding prop rule. The previous code generated N output strategies for the stack op, one based on each of the original N input strategies. However, Each of the N output strategies was the same. The heuristic in the stack rule is to find one of the N inputs and follow that one. We now just generate one output strategy. Fixes pytorch#169445 Pull Request resolved: pytorch#169519 Approved by: https://github.com/zpcore, https://github.com/malfet, https://github.com/albanD
As identified in the original issue, there is quadratic complexity in the number of input tensors, due to an improperly written sharding prop rule. The previous code generated N output strategies for the stack op, one based on each of the original N input strategies. However, Each of the N output strategies was the same. The heuristic in the stack rule is to find one of the N inputs and follow that one. We now just generate one output strategy. Fixes #169445 Pull Request resolved: #169519 Approved by: https://github.com/zpcore, https://github.com/malfet, https://github.com/albanD

Stack from ghstack (oldest at bottom):
As identified in the original issue, there is quadratic complexity in
the number of input tensors, due to an improperly written sharding prop
rule.
The previous code generated N output strategies for the stack op, one
based on each of the original N input strategies. However, Each of the
N output strategies was the same. The heuristic in the stack rule is to
find one of the N inputs and follow that one.
We now just generate one output strategy.
Fixes #169445