-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[ONNX] Support tuples in ScriptModule inputs/outputs #20784
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
31913e0 to
57b3715
Compare
72a71bb to
91aa7d4
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
Fix the lint error please |
91aa7d4 to
5b1da10
Compare
Fixed, thanks. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
32d045c to
baf2107
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
torch/csrc/jit/script/init.cpp
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it's a bit hard to understand the meaning of isDimensionedTensor's meaning. Can we get a better name for it? And add some comments why we need it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, changed to type_kind instead, marking what is the desired kind of tensor type.
houseroad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Could you address my inline comments?
baf2107 to
a0b42ae
Compare
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
houseroad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks!
d9e9558 to
9f53df3
Compare
Thanks! rebased to resolve conflict. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@houseroad has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
The test failures are unrelated. So landing |
|
@houseroad merged this pull request in a3db284. |
Add tests after [ONNX] Fix bug in exporting node with multiple outputs by scripting #20256 is merged
Support exporting ScriptModule with inputs/outputs of arbitrarily constructed tuples.
Moved the assigning of output shapes to after graph conversion to ONNX is completed. By then all tuples in the IR has already been lowered by the pass
_jit_pass_lower_all_tuples. If assigning output shapes is required to happen before that, we'll need to hand parse the tuple structures in the graph, and repeat the same logic in_jit_pass_lower_all_tuples. Handling inputs is easier because all tuple information is encoded within the input tensor type.Swap the order of
_jit_pass_lower_all_tuplesand_jit_pass_erase_number_types. Ops likeprim::TupleIndexrelies on index being a scalar._jit_pass_erase_number_typeswill convert these kind of scalars to tensors.