-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[jit] Support pack_padded_sequence and pad_packed_sequence #23249
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[jit] Support pack_padded_sequence and pad_packed_sequence gh-metadata: pytorch pytorch 23249 gh/wanchaol/34/head
eellison
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. maybe add a couple more test inputs
test/test_jit.py
Outdated
| if seq_lens[b] < T: | ||
| x[seq_lens[b]:, b, :] = 0 | ||
|
|
||
| eager_seq = pack_padded_pad_packed_script(x, seq_lens) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reason not to do checkScript here instead ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CheckScript is also checking the string frontend and the string frontend could not handle the comments starting with r..
test/test_jit.py
Outdated
|
|
||
| def pack_padded_pad_packed_script(x, seq_lens): | ||
| x = pack_padded_sequence(x, seq_lens) | ||
| x, _ = pad_packed_sequence(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check both outputs here ?
| 'values, and it will treat them as constants, likely rendering ' | ||
| 'the trace incorrect for any other combination of lengths.', | ||
| category=torch.jit.TracerWarning, stacklevel=2) | ||
| stacklevel=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we make JIT ignore the category? It's not a good idea to remove it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
category is hard to ignore atm because we don't have any way of typing torch.jit.TracerWarning
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I tried that and it's kind of hard now, plus other callsites does not have this category either way.. not sure if someone deleted them before.
| batch_dim = 0 if batch_first else 1 | ||
| return padded_output.index_select(batch_dim, sequence.unsorted_indices), \ | ||
| lengths[sequence.unsorted_indices] | ||
| return padded_output.index_select(batch_dim, unsorted_indices), lengths[unsorted_indices] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you have to rewrite this code? Because of optional unwrapping?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, our type refinement now is limited to TK_VAR only, we will need to expand it to a general case, see #23049
| .format(total_length, max_seq_length)) | ||
| max_seq_length = total_length | ||
| padded_output, lengths = torch._C._VariableFunctions._pad_packed_sequence( | ||
| padded_output, lengths = _VF._pad_packed_sequence( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this because the JIT can't handle torch._C._VariableFunctions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we are using a wrapper for _VariableFunctions (which is _VF) for it to be accessible in JIT.
Differential Revision: [D16466587](https://our.internmc.facebook.com/intern/diff/D16466587)
Differential Revision: [D16466587](https://our.internmc.facebook.com/intern/diff/D16466587)
Differential Revision: [D16466587](https://our.internmc.facebook.com/intern/diff/D16466587)
Differential Revision: [D16466587](https://our.internmc.facebook.com/intern/diff/D16466587)
…ked_sequence" Differential Revision: [D16466587](https://our.internmc.facebook.com/intern/diff/D16466587)
Stack from ghstack:
Differential Revision: D16466587