Skip to content

Add view support for library custom Function#164520

Closed
albanD wants to merge 9 commits intogh/albanD/2/basefrom
gh/albanD/2/head
Closed

Add view support for library custom Function#164520
albanD wants to merge 9 commits intogh/albanD/2/basefrom
gh/albanD/2/head

Conversation

@albanD
Copy link
Collaborator

@albanD albanD commented Oct 2, 2025

[ghstack-poisoned]
@albanD albanD requested a review from soulitzer as a code owner October 2, 2025 22:42
@pytorch-bot
Copy link

pytorch-bot bot commented Oct 2, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/164520

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 77ab566 with merge base 87c9fbd (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

[ghstack-poisoned]
albanD added 2 commits October 3, 2025 11:27
[ghstack-poisoned]
[ghstack-poisoned]
self.assertIsInstance(args[0], MyTwoTensor)
res = args[0].a
# Always return a fresh Tensor!
return res.view_as(res)
Copy link
Contributor

@soulitzer soulitzer Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No action necessary, but now that custom ops have official support for views, I wonder if there's things we can do to make the inference mode handling smoother.

With this PR today, I guess wrapper subclasses still need special handling for inference mode for views (where ordinarily inference-ness being preserved via this line, but can be broken if the outer tensor must be rewrapped in a new subclass)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have code sample that shows this concern?

Copy link
Contributor

@soulitzer soulitzer Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because LoggingTensor doesn't do this special inference-mode handling for views, doing

import torch
from torch.testing._internal.logging_tensor import LoggingTensor

base = LoggingTensor(torch.rand(1, 2, 3)).requires_grad_(True)

with torch.inference_mode(True):
    base.view(-1)

Fails with

Traceback (most recent call last):
  File "/Users/jw3468/local/.tests/tst1.py", line 9412, in <module>
    base.view(-1)
RuntimeError: Cannot set version_counter for inference tensor

Example of how it is handled in nested tensor:

def view_default(func, *args, **kwargs):
    ...
   with torch.inference_mode(inp.is_inference()):
        return NestedTensor(func(inp._values, inner_size), **extract_kwargs(inp))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unrelated to the change at hand so will move that to a follow up

albanD added 3 commits October 3, 2025 14:33
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
keyset & _C._after_ADInplaceOrView_keyset, *args, **kwargs
)
# Handle view + mutation that are in the schema
return original_kernel.call_boxed(keyset, *args, **kwargs)
Copy link
Contributor

@soulitzer soulitzer Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're calling into the fallback now for the is_mutable case as well, being able to remove the whole branch would be a good simplification!
Is there something preventing us from doing that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No we cannot :/
The inplace logic is based on user anotation and the schema is not actually made to reflect these inplace. So we do need this custom python logic to do the increment_version as the fallback won't do them.

We could refactor the whole thing to have accurate schema and rely on the fallback to do all incref though. But that is much more involved.

Copy link
Contributor

@soulitzer soulitzer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Added some minor comments

[ghstack-poisoned]
&autogradNotImplementedFallbackImpl>();
}

struct GenericViewFunc : public ViewFunc {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So is this literally just "remember all non-Tensor arguments, and then rerun the function if you need to reapply the view"?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.
With extra complexity if there are tensor/symint other arguments, which I bypass

[ghstack-poisoned]
@albanD
Copy link
Collaborator Author

albanD commented Oct 9, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 9, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk Trigger trunk jobs on your pull request Merged release notes: python_frontend python frontend release notes category topic: improvements topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants