Skip to content

Fix various bugs in subclass input in export#163770

Closed
tugsbayasgalan wants to merge 5 commits intogh/tugsbayasgalan/40/basefrom
gh/tugsbayasgalan/40/head
Closed

Fix various bugs in subclass input in export#163770
tugsbayasgalan wants to merge 5 commits intogh/tugsbayasgalan/40/basefrom
gh/tugsbayasgalan/40/head

Conversation

@tugsbayasgalan
Copy link
Contributor

@tugsbayasgalan tugsbayasgalan commented Sep 24, 2025

Stack from ghstack (oldest at bottom):

This adds basic support for subclass inputs in export (specifically for non-strict). I had to make fakify little more complicated which risks further divergence from dynamo fakification. But dynamo one is so complex, so i feel it is better to do this way. Also improved fake mode detection logic to recursively look into subclass inner tensors.

Differential Revision: D83156489

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 24, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163770

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 07b7ae8 with merge base 082eaf4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

tugsbayasgalan added a commit that referenced this pull request Sep 24, 2025
ghstack-source-id: 5f14d99
Pull Request resolved: #163770
@tugsbayasgalan
Copy link
Contributor Author

@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 24, 2025
@tugsbayasgalan tugsbayasgalan requested a review from suo September 24, 2025 15:39

if is_traceable_wrapper_subclass(t):
# Get symbolic context for outer tensor
outer_context = _create_symbolic_context_for_tensor(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does specifying dynamic shapes for tensor subclasses look like? could you add a test or add some casing here to error if we try to have dynamic shapes with tensor subclasses?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think it works today hahaha yeah sure i will add a test case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apparently it works!

Copy link
Contributor

@avikchaudhuri avikchaudhuri left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lots of questions, mostly for my own education, although more comments would be nice

for i, flat_input in enumerate(flat_inputs):
if isinstance(flat_input, FakeTensor):
fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
if is_traceable_wrapper_subclass(flat_input):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a elif? We expect flat_input to be a tensor subclass instance here, and a fake tensor is not one of those subclasses (because it doesn't implement flatten/unflatten?).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah FakeTensor doens't have flatten/unflatten.

torch/_guards.py Outdated
out: list[torch.Tensor] = []
get_plain_tensors(flat_input, out=out) # type: ignore[arg-type]
fake_tensors: list[FakeTensor] = filter(
lambda x: isinstance(x, FakeTensor), out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the filter needed? Confused about what get_plain_tensors outputs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In particular, looking at that function we may return symint as well, shouldn't we look at fake modes for them too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SymInt's only have shape_env i think.

"""\
graph():
%x : [num_users=1] = placeholder[target=x]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 1), kwargs = {})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this checking, that the input is not desugared? perhaps you should add another part that runs decompositions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't support subclass inputs today in post-dispatch IR for export. We should implement it tho.


# Get symbolic contexts for inner tensors
inner_contexts = {} # mapping from attr -> symbolic context
attrs, _ = type(t).__tensor_flatten__(t)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_plain_tensors recursively calls flatten, which suggests you should too here. Or maybe create some util for it

t, source, t_constraints, sources, mode
)

fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So...what happens again when t is a subclass here? is the tensor subclass instance wrapped in a fake or the flattened tensors faked? Confused.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is gonna look like TwoTensor(TwoTensor(FakeTensor, FakeTensor), TwoTensor(FakeTensor, FakeTensor)).

tugsbayasgalan added a commit that referenced this pull request Sep 26, 2025
ghstack-source-id: 75d003f
Pull Request resolved: #163770
@tugsbayasgalan
Copy link
Contributor Author

@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

tugsbayasgalan added a commit that referenced this pull request Sep 26, 2025
ghstack-source-id: 960acec
Pull Request resolved: #163770
@tugsbayasgalan
Copy link
Contributor Author

@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

This adds basic support for subclass inputs in export (specifically for non-strict). I had to make fakify little more complicated which risks further divergence from dynamo fakification. But dynamo one is so complex, so i feel it is better to do this way. Also improved fake mode detection logic to recursively look into subclass inner tensors. 

Differential Revision: [D83156489](https://our.internmc.facebook.com/intern/diff/D83156489)

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Sep 26, 2025
ghstack-source-id: 3e51da4
Pull Request resolved: #163770
This adds basic support for subclass inputs in export (specifically for non-strict). I had to make fakify little more complicated which risks further divergence from dynamo fakification. But dynamo one is so complex, so i feel it is better to do this way. Also improved fake mode detection logic to recursively look into subclass inner tensors. 

Differential Revision: [D83156489](https://our.internmc.facebook.com/intern/diff/D83156489)

[ghstack-poisoned]
tugsbayasgalan added a commit that referenced this pull request Sep 26, 2025
ghstack-source-id: abe0ae2
Pull Request resolved: #163770
@tugsbayasgalan
Copy link
Contributor Author

@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@pytorchbot merge

(Initiating merge automatically since Phabricator Diff has merged)

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 2 jobs have failed, first few of them are: inductor / unit-test / inductor-cpu-test / test (inductor_avx2, 2, 2, linux.10xlarge.avx2), trunk / verify-cachebench-cpu-build / build

Details for Dev Infra team Raised by workflow job

@Camyll
Copy link
Contributor

Camyll commented Sep 28, 2025

@pytorchmergebot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

jainapurva pushed a commit that referenced this pull request Sep 29, 2025
This adds basic support for subclass inputs in export (specifically for non-strict). I had to make fakify little more complicated which risks further divergence from dynamo fakification. But dynamo one is so complex, so i feel it is better to do this way. Also improved fake mode detection logic to recursively look into subclass inner tensors.

Differential Revision: [D83156489](https://our.internmc.facebook.com/intern/diff/D83156489)
Pull Request resolved: #163770
Approved by: https://github.com/avikchaudhuri
maggiemoss pushed a commit to maggiemoss/pytorch that referenced this pull request Sep 29, 2025
This adds basic support for subclass inputs in export (specifically for non-strict). I had to make fakify little more complicated which risks further divergence from dynamo fakification. But dynamo one is so complex, so i feel it is better to do this way. Also improved fake mode detection logic to recursively look into subclass inner tensors.

Differential Revision: [D83156489](https://our.internmc.facebook.com/intern/diff/D83156489)
Pull Request resolved: pytorch#163770
Approved by: https://github.com/avikchaudhuri
@github-actions github-actions bot deleted the gh/tugsbayasgalan/40/head branch October 29, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants