Skip to content

Commit afe4b29

Browse files
committed
[jit] make nn.LSTM accept PackedSequence instead of Tuples
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 811724a Pull Request resolved: #23643
1 parent bab7d7a commit afe4b29

File tree

4 files changed

+5
-9
lines changed

4 files changed

+5
-9
lines changed

test/test_jit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12786,6 +12786,7 @@ def forward(self, x, lengths, h0, c0):
1278612786
self.assertEqual(eager_out, script_out)
1278712787

1278812788
def test_nn_LSTM(self):
12789+
from torch.nn.utils.rnn import PackedSequence
1278912790
input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)])
1279012791

1279112792
class S(torch.jit.ScriptModule):
@@ -12795,7 +12796,7 @@ def __init__(self):
1279512796

1279612797
@torch.jit.script_method
1279712798
def forward(self, input):
12798-
# type: (Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]) -> Tuple[Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Tuple[Tensor, Tensor]] # noqa
12799+
# type: (PackedSequence) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa
1279912800
return self.x(input)
1280012801

1280112802
eager_out = self.runAndSaveRNG(lambda x: torch.nn.LSTM(5, 5)(x), (input,))[0]

torch/jit/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2002,7 +2002,6 @@ def register_all(mod):
20022002
(torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"),
20032003
(torch.nn.init._no_grad_uniform_, "aten::_no_grad_uniform_"),
20042004
(torch.nn.init._no_grad_zero_, "aten::_no_grad_zero_"),
2005-
(torch.nn.utils.rnn.get_packed_sequence, "aten::_pack_sequence"),
20062005
(torch._C._get_tracing_state, "aten::_get_tracing_state"),
20072006
(warnings.warn, "aten::warn"),
20082007
]

torch/nn/modules/rnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .module import Module
77
from ..parameter import Parameter
8-
from ..utils.rnn import PackedSequence, get_packed_sequence
8+
from ..utils.rnn import PackedSequence
99
from .. import init
1010
from .. import _VF
1111
from ..._jit_internal import _parameter_list
@@ -544,14 +544,14 @@ def forward_tensor(self, input, hx=None):
544544

545545
@torch._jit_internal.export
546546
def forward_packed(self, input, hx=None):
547-
# type: (Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Tuple[Tensor, Tensor]] # noqa
547+
# type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa
548548
input, batch_sizes, sorted_indices, unsorted_indices = input
549549
max_batch_size = batch_sizes[0]
550550
max_batch_size = int(max_batch_size)
551551

552552
output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
553553

554-
output = get_packed_sequence(output, batch_sizes, sorted_indices, unsorted_indices)
554+
output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
555555
return output, self.permute_hidden(hidden, unsorted_indices)
556556

557557
@torch._jit_internal.ignore

torch/nn/utils/rnn.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,3 @@ def pack_sequence(sequences, enforce_sorted=True):
419419
"""
420420
lengths = [v.size(0) for v in sequences]
421421
return pack_padded_sequence(pad_sequence(sequences), lengths, enforce_sorted=enforce_sorted)
422-
423-
424-
def get_packed_sequence(data, batch_sizes, sorted_indices, unsorted_indices):
425-
return PackedSequence(data, batch_sizes, sorted_indices, unsorted_indices)

0 commit comments

Comments
 (0)