Skip to content

Conversation

@ColinPeppler
Copy link
Contributor

@ColinPeppler ColinPeppler commented Jul 25, 2025

Switch from guard_size_oblivious to guard_or_false if you encounter a DDE, this would then avoid folding this 3d bmm into a mm.

elif should_fold(tensor1, tensor2, is_out):
# dim_tensor1 >=3 && (dim_tensor2 == 1 || dim_tensor2 == 2) ||
# dim_tensor2 >=3 && (dim_tensor1 == 1 || dim_tensor1 == 2)
# and some condition on the strides is fulfilled
# optimization: use mm instead of bmm by folding the batch of the larger tensor
# into its leading matrix dimension

DDE

  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4506, in matmul
    elif should_fold(tensor1, tensor2, is_out):
  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4472, in should_fold
    if guard_size_oblivious(t1.numel() == 0):
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(12*((u0//2)), 0) (unhinted: Eq(12*((u0//2)), 0)).  (Size-like symbols: none)

Caused by: (_decomp/decompositions.py:4472 in should_fold)
  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4506, in matmul
    elif should_fold(tensor1, tensor2, is_out):
  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4483, in should_fold
    return all(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(3*((u0//2)), 3) (unhinted: Eq(3*((u0//2)), 3)).  (Size-like symbols: none)

Caused by: (_decomp/decompositions.py:4483 in should_fold)

Stack from ghstack (oldest at bottom):

cc @ezyang @penguinwu @bobrenjc93

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 25, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/159184

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 470fa9a with merge base 255a04b (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

ColinPeppler added a commit that referenced this pull request Jul 25, 2025
@ColinPeppler ColinPeppler changed the title gso to guard_or_false when checking should_fold on 3D matmul (should_fold) gso to guard_or_false when checking folding whether to 3d bmm into 2d mm Jul 25, 2025
@ColinPeppler
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jul 29, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

yangw-dev pushed a commit that referenced this pull request Aug 1, 2025
…3d bmm into 2d mm (#159184)

Switch from guard_size_oblivious to guard_or_false if you encounter a DDE, this would then avoid folding this 3d bmm into a mm.

https://github.com/pytorch/pytorch/blob/806d9e3fe70ec250a1fb3823841d16c61b7d1b02/torch/_decomp/decompositions.py#L4506-L4512

## DDE
```
  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4506, in matmul
    elif should_fold(tensor1, tensor2, is_out):
  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4472, in should_fold
    if guard_size_oblivious(t1.numel() == 0):
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(12*((u0//2)), 0) (unhinted: Eq(12*((u0//2)), 0)).  (Size-like symbols: none)

Caused by: (_decomp/decompositions.py:4472 in should_fold)
```

```
  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4506, in matmul
    elif should_fold(tensor1, tensor2, is_out):
  File "/data/users/colinpeppler/pytorch/torch/_decomp/decompositions.py", line 4483, in should_fold
    return all(
torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(3*((u0//2)), 3) (unhinted: Eq(3*((u0//2)), 3)).  (Size-like symbols: none)

Caused by: (_decomp/decompositions.py:4483 in should_fold)
```

Pull Request resolved: #159184
Approved by: https://github.com/ezyang
ghstack dependencies: #158894
@github-actions github-actions bot deleted the gh/ColinPeppler/78/head branch August 30, 2025 02:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants