-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Handle sequence lengths correctly when exporting RNNs to ONNX #4695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@anderspapitto, thanks for your PR! We identified @zdevito to be a potential reviewer. |
houseroad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add some docs about OnnxElidedInput type node. We should also have a better story for multiple optional parameters.
torch/csrc/jit/export.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
e0c5aa8 to
08c9dd9
Compare
|
@anderspapitto and I discussed using Undefined instead of introducing a new ElidedOnnxNode. |
a6cb557 to
fd3989d
Compare
torch/csrc/jit/export.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/csrc/jit/export.cpp
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/module.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/rnn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/rnn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
a265f8a to
c115e76
Compare
torch/nn/_functions/packing.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/_functions/rnn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/_functions/rnn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/rnn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/rnn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/rnn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
f4696a4 to
d57e226
Compare
|
[WIP] btw. There's one or two other components that I have to do to get this to export correctly |
e39ae15 to
aa29c97
Compare
|
@dzhulgakov yeah, need to wrap list of integers with Variable() Although, actually I expect that in many cases, people will already have the lengths stored in a Variable, and were having to explicitly convert to list of int (because semantically, lengths/batch sizes are just as much a feature of the dynamic input as the actual float values) - so actually they will just be able to remove any such code. |
|
@anderspapitto I don't think that's the case. |
|
The |
|
@apaszke I actually don't have many examples, but it is true in the one (unfortunately internal, so I can't link it) case that I have looked at. Anyway I think I didn't say that in the clearest way, so let me give it another go: Whenever a user calls |
torch/autograd/function.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
test/test_nn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/_functions/packing.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/_functions/rnn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/rnn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/modules/rnn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/rnn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/rnn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/nn/utils/rnn.py
Outdated
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
841b98c to
3e2cd64
Compare
|
addressed all comments |
- when uniform sequence lengths are provided, correctly omit the argument when constructing the ONNX graph, so as to not fix the graph to the batch size. - handle PackedSequences by floating them through the graph and eliminating them in an optimization pass. ONNX does not have packed sequences, but operates on a representation equivalent to PaddedSequence, so we hide the representation-switching from ONNX - as a preliminary step towards handling PackedSequences, not directly tied to ONNX export, change batch_sizes from being an argument to the RNN operators into being an argument to the forward() function of those RNN operators. This more closely models the reality that batch_sizes are effectively part of the input sequences.
3e2cd64 to
c9c0a59
Compare
how can I know if this is spurious or something I need to address? 116.0 vs 115.5 seems small but maybe it isn't? |
This was accidentally lost while addressing review comments on pytorch#4695 pack_padded_sequence may be called either with a list or with a Variable. If called with a list we convert to Variable internally. I added to test_nn to test the new codepath. The bug was also caught by the onnx-fb-universe tests (which rely on passing in Variable).
This was accidentally lost while addressing review comments on #4695 pack_padded_sequence may be called either with a list or with a Variable. If called with a list we convert to Variable internally. I added to test_nn to test the new codepath. The bug was also caught by the onnx-fb-universe tests (which rely on passing in Variable).
PackedSequence: store batch_sizes as tensor rather
than converting to a list of python integers. This maintains
the invariant that module's inputs/outputs are collections of
Variables.
In particular, this causes the JIT to no longer choke when flattening
and unflattening arguments.
when uniform sequence lengths are provided, correctly omit the
argument when constructing the ONNX graph, so as to not fix the
graph to the batch size.
handle PackedSequences by floating them through the graph and
eliminating them in an optimization pass. ONNX does not have packed
sequences, but operates on a representation equivalent to
PaddedSequence, so we hide the representation-switching from ONNX
as a preliminary step towards handling PackedSequences, not directly
tied to ONNX export, change batch_sizes from being an argument to
the RNN operators into being an argument to the forward() function
of those RNN operators. This more closely models the reality that
batch_sizes are effectively part of the input sequences.