Add view support for library custom Function#164520
Add view support for library custom Function#164520albanD wants to merge 9 commits intogh/albanD/2/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164520
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 77ab566 with merge base 87c9fbd ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| self.assertIsInstance(args[0], MyTwoTensor) | ||
| res = args[0].a | ||
| # Always return a fresh Tensor! | ||
| return res.view_as(res) |
There was a problem hiding this comment.
No action necessary, but now that custom ops have official support for views, I wonder if there's things we can do to make the inference mode handling smoother.
With this PR today, I guess wrapper subclasses still need special handling for inference mode for views (where ordinarily inference-ness being preserved via this line, but can be broken if the outer tensor must be rewrapped in a new subclass)
There was a problem hiding this comment.
Do you have code sample that shows this concern?
There was a problem hiding this comment.
Because LoggingTensor doesn't do this special inference-mode handling for views, doing
import torch
from torch.testing._internal.logging_tensor import LoggingTensor
base = LoggingTensor(torch.rand(1, 2, 3)).requires_grad_(True)
with torch.inference_mode(True):
base.view(-1)Fails with
Traceback (most recent call last):
File "/Users/jw3468/local/.tests/tst1.py", line 9412, in <module>
base.view(-1)
RuntimeError: Cannot set version_counter for inference tensor
Example of how it is handled in nested tensor:
def view_default(func, *args, **kwargs):
...
with torch.inference_mode(inp.is_inference()):
return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp))There was a problem hiding this comment.
This is unrelated to the change at hand so will move that to a follow up
| keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs | ||
| ) | ||
| # Handle view + mutation that are in the schema | ||
| return original_kernel.call_boxed(keyset, *args, **kwargs) |
There was a problem hiding this comment.
If we're calling into the fallback now for the is_mutable case as well, being able to remove the whole branch would be a good simplification!
Is there something preventing us from doing that?
There was a problem hiding this comment.
No we cannot :/
The inplace logic is based on user anotation and the schema is not actually made to reflect these inplace. So we do need this custom python logic to do the increment_version as the fallback won't do them.
We could refactor the whole thing to have accurate schema and rely on the fallback to do all incref though. But that is much more involved.
soulitzer
left a comment
There was a problem hiding this comment.
Cool! Added some minor comments
| &autogradNotImplementedFallbackImpl>(); | ||
| } | ||
|
|
||
| struct GenericViewFunc : public ViewFunc { |
There was a problem hiding this comment.
So is this literally just "remember all non-Tensor arguments, and then rerun the function if you need to reapply the view"?
There was a problem hiding this comment.
Yes.
With extra complexity if there are tensor/symint other arguments, which I bypass
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Pull Request resolved: pytorch#164520 Approved by: https://github.com/soulitzer, https://github.com/ezyang
Stack from ghstack (oldest at bottom):