[DTensor] ignore fresh unbacked symbols in shard prop#166989
[DTensor] ignore fresh unbacked symbols in shard prop#166989pianpwk wants to merge 2 commits intogh/pianpwk/29/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166989
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 5cdeade with merge base 82fa2aa ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| from torch.fx.experimental.proxy_tensor import disable_proxy_modes_tracing | ||
|
|
||
| with FakeTensorMode(), disable_proxy_modes_tracing(): | ||
| fake_mode = detect_fake_mode() or FakeTensorMode() |
There was a problem hiding this comment.
should we also initialize dummy shape env here as well? When there is no fake mode from tracing context, below lines will fail with errors like "NoneType doesn't have create_unbacked_symint". It seems to me that even in eager, you would do this fake tensor prop thing right?
| y = torch.randint(1, (10,)).bool() | ||
| x_dt = distribute_tensor(x, device_mesh, placements=[Replicate()]) | ||
| y_dt = distribute_tensor(y, device_mesh, placements=[Replicate()]) | ||
| _dynamo_graph_capture_for_export(Foo())(x_dt, y_dt) |
|
@laithsakka ptal |
|
|
||
| with FakeTensorMode(), disable_proxy_modes_tracing(): | ||
| fake_mode = detect_fake_mode() or FakeTensorMode() | ||
| suppress_fresh_symbols_ctx = ( |
There was a problem hiding this comment.
- can you add a comment explaining why ignore_fresh_unbacked_symbols() is safe here.
This fixes 2 issues with the DTensor data-dependent test case: 1) ShapeEnv not found when doing shard prop on data-dependent ops - fix was to detect the outer tracing fake mode. Maybe ShardingPropagator should just own a FakeMode & ShapeEnv for these purposes? The previous behavior was to initialize a new fake mode on every call. 2) Pending unbacked symbols not found. This happens because DTensor dispatch runs fake prop twice, once while figuring out the output sharding: https://github.com/pytorch/pytorch/blob/2bba37309bc8996fc6a190592e5ad9aac53761c9/torch/distributed/tensor/_sharding_prop.py#L175 and again to actually get the resulting local tensor: https://github.com/pytorch/pytorch/blob/2bba37309bc8996fc6a190592e5ad9aac53761c9/torch/distributed/tensor/_dispatch.py#L254-L255 With data-dependent ops, both calls will produce an unbacked symbol, but symbols in the first invocation are never surfaced, producing this error, so we ignore pending symbols from this site. cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta msaroufim dcci [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
ghstack-source-id: 0f80220 Pull Request resolved: pytorch/pytorch#166989
This fixes 2 issues with the DTensor data-dependent test case: 1) ShapeEnv not found when doing shard prop on data-dependent ops - fix was to detect the outer tracing fake mode. Maybe ShardingPropagator should just own a FakeMode & ShapeEnv for these purposes? The previous behavior was to initialize a new fake mode on every call. 2) Pending unbacked symbols not found. This happens because DTensor dispatch runs fake prop twice, once while figuring out the output sharding: https://github.com/pytorch/pytorch/blob/2bba37309bc8996fc6a190592e5ad9aac53761c9/torch/distributed/tensor/_sharding_prop.py#L175 and again to actually get the resulting local tensor: https://github.com/pytorch/pytorch/blob/2bba37309bc8996fc6a190592e5ad9aac53761c9/torch/distributed/tensor/_dispatch.py#L254-L255 With data-dependent ops, both calls will produce an unbacked symbol, but symbols in the first invocation are never surfaced, producing this error, so we ignore pending symbols from this site. Pull Request resolved: pytorch#166989 Approved by: https://github.com/ezyang
This fixes 2 issues with the DTensor data-dependent test case:
ShapeEnv not found when doing shard prop on data-dependent ops - fix was to detect the outer tracing fake mode. Maybe ShardingPropagator should just own a FakeMode & ShapeEnv for these purposes? The previous behavior was to initialize a new fake mode on every call.
Pending unbacked symbols not found. This happens because DTensor dispatch runs fake prop twice, once while figuring out the output sharding:
pytorch/torch/distributed/tensor/_sharding_prop.py
Line 175 in 2bba373
pytorch/torch/distributed/tensor/_dispatch.py
Lines 254 to 255 in 2bba373
Stack from ghstack (oldest at bottom):
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci