Fix various bugs in subclass input in export#163770
Fix various bugs in subclass input in export#163770tugsbayasgalan wants to merge 5 commits intogh/tugsbayasgalan/40/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163770
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 07b7ae8 with merge base 082eaf4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
torch/_export/non_strict_utils.py
Outdated
|
|
||
| if is_traceable_wrapper_subclass(t): | ||
| # Get symbolic context for outer tensor | ||
| outer_context = _create_symbolic_context_for_tensor( |
There was a problem hiding this comment.
what does specifying dynamic shapes for tensor subclasses look like? could you add a test or add some casing here to error if we try to have dynamic shapes with tensor subclasses?
There was a problem hiding this comment.
i don't think it works today hahaha yeah sure i will add a test case.
There was a problem hiding this comment.
apparently it works!
avikchaudhuri
left a comment
There was a problem hiding this comment.
lots of questions, mostly for my own education, although more comments would be nice
| for i, flat_input in enumerate(flat_inputs): | ||
| if isinstance(flat_input, FakeTensor): | ||
| fake_modes.append((flat_input.fake_mode, "fake tensor input", i)) | ||
| if is_traceable_wrapper_subclass(flat_input): |
There was a problem hiding this comment.
Is this a elif? We expect flat_input to be a tensor subclass instance here, and a fake tensor is not one of those subclasses (because it doesn't implement flatten/unflatten?).
There was a problem hiding this comment.
Yeah FakeTensor doens't have flatten/unflatten.
torch/_guards.py
Outdated
| out: list[torch.Tensor] = [] | ||
| get_plain_tensors(flat_input, out=out) # type: ignore[arg-type] | ||
| fake_tensors: list[FakeTensor] = filter( | ||
| lambda x: isinstance(x, FakeTensor), out |
There was a problem hiding this comment.
Why is the filter needed? Confused about what get_plain_tensors outputs.
There was a problem hiding this comment.
In particular, looking at that function we may return symint as well, shouldn't we look at fake modes for them too?
There was a problem hiding this comment.
SymInt's only have shape_env i think.
| """\ | ||
| graph(): | ||
| %x : [num_users=1] = placeholder[target=x] | ||
| %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, 1), kwargs = {}) |
There was a problem hiding this comment.
what is this checking, that the input is not desugared? perhaps you should add another part that runs decompositions.
There was a problem hiding this comment.
We don't support subclass inputs today in post-dispatch IR for export. We should implement it tho.
torch/_export/non_strict_utils.py
Outdated
|
|
||
| # Get symbolic contexts for inner tensors | ||
| inner_contexts = {} # mapping from attr -> symbolic context | ||
| attrs, _ = type(t).__tensor_flatten__(t) |
There was a problem hiding this comment.
get_plain_tensors recursively calls flatten, which suggests you should too here. Or maybe create some util for it
| t, source, t_constraints, sources, mode | ||
| ) | ||
|
|
||
| fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context) |
There was a problem hiding this comment.
So...what happens again when t is a subclass here? is the tensor subclass instance wrapped in a fake or the flattened tensors faked? Confused.
There was a problem hiding this comment.
It is gonna look like TwoTensor(TwoTensor(FakeTensor, FakeTensor), TwoTensor(FakeTensor, FakeTensor)).
Differential Revision: [D83156489](https://our.internmc.facebook.com/intern/diff/D83156489) [ghstack-poisoned]
|
@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Differential Revision: [D83156489](https://our.internmc.facebook.com/intern/diff/D83156489) [ghstack-poisoned]
|
@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
This adds basic support for subclass inputs in export (specifically for non-strict). I had to make fakify little more complicated which risks further divergence from dynamo fakification. But dynamo one is so complex, so i feel it is better to do this way. Also improved fake mode detection logic to recursively look into subclass inner tensors. Differential Revision: [D83156489](https://our.internmc.facebook.com/intern/diff/D83156489) [ghstack-poisoned]
This adds basic support for subclass inputs in export (specifically for non-strict). I had to make fakify little more complicated which risks further divergence from dynamo fakification. But dynamo one is so complex, so i feel it is better to do this way. Also improved fake mode detection logic to recursively look into subclass inner tensors. Differential Revision: [D83156489](https://our.internmc.facebook.com/intern/diff/D83156489) [ghstack-poisoned]
|
@tugsbayasgalan has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
@pytorchbot merge (Initiating merge automatically since Phabricator Diff has merged) |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 2 jobs have failed, first few of them are: inductor / unit-test / inductor-cpu-test / test (inductor_avx2, 2, 2, linux.10xlarge.avx2), trunk / verify-cachebench-cpu-build / build Details for Dev Infra teamRaised by workflow job |
|
@pytorchmergebot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
This adds basic support for subclass inputs in export (specifically for non-strict). I had to make fakify little more complicated which risks further divergence from dynamo fakification. But dynamo one is so complex, so i feel it is better to do this way. Also improved fake mode detection logic to recursively look into subclass inner tensors. Differential Revision: [D83156489](https://our.internmc.facebook.com/intern/diff/D83156489) Pull Request resolved: #163770 Approved by: https://github.com/avikchaudhuri
This adds basic support for subclass inputs in export (specifically for non-strict). I had to make fakify little more complicated which risks further divergence from dynamo fakification. But dynamo one is so complex, so i feel it is better to do this way. Also improved fake mode detection logic to recursively look into subclass inner tensors. Differential Revision: [D83156489](https://our.internmc.facebook.com/intern/diff/D83156489) Pull Request resolved: pytorch#163770 Approved by: https://github.com/avikchaudhuri
Stack from ghstack (oldest at bottom):
This adds basic support for subclass inputs in export (specifically for non-strict). I had to make fakify little more complicated which risks further divergence from dynamo fakification. But dynamo one is so complex, so i feel it is better to do this way. Also improved fake mode detection logic to recursively look into subclass inner tensors.
Differential Revision: D83156489