Skip to content

Always track _local_scalar_dense output in tensorify_python_scalars. #166573

Closed
laithsakka wants to merge 4 commits intogh/laithsakka/320/basefrom
gh/laithsakka/320/head
Closed

Always track _local_scalar_dense output in tensorify_python_scalars. #166573
laithsakka wants to merge 4 commits intogh/laithsakka/320/basefrom
gh/laithsakka/320/head

Conversation

@laithsakka
Copy link
Contributor

@laithsakka laithsakka commented Oct 29, 2025

Stack from ghstack (oldest at bottom):

We need to track all symbols, we used to skip
u = item()
and fail with

 File "/home/lsakka/pytorch10/pytorch/torch/fx/passes/_tensorify_python_scalars.py", line 149, in _sympy_interp
    expr_to_sym_proxy[expr]
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
KeyError: u0

cc @ezyang @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 29, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166573

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (4 Unrelated Failures)

As of commit b63a869 with merge base 1a67403 (image):

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Oct 29, 2025
laithsakka added a commit that referenced this pull request Oct 29, 2025
ghstack-source-id: 0587136
Pull Request resolved: #166573
work for
```
import torch
from torch._refs import full
import torch.nn.functional as F
from typing import Tuple
torch._dynamo.config.capture_scalar_outputs = True

torch.compile(fullgraph=True)
def _random_resize(image: torch.Tensor) -> Tuple[int, int, torch.Tensor]:
    image_metanet = image
    default_patch_size = 14
    rand_cnn_resolution = (224, 256)
    min_nump = rand_cnn_resolution[0] // default_patch_size
    max_nump = rand_cnn_resolution[1] // default_patch_size
    new_nump = torch.randint(min_nump, max_nump + 1, (1,)).item()
    torch._check(new_nump>0)
    torch._check(new_nump * default_patch_size >1)
    # print(f"{new_nump=}, {self.default_patch_size=} {new_nump * self.default_patch_size=}")
    image_metanet = F.interpolate(
        image_metanet,
        size=(new_nump * default_patch_size, new_nump * default_patch_size),
        mode="bilinear",
        align_corners=True,
    )
    img_h_new, img_w_new = image_metanet.shape[2:]

    return (img_h_new, img_w_new), image_metanet

torch._logging.set_logs(graph_breaks=True)

_random_resize(torch.rand(1, 3, 224, 224))
``` 
used to fail


cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Oct 30, 2025
ghstack-source-id: 8ff3e74
Pull Request resolved: #166573
work for
```
import torch
from torch._refs import full
import torch.nn.functional as F
from typing import Tuple
torch._dynamo.config.capture_scalar_outputs = True

torch.compile(fullgraph=True)
def _random_resize(image: torch.Tensor) -> Tuple[int, int, torch.Tensor]:
    image_metanet = image
    default_patch_size = 14
    rand_cnn_resolution = (224, 256)
    min_nump = rand_cnn_resolution[0] // default_patch_size
    max_nump = rand_cnn_resolution[1] // default_patch_size
    new_nump = torch.randint(min_nump, max_nump + 1, (1,)).item()
    torch._check(new_nump>0)
    torch._check(new_nump * default_patch_size >1)
    # print(f"{new_nump=}, {self.default_patch_size=} {new_nump * self.default_patch_size=}")
    image_metanet = F.interpolate(
        image_metanet,
        size=(new_nump * default_patch_size, new_nump * default_patch_size),
        mode="bilinear",
        align_corners=True,
    )
    img_h_new, img_w_new = image_metanet.shape[2:]

    return (img_h_new, img_w_new), image_metanet

torch._logging.set_logs(graph_breaks=True)

_random_resize(torch.rand(1, 3, 224, 224))
``` 
used to fail


cc ezyang EikanWang jgong5 wenzhe-nrv

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Nov 13, 2025
ghstack-source-id: 64f747e
Pull Request resolved: #166573
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 13, 2025

The label module: dynamo is only applicable to issues and has been removed. Please only use this label on issues.

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 13, 2025

The label module: dynamo is only applicable to issues and has been removed. Please only use this label on issues.

@laithsakka laithsakka changed the title WIP Always track _local_scalar_dense output in tensorify_python_scalars. Nov 13, 2025
@pytorch-bot
Copy link

pytorch-bot bot commented Nov 13, 2025

The label module: dynamo is only applicable to issues and has been removed. Please only use this label on issues.

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 13, 2025

The label module: dynamo is only applicable to issues and has been removed. Please only use this label on issues.


s = node.meta["val"].node.expr

# always track s.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove comment or make it more clear?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

…n_scalars. "



We need to track all symbols, we used to skip
u = item()
and fail with 
```
 File "/home/lsakka/pytorch10/pytorch/torch/fx/passes/_tensorify_python_scalars.py", line 149, in _sympy_interp
    expr_to_sym_proxy[expr]
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
KeyError: u0
```


cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
laithsakka added a commit that referenced this pull request Nov 14, 2025
ghstack-source-id: 175e65f
Pull Request resolved: #166573
@laithsakka
Copy link
Contributor Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Nov 14, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Silv3S pushed a commit to Silv3S/pytorch that referenced this pull request Nov 18, 2025
…ytorch#166573)

We need to track all symbols, we used to skip
u = item()
and fail with
```
 File "/home/lsakka/pytorch10/pytorch/torch/fx/passes/_tensorify_python_scalars.py", line 149, in _sympy_interp
    expr_to_sym_proxy[expr]
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
KeyError: u0
```

Pull Request resolved: pytorch#166573
Approved by: https://github.com/bobrenjc93
@github-actions github-actions bot deleted the gh/laithsakka/320/head branch December 15, 2025 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request fx Merged module: dynamo release notes: fx release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants