-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Update tls logic to work better with guarded call #73925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
CI Flow Status⚛️ CI FlowRuleset - Version:
|
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 3477503 (more details on the Dr. CI page):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
soulitzer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pretty cool! Just one small comment on naming:
I think some of the meaning of the terms shifted during this PR, which confused me for a bit. Previously TLSState referred to the current local dispatch key set. In this PR TLSState seems to refers to tls_on_entry. Stash also means something slightly different now.
|
Also TLSState = Thread Local State State? :P |
| self.assertEqual(counter[0], 1) | ||
| fwAD.make_dual(torch.rand_like(s), s) | ||
| self.assertEqual(counter[0], 2) | ||
| self.assertEqual(counter[0], 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I realize what is happening here now, my test last time was completely oblivious to what torch function does.
Before this PR torch_function was wrapping the output of torch.rand_like in the subclass, but now that we disable torch function for this subclass, that no longer happens, and the counter is only triggered once as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep
| def __new__(cls, data=None): | ||
| return torch.Tensor._make_subclass(cls, data) | ||
|
|
||
| __torch_function__ = torch._C._disabled_torch_function_impl |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(although I think we can get rid of this after #73942)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes depending on which one lands first.
|
FYI @ezyang as well as you're thinking on how to replace this system altogether. |
|
This is ready for a final review |
soulitzer
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
@albanD has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
Hey @albanD. |
|
@albanD @soulitzer this PR broke functorch: https://app.circleci.com/pipelines/github/pytorch/functorch/2032/workflows/2a0f2124-3870-4c6e-b95b-00951c1e9332/jobs/12098 (I confirmed by bisection to this PR that it is the one that causes the failure) I don't know the exact reason, but I suspect it has to do with the following:
This doesn't seem easy to resolve. Could you help take a look please? Minimal repro: |
|
Hmmm not sure why this PR would change this behavior as this was happening before? |
|
It might not look this way because of our lack of test coverage for this, but specifically this PR seems to break the torch_dispatch interaction with functorch. The main consequence of this right now is that we cannot compose AOTAutograd with functorch eager-mode transforms (e.g. vmap, grad, etc), so this is a really large breakage for us. It would be a huge favor to us if we could figure out how to resolve this soon. |
|
Here's my current understanding to what is happening. The purpose of this PR is to save the TLS when going from Python -> into the dispatcher so that when In the example, we are:
So I think the question now is how should the TLS logic in this PR interact with mode-style dispatch keys?
Imagine some Python code that looks like: This would end up running through HypotheticalMode twice. Would that be the desired behavior? |
Yes. The high level idea of this TLS logic is the following: if you call So in this case, that looks like the expected behavior to me? (not that we want to keep it necessarily but at least it does follow the high level idea) |
Description of the new behavior is in PythonFallbackKernel.cpp.
The updated test makes sure that we only call alias on the first Tensor.