Skip to content

minor proxy_tensor reorg#165266

Closed
aorenste wants to merge 4 commits intogh/aorenste/140/basefrom
gh/aorenste/140/head
Closed

minor proxy_tensor reorg#165266
aorenste wants to merge 4 commits intogh/aorenste/140/basefrom
gh/aorenste/140/head

Conversation

@aorenste
Copy link
Contributor

@aorenste aorenste commented Oct 12, 2025

Stack from ghstack (oldest at bottom):

Moving some code around in proxy_tensor in preparation for the next PR. There we
no actual changes (other than simple relabeling such as self.tracer ->
tracer):

  • Move _compute_proxy() out of ProxyTorchDispatchMode.

  • Give sympy_expr_tracker a structured type instead of object.

  • Split SymNode registration out of ProxyTorchDispatchMode.sym_dispatch() so
    it can be reused.

cc @ezyang @EikanWang @jgong5 @wenzhe-nrv

Moving some code around in proxy_tensor in preparation for the next PR. There we
no actual changes (other than simple relabeling such as `self.tracer` ->
`tracer`):

- Move _compute_proxy() out of ProxyTorchDispatchMode.

- Give `sympy_expr_tracker` a structured type instead of `object`.

- Split SymNode registration out of ProxyTorchDispatchMode.__sym_dispatch__() so
  it can be reused.

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165266

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit e539924 with merge base 992857e (image):

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Moving some code around in proxy_tensor in preparation for the next PR. There we
no actual changes (other than simple relabeling such as `self.tracer` ->
`tracer`):

- Move _compute_proxy() out of ProxyTorchDispatchMode.

- Give `sympy_expr_tracker` a structured type instead of `object`.

- Split SymNode registration out of ProxyTorchDispatchMode.__sym_dispatch__() so
  it can be reused.

[ghstack-poisoned]
Moving some code around in proxy_tensor in preparation for the next PR. There we
no actual changes (other than simple relabeling such as `self.tracer` ->
`tracer`):

- Move _compute_proxy() out of ProxyTorchDispatchMode.

- Give `sympy_expr_tracker` a structured type instead of `object`.

- Split SymNode registration out of ProxyTorchDispatchMode.__sym_dispatch__() so
  it can be reused.

[ghstack-poisoned]
Moving some code around in proxy_tensor in preparation for the next PR. There we
no actual changes (other than simple relabeling such as `self.tracer` ->
`tracer`):

- Move _compute_proxy() out of ProxyTorchDispatchMode.

- Give `sympy_expr_tracker` a structured type instead of `object`.

- Split SymNode registration out of ProxyTorchDispatchMode.__sym_dispatch__() so
  it can be reused.

[ghstack-poisoned]
@aorenste aorenste added the suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter) label Oct 16, 2025
@aorenste
Copy link
Contributor Author

The BC failure:

Warning: Function PythonKeyTracer: sympy_expr_tracker changed from dict[sympy.Symbol, object] to dict[sympy.Symbol, _SympyExprTrackerValue]

is on a variable that probably shouldn't be public...

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #164869

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #164718

pytorchmergebot pushed a commit that referenced this pull request Oct 16, 2025
…nputs (#164717)

In a training library we hit a weird conflict between dtensor, dynamic shapes, and proxy tensor.

The problem is occuring because in sharding_prop we use FakeTensors to compute an operation size (so we don't have to  use the full "real" data). We turn off proxy tracing while we're doing that because we don't want the FakeTensor ops to end up in the graph.  We then use that size when doing later operations.

Normally this is no problem - but when those sizes are dynamic shapes then we have a problem - the proxy tracer wants to track the provenance of all shape operations (`s1*s2`) but since tracing is disabled it doesn't see the operation and when we then use the result shape later on the proxy tracer gets all confused (because the SymNode appeared out of nowhere).

At first we were thinking to never disable shape tracing - but that caused a slew of other downstream problems (lots of code that actually needs the shape tracing to be disabled) so instead we enable having a "sym tracing override" and surgically when we disable proxy tracing we leave shape tracing enabled.

After this change the dtensor embedding is "fixed" but then runs afoul of a FakeTensor cache bug - which is fixed in the next PR.

Pull Request resolved: #164717
Approved by: https://github.com/bobrenjc93, https://github.com/ezyang
ghstack dependencies: #165266
pytorchmergebot pushed a commit that referenced this pull request Oct 16, 2025
Improve FakeTensor cache to handle SymNode and tracing properly.

For now, when we're proxy tracing just don't bother caching operations that contain SymNodes in the output. The problem is that the proxy tracer relies on SymNode identity and our cache doesn't preserve that. It can be fixed (and I left some notes in _validate_symbolic_output_for_caching() how) but it's not worth it for now.

If we aren't proxy tracing then caching is fine.

Thus these changes:

1. Our cache key needs to include whether we were actively tracing or not - this way if we create a cache entry when we weren't tracing and then we try to use it when we ARE tracing it gets rerun.

2. If there's a SymNode in the output then bypass tracing.

3. Some general cleanup of the output validation - we were unnecessarily doing it as a two-step process when it could just be a single step (it's still two parts internally but only a single outer try/except).

Pull Request resolved: #164718
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #165266, #164717
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
Moving some code around in proxy_tensor in preparation for the next PR. There we
no actual changes (other than simple relabeling such as `self.tracer` ->
`tracer`):

- Move _compute_proxy() out of ProxyTorchDispatchMode.

- Give `sympy_expr_tracker` a structured type instead of `object`.

- Split SymNode registration out of ProxyTorchDispatchMode.__sym_dispatch__() so
  it can be reused.

Pull Request resolved: pytorch#165266
Approved by: https://github.com/ezyang, https://github.com/mlazos
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
…nputs (pytorch#164717)

In a training library we hit a weird conflict between dtensor, dynamic shapes, and proxy tensor.

The problem is occuring because in sharding_prop we use FakeTensors to compute an operation size (so we don't have to  use the full "real" data). We turn off proxy tracing while we're doing that because we don't want the FakeTensor ops to end up in the graph.  We then use that size when doing later operations.

Normally this is no problem - but when those sizes are dynamic shapes then we have a problem - the proxy tracer wants to track the provenance of all shape operations (`s1*s2`) but since tracing is disabled it doesn't see the operation and when we then use the result shape later on the proxy tracer gets all confused (because the SymNode appeared out of nowhere).

At first we were thinking to never disable shape tracing - but that caused a slew of other downstream problems (lots of code that actually needs the shape tracing to be disabled) so instead we enable having a "sym tracing override" and surgically when we disable proxy tracing we leave shape tracing enabled.

After this change the dtensor embedding is "fixed" but then runs afoul of a FakeTensor cache bug - which is fixed in the next PR.

Pull Request resolved: pytorch#164717
Approved by: https://github.com/bobrenjc93, https://github.com/ezyang
ghstack dependencies: pytorch#165266
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
Improve FakeTensor cache to handle SymNode and tracing properly.

For now, when we're proxy tracing just don't bother caching operations that contain SymNodes in the output. The problem is that the proxy tracer relies on SymNode identity and our cache doesn't preserve that. It can be fixed (and I left some notes in _validate_symbolic_output_for_caching() how) but it's not worth it for now.

If we aren't proxy tracing then caching is fine.

Thus these changes:

1. Our cache key needs to include whether we were actively tracing or not - this way if we create a cache entry when we weren't tracing and then we try to use it when we ARE tracing it gets rerun.

2. If there's a SymNode in the output then bypass tracing.

3. Some general cleanup of the output validation - we were unnecessarily doing it as a two-step process when it could just be a single step (it's still two parts internally but only a single outer try/except).

Pull Request resolved: pytorch#164718
Approved by: https://github.com/bobrenjc93
ghstack dependencies: pytorch#165266, pytorch#164717
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
Moving some code around in proxy_tensor in preparation for the next PR. There we
no actual changes (other than simple relabeling such as `self.tracer` ->
`tracer`):

- Move _compute_proxy() out of ProxyTorchDispatchMode.

- Give `sympy_expr_tracker` a structured type instead of `object`.

- Split SymNode registration out of ProxyTorchDispatchMode.__sym_dispatch__() so
  it can be reused.

Pull Request resolved: pytorch#165266
Approved by: https://github.com/ezyang, https://github.com/mlazos
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
…nputs (pytorch#164717)

In a training library we hit a weird conflict between dtensor, dynamic shapes, and proxy tensor.

The problem is occuring because in sharding_prop we use FakeTensors to compute an operation size (so we don't have to  use the full "real" data). We turn off proxy tracing while we're doing that because we don't want the FakeTensor ops to end up in the graph.  We then use that size when doing later operations.

Normally this is no problem - but when those sizes are dynamic shapes then we have a problem - the proxy tracer wants to track the provenance of all shape operations (`s1*s2`) but since tracing is disabled it doesn't see the operation and when we then use the result shape later on the proxy tracer gets all confused (because the SymNode appeared out of nowhere).

At first we were thinking to never disable shape tracing - but that caused a slew of other downstream problems (lots of code that actually needs the shape tracing to be disabled) so instead we enable having a "sym tracing override" and surgically when we disable proxy tracing we leave shape tracing enabled.

After this change the dtensor embedding is "fixed" but then runs afoul of a FakeTensor cache bug - which is fixed in the next PR.

Pull Request resolved: pytorch#164717
Approved by: https://github.com/bobrenjc93, https://github.com/ezyang
ghstack dependencies: pytorch#165266
zhudada0120 pushed a commit to zhudada0120/pytorch that referenced this pull request Oct 22, 2025
Improve FakeTensor cache to handle SymNode and tracing properly.

For now, when we're proxy tracing just don't bother caching operations that contain SymNodes in the output. The problem is that the proxy tracer relies on SymNode identity and our cache doesn't preserve that. It can be fixed (and I left some notes in _validate_symbolic_output_for_caching() how) but it's not worth it for now.

If we aren't proxy tracing then caching is fine.

Thus these changes:

1. Our cache key needs to include whether we were actively tracing or not - this way if we create a cache entry when we weren't tracing and then we try to use it when we ARE tracing it gets rerun.

2. If there's a SymNode in the output then bypass tracing.

3. Some general cleanup of the output validation - we were unnecessarily doing it as a two-step process when it could just be a single step (it's still two parts internally but only a single outer try/except).

Pull Request resolved: pytorch#164718
Approved by: https://github.com/bobrenjc93
ghstack dependencies: pytorch#165266, pytorch#164717
@github-actions github-actions bot deleted the gh/aorenste/140/head branch November 16, 2025 02:20
Khanaksahu pushed a commit to Khanaksahu/pytorch-fork that referenced this pull request Nov 17, 2025
Moving some code around in proxy_tensor in preparation for the next PR. There we
no actual changes (other than simple relabeling such as `self.tracer` ->
`tracer`):

- Move _compute_proxy() out of ProxyTorchDispatchMode.

- Give `sympy_expr_tracker` a structured type instead of `object`.

- Split SymNode registration out of ProxyTorchDispatchMode.__sym_dispatch__() so
  it can be reused.

ghstack-source-id: 68b8dd8
Pull Request resolved: pytorch/pytorch#165266
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor fx Merged release notes: fx release notes category suppress-bc-linter Suppresses the failures of API backward-compatibility linter (Lint/bc_linter)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants