Skip to content

Conversation

@wanchaol
Copy link
Collaborator

@wanchaol wanchaol commented Jul 23, 2019

Stack from ghstack:

Differential Revision: D16466587

@pytorchbot pytorchbot added oncall: jit Add this issue/PR to JIT oncall triage queue module: nn Related to torch.nn labels Jul 23, 2019
@wanchaol wanchaol requested review from driazati, eellison and suo July 23, 2019 18:23
[jit] Support pack_padded_sequence and pad_packed_sequence

gh-metadata: pytorch pytorch 23249 gh/wanchaol/34/head
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. maybe add a couple more test inputs

test/test_jit.py Outdated
if seq_lens[b] < T:
x[seq_lens[b]:, b, :] = 0

eager_seq = pack_padded_pad_packed_script(x, seq_lens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reason not to do checkScript here instead ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CheckScript is also checking the string frontend and the string frontend could not handle the comments starting with r..

test/test_jit.py Outdated

def pack_padded_pad_packed_script(x, seq_lens):
x = pack_padded_sequence(x, seq_lens)
x, _ = pad_packed_sequence(x)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check both outputs here ?

'values, and it will treat them as constants, likely rendering '
'the trace incorrect for any other combination of lengths.',
category=torch.jit.TracerWarning, stacklevel=2)
stacklevel=2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we make JIT ignore the category? It's not a good idea to remove it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

category is hard to ignore atm because we don't have any way of typing torch.jit.TracerWarning

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I tried that and it's kind of hard now, plus other callsites does not have this category either way.. not sure if someone deleted them before.

batch_dim = 0 if batch_first else 1
return padded_output.index_select(batch_dim, sequence.unsorted_indices), \
lengths[sequence.unsorted_indices]
return padded_output.index_select(batch_dim, unsorted_indices), lengths[unsorted_indices]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you have to rewrite this code? Because of optional unwrapping?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, our type refinement now is limited to TK_VAR only, we will need to expand it to a general case, see #23049

.format(total_length, max_seq_length))
max_seq_length = total_length
padded_output, lengths = torch._C._VariableFunctions._pad_packed_sequence(
padded_output, lengths = _VF._pad_packed_sequence(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this because the JIT can't handle torch._C._VariableFunctions?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we are using a wrapper for _VariableFunctions (which is _VF) for it to be accessible in JIT.

@zou3519 zou3519 deleted the gh/wanchaol/34/head branch July 29, 2019 22:39
@facebook-github-bot
Copy link
Contributor

@wanchaol merged this pull request in f4eb93f.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: nn Related to torch.nn oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants