Skip to content

[dynamo] Rehaul the autograd.Function support#166788

Closed
anijain2305 wants to merge 51 commits intogh/anijain2305/940/basefrom
gh/anijain2305/940/head
Closed

[dynamo] Rehaul the autograd.Function support#166788
anijain2305 wants to merge 51 commits intogh/anijain2305/940/basefrom
gh/anijain2305/940/head

Conversation

@anijain2305
Copy link
Contributor

@anijain2305 anijain2305 commented Nov 1, 2025

Stack from ghstack (oldest at bottom):

We make a rehaul because
(1) we want to support non-proxyable outputs in the fwd method
(2) we saw general softness in the support.

I have put lot of comments in the code.

Follow up

  • Graph break on backward stride dependent computation. This is BC breaking, so needs care.
  • Use DynamoAutogradFunctionTraceHelper for backward tracer.
  • Add test cases for module input, pytree input/outputs, pg groups
  • Consider unifying automatic and automatic_with_forced_placeholders
  • Better error messages - especially nonstrict trace for stride dependent backward computation.

cc @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben @jataylo @Lucaskabela @chenyang78

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 1, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/166788

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures, 11 Unrelated Failures

As of commit cad48c9 with merge base 033659b (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

anijain2305 added a commit that referenced this pull request Nov 1, 2025
ghstack-source-id: 073736f
Pull Request resolved: #166788
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Nov 1, 2025
ghstack-source-id: 2b2c4d2
Pull Request resolved: #166788
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Nov 2, 2025
ghstack-source-id: 26579df
Pull Request resolved: #166788
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Nov 2, 2025
ghstack-source-id: 2003a64
Pull Request resolved: #166788
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Nov 2, 2025
ghstack-source-id: 75c7530
Pull Request resolved: #166788
@anijain2305 anijain2305 added the keep-going Don't stop on first failure, keep running tests until the end label Nov 2, 2025
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Nov 2, 2025
ghstack-source-id: 89c93ea
Pull Request resolved: #166788
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Nov 3, 2025
ghstack-source-id: 3976ea3
Pull Request resolved: #166788
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Nov 3, 2025
ghstack-source-id: 1e6f854
Pull Request resolved: #166788
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx chenyang78 kadeng chauhang amjames Lucaskabela

[ghstack-poisoned]
anijain2305 added a commit that referenced this pull request Nov 3, 2025
ghstack-source-id: 669b27b
Pull Request resolved: #166788
@anijain2305 anijain2305 added the topic: not user facing topic category label Nov 3, 2025
[ghstack-poisoned]
umechand-amd pushed a commit to ROCm/pytorch that referenced this pull request Dec 8, 2025
We make a rehaul because
(1) we want to support non-proxyable outputs in the fwd method
(2) we saw general softness in the support.

I have put lot of comments in the code.

Follow up
* Graph break on backward stride dependent computation. This is BC breaking, so needs care.
* Use DynamoAutogradFunctionTraceHelper for backward tracer.
* Add test cases for module input, pytree input/outputs, pg groups
* Consider unifying `automatic` and `automatic_with_forced_placeholders`
* Better error messages - especially nonstrict trace for stride dependent backward computation.

Pull Request resolved: pytorch#166788
Approved by: https://github.com/zou3519
umechand-amd pushed a commit to ROCm/pytorch that referenced this pull request Dec 8, 2025
…d graph tracing (pytorch#169399)"

This reverts commit f874542.

Reverted pytorch#169399 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I need to revert this to revert pytorch#166788 ([comment](pytorch#169399 (comment)))
umechand-amd pushed a commit to ROCm/pytorch that referenced this pull request Dec 8, 2025
This reverts commit a84798e.

Reverted pytorch#166788 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to cause some numerical error in trunk ([comment](pytorch#166788 (comment)))
JacobSzwejbka pushed a commit that referenced this pull request Dec 8, 2025
We make a rehaul because
(1) we want to support non-proxyable outputs in the fwd method
(2) we saw general softness in the support.

I have put lot of comments in the code.

Follow up
* Graph break on backward stride dependent computation. This is BC breaking, so needs care.
* Use DynamoAutogradFunctionTraceHelper for backward tracer.
* Add test cases for module input, pytree input/outputs, pg groups
* Consider unifying `automatic` and `automatic_with_forced_placeholders`
* Better error messages - especially nonstrict trace for stride dependent backward computation.

Pull Request resolved: #166788
Approved by: https://github.com/zou3519
JacobSzwejbka pushed a commit that referenced this pull request Dec 8, 2025
…d graph tracing (#169399)"

This reverts commit f874542.

Reverted #169399 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I need to revert this to revert #166788 ([comment](#169399 (comment)))
JacobSzwejbka pushed a commit that referenced this pull request Dec 8, 2025
This reverts commit a84798e.

Reverted #166788 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to cause some numerical error in trunk ([comment](#166788 (comment)))
tiendatngcs pushed a commit to tiendatngcs/pytorch-Dec25 that referenced this pull request Dec 10, 2025
tiendatngcs pushed a commit to tiendatngcs/pytorch-Dec25 that referenced this pull request Dec 10, 2025
[ghstack-poisoned]
[ghstack-poisoned]
@anijain2305
Copy link
Contributor Author

anijain2305 commented Dec 11, 2025

@zou3519 I am debugging why test/inductor/test_inplacing_pass.py::TestReinplacingPassCorrectness::test_partitioner_recomputes_factory_ones_like_sin_op is failing for my PR

Observations

On main, this is the tlparse. With this PR, this is the tlparse.

  • On main, there is a graph break. That cause the only the forward graph of the autograd.Function to be compiled. There is no backward compilation. I think we just run the backward method as is. I would argue that the test might not even be completely valid here because partitioner is not even taking effect.
  • With my PR, I fixed a source bug which causes the whole graph to be captured. The test deviates from main branch right here. So the problem exists on main as well, it was just hidden. Moving on. Full graph is nice, but now we have partitioner coming into picture. Here ones_like is already decomposed into full, but the backward graph does not have full after the partitioning. Just for a comparison, here is the tlparse for the empty_like test which passes.

So my recommendation is to skip the test for now for this PR, because this is an existing issue. And handle that as a separate issue, opened here - #170160

[ghstack-poisoned]
@zou3519
Copy link
Contributor

zou3519 commented Dec 11, 2025

@anijain2305 can you add a test case that defends against the source bug? I've seen this issue actually but wasn't able to repro it. It can just be test_partitioner_recomputes_factory_ones_like_sin_op with a fullgraph=True and cleaned up.

I'm curious about what the tlparse looks like for the original PR - like, was the autograd.Function captured in whole? If it was captured in whole, did it run into the problem? Was the source bug introduced later?

I'm not completely sure I'd classify this as "a preexisting problem" without that analysis. This PR does increase the memory usage of that autograd.Function because inductor fails to reinplace the custom op in the backward, it trades-off the graph break for additional memory.

That being said, it's pretty rare for someone to use ones_like to allocate a tensor for a kernel, so I'm fine if you want to merge this. It's also a good time to yolo merge this (it's after the branch cut)

[ghstack-poisoned]
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #169399

1 similar comment
@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #169399

pytorchmergebot pushed a commit that referenced this pull request Dec 12, 2025
vishalgoyal316 pushed a commit to vishalgoyal316/pytorch that referenced this pull request Dec 17, 2025
We make a rehaul because
(1) we want to support non-proxyable outputs in the fwd method
(2) we saw general softness in the support.

I have put lot of comments in the code.

Follow up
* Graph break on backward stride dependent computation. This is BC breaking, so needs care.
* Use DynamoAutogradFunctionTraceHelper for backward tracer.
* Add test cases for module input, pytree input/outputs, pg groups
* Consider unifying `automatic` and `automatic_with_forced_placeholders`
* Better error messages - especially nonstrict trace for stride dependent backward computation.

Pull Request resolved: pytorch#166788
Approved by: https://github.com/zou3519
vishalgoyal316 pushed a commit to vishalgoyal316/pytorch that referenced this pull request Dec 17, 2025
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
We make a rehaul because
(1) we want to support non-proxyable outputs in the fwd method
(2) we saw general softness in the support.

I have put lot of comments in the code.

Follow up
* Graph break on backward stride dependent computation. This is BC breaking, so needs care.
* Use DynamoAutogradFunctionTraceHelper for backward tracer.
* Add test cases for module input, pytree input/outputs, pg groups
* Consider unifying `automatic` and `automatic_with_forced_placeholders`
* Better error messages - especially nonstrict trace for stride dependent backward computation.

Pull Request resolved: pytorch#166788
Approved by: https://github.com/zou3519
krastogi-in pushed a commit to krastogi-in/pytorch that referenced this pull request Jan 9, 2026
@github-actions github-actions bot deleted the gh/anijain2305/940/head branch January 12, 2026 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request keep-going Don't stop on first failure, keep running tests until the end Merged module: dynamo module: inductor Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants