Always track _local_scalar_dense output in tensorify_python_scalars. #166573
Always track _local_scalar_dense output in tensorify_python_scalars. #166573laithsakka wants to merge 4 commits intogh/laithsakka/320/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166573
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (4 Unrelated Failures)As of commit b63a869 with merge base 1a67403 ( FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
work for
```
import torch
from torch._refs import full
import torch.nn.functional as F
from typing import Tuple
torch._dynamo.config.capture_scalar_outputs = True
torch.compile(fullgraph=True)
def _random_resize(image: torch.Tensor) -> Tuple[int, int, torch.Tensor]:
image_metanet = image
default_patch_size = 14
rand_cnn_resolution = (224, 256)
min_nump = rand_cnn_resolution[0] // default_patch_size
max_nump = rand_cnn_resolution[1] // default_patch_size
new_nump = torch.randint(min_nump, max_nump + 1, (1,)).item()
torch._check(new_nump>0)
torch._check(new_nump * default_patch_size >1)
# print(f"{new_nump=}, {self.default_patch_size=} {new_nump * self.default_patch_size=}")
image_metanet = F.interpolate(
image_metanet,
size=(new_nump * default_patch_size, new_nump * default_patch_size),
mode="bilinear",
align_corners=True,
)
img_h_new, img_w_new = image_metanet.shape[2:]
return (img_h_new, img_w_new), image_metanet
torch._logging.set_logs(graph_breaks=True)
_random_resize(torch.rand(1, 3, 224, 224))
```
used to fail
cc ezyang EikanWang jgong5 wenzhe-nrv
[ghstack-poisoned]
work for
```
import torch
from torch._refs import full
import torch.nn.functional as F
from typing import Tuple
torch._dynamo.config.capture_scalar_outputs = True
torch.compile(fullgraph=True)
def _random_resize(image: torch.Tensor) -> Tuple[int, int, torch.Tensor]:
image_metanet = image
default_patch_size = 14
rand_cnn_resolution = (224, 256)
min_nump = rand_cnn_resolution[0] // default_patch_size
max_nump = rand_cnn_resolution[1] // default_patch_size
new_nump = torch.randint(min_nump, max_nump + 1, (1,)).item()
torch._check(new_nump>0)
torch._check(new_nump * default_patch_size >1)
# print(f"{new_nump=}, {self.default_patch_size=} {new_nump * self.default_patch_size=}")
image_metanet = F.interpolate(
image_metanet,
size=(new_nump * default_patch_size, new_nump * default_patch_size),
mode="bilinear",
align_corners=True,
)
img_h_new, img_w_new = image_metanet.shape[2:]
return (img_h_new, img_w_new), image_metanet
torch._logging.set_logs(graph_breaks=True)
_random_resize(torch.rand(1, 3, 224, 224))
```
used to fail
cc ezyang EikanWang jgong5 wenzhe-nrv
[ghstack-poisoned]
|
The label |
|
The label |
|
The label |
|
The label |
|
|
||
| s = node.meta["val"].node.expr | ||
|
|
||
| # always track s. |
There was a problem hiding this comment.
remove comment or make it more clear?
…n_scalars. "
We need to track all symbols, we used to skip
u = item()
and fail with
```
File "/home/lsakka/pytorch10/pytorch/torch/fx/passes/_tensorify_python_scalars.py", line 149, in _sympy_interp
expr_to_sym_proxy[expr]
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
KeyError: u0
```
cc ezyang EikanWang jgong5 wenzhe-nrv voznesenskym penguinwu Guobing-Chen XiaobingSuper zhuhaozhe blzheng jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela
[ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
…ytorch#166573) We need to track all symbols, we used to skip u = item() and fail with ``` File "/home/lsakka/pytorch10/pytorch/torch/fx/passes/_tensorify_python_scalars.py", line 149, in _sympy_interp expr_to_sym_proxy[expr] torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: KeyError: u0 ``` Pull Request resolved: pytorch#166573 Approved by: https://github.com/bobrenjc93
Stack from ghstack (oldest at bottom):
We need to track all symbols, we used to skip
u = item()
and fail with
cc @ezyang @EikanWang @jgong5 @wenzhe-nrv @voznesenskym @penguinwu @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @jiayisunx @chenyang78 @kadeng @chauhang @amjames @Lucaskabela