You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Handle sequence lengths correctly when exporting RNNs to ONNX (#4695)
* PackedSequence: store batch_sizes as tensor
rather than converting to a list of python integers. This maintains
the invariant that module's inputs/outputs are collections of
Variables.
In particular, this causes the JIT to no longer choke when flattening
and unflattening arguments.
* Handle sequence lengths correctly when exporting RNNs to ONNX
- when uniform sequence lengths are provided, correctly omit the
argument when constructing the ONNX graph, so as to not fix the
graph to the batch size.
- handle PackedSequences by floating them through the graph and
eliminating them in an optimization pass. ONNX does not have packed
sequences, but operates on a representation equivalent to
PaddedSequence, so we hide the representation-switching from ONNX
- as a preliminary step towards handling PackedSequences, not directly
tied to ONNX export, change batch_sizes from being an argument to
the RNN operators into being an argument to the forward() function
of those RNN operators. This more closely models the reality that
batch_sizes are effectively part of the input sequences.
0 commit comments