Skip to content

Commit 9d2cc2c

Browse files
wanchaolfacebook-github-bot
authored andcommitted
Support nn.GRU in script
Summary: Pull Request resolved: #23266 Test Plan: Imported from OSS Differential Revision: D16466586 Pulled By: wanchaol fbshipit-source-id: 0f5b8013167bb7b246bd7e28d87a4a9e9c3b34d5
1 parent b22c88b commit 9d2cc2c

File tree

3 files changed

+103
-1
lines changed

3 files changed

+103
-1
lines changed

test/test_jit.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12897,6 +12897,40 @@ def forward(self, input):
1289712897

1289812898
self.assertEqual(eager_out, script_out)
1289912899

12900+
def test_nn_GRU(self):
12901+
from torch.nn.utils.rnn import PackedSequence
12902+
seq_input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)])
12903+
tensor_input = torch.randn(5, 5, 5)
12904+
12905+
class SeqLengthGRU(torch.jit.ScriptModule):
12906+
def __init__(self):
12907+
super(SeqLengthGRU, self).__init__()
12908+
self.x = torch.nn.GRU(5, 5)
12909+
12910+
@torch.jit.script_method
12911+
def forward(self, input):
12912+
# type: (PackedSequence) -> Tuple[PackedSequence, Tensor]
12913+
return self.x(input)
12914+
12915+
class TensorGRU(torch.jit.ScriptModule):
12916+
def __init__(self):
12917+
super(TensorGRU, self).__init__()
12918+
self.x = torch.nn.GRU(5, 5)
12919+
12920+
@torch.jit.script_method
12921+
def forward(self, input):
12922+
# type: (Tensor) -> Tuple[Tensor, Tensor]
12923+
return self.x(input)
12924+
12925+
seq_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (seq_input,))[0]
12926+
seq_script_out = self.runAndSaveRNG(lambda x: SeqLengthGRU()(x), (seq_input,))[0]
12927+
tensor_eager_out = self.runAndSaveRNG(lambda x: torch.nn.GRU(5, 5)(x), (tensor_input,))[0]
12928+
tensor_script_out = self.runAndSaveRNG(lambda x: TensorGRU()(x), (tensor_input,))[0]
12929+
12930+
self.assertEqual(seq_eager_out, seq_script_out)
12931+
self.assertEqual(tensor_eager_out, tensor_script_out)
12932+
12933+
1290012934
def test_torchscript_multi_head_attn(self):
1290112935
@torch.jit.script
1290212936
def jit_multihead_attn_forward(query, # type: Tensor

torch/nn/modules/rnn.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from ..._jit_internal import _parameter_list
1212

1313
_rnn_impls = {
14-
'GRU': _VF.gru,
1514
'RNN_TANH': _VF.rnn_tanh,
1615
'RNN_RELU': _VF.rnn_relu,
1716
}
@@ -167,12 +166,14 @@ def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size
167166
raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
168167

169168
def check_forward_args(self, input, hidden, batch_sizes):
169+
# type: (Tensor, Tensor, Optional[Tensor]) -> None
170170
self.check_input(input, batch_sizes)
171171
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
172172

173173
self.check_hidden_size(hidden, expected_hidden_size)
174174

175175
def permute_hidden(self, hx, permutation):
176+
# type: (Tensor, Optional[Tensor]) -> Tensor
176177
if permutation is None:
177178
return hx
178179
return apply_permutation(hx, permutation)
@@ -369,6 +370,16 @@ def __init__(self, *args, **kwargs):
369370
super(RNN, self).__init__(mode, *args, **kwargs)
370371

371372

373+
# XXX: LSTM and GRU implementation is different from RNNBase, this is because:
374+
# 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in
375+
# its current state could not support the python Union Type or Any Type.
376+
# 2. TorchScript static typing does not allow a Function or Callable type in
377+
# Dict values, so we have to separately call _VF instead of using _rnn_impls
378+
# 3. This is temporary only and in the transition state that we want to make it
379+
# on time for the release
380+
#
381+
# TODO: remove the overriding implementations for LSTM and GRU when TorchScript
382+
# support expressing these two modules generally.
372383
class LSTM(RNNBase):
373384
r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input
374385
sequence.
@@ -655,10 +666,66 @@ class GRU(RNNBase):
655666
>>> h0 = torch.randn(2, 3, 20)
656667
>>> output, hn = rnn(input, h0)
657668
"""
669+
__overloads__ = {'forward': ['forward_packed', 'forward_tensor']}
658670

659671
def __init__(self, *args, **kwargs):
660672
super(GRU, self).__init__('GRU', *args, **kwargs)
661673

674+
def run_impl(self, input, hx, batch_sizes):
675+
# type: (Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
676+
if batch_sizes is None:
677+
result = _VF.gru(input, hx, self._get_flat_weights(), self.bias, self.num_layers,
678+
self.dropout, self.training, self.bidirectional, self.batch_first)
679+
else:
680+
result = _VF.gru(input, batch_sizes, hx, self._get_flat_weights(), self.bias,
681+
self.num_layers, self.dropout, self.training, self.bidirectional)
682+
return result
683+
684+
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
685+
# type: (Tensor, Optional[Tensor], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa
686+
if hx is None:
687+
num_directions = 2 if self.bidirectional else 1
688+
hx = torch.zeros(self.num_layers * num_directions,
689+
max_batch_size, self.hidden_size,
690+
dtype=input.dtype, device=input.device)
691+
else:
692+
# Each batch of the hidden state should match the input sequence that
693+
# the user believes he/she is passing in.
694+
hx = self.permute_hidden(hx, sorted_indices)
695+
696+
self.check_forward_args(input, hx, batch_sizes)
697+
result = self.run_impl(input, hx, batch_sizes)
698+
output = result[0]
699+
hidden = result[1]
700+
return output, hidden
701+
702+
@torch._jit_internal.export
703+
def forward_packed(self, input, hx=None):
704+
# type: (PackedSequence, Optional[Tensor]) -> Tuple[PackedSequence, Tensor]
705+
input, batch_sizes, sorted_indices, unsorted_indices = input
706+
max_batch_size = batch_sizes[0]
707+
max_batch_size = int(max_batch_size)
708+
output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
709+
output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
710+
return output, self.permute_hidden(hidden, unsorted_indices)
711+
712+
@torch._jit_internal.export
713+
def forward_tensor(self, input, hx=None):
714+
# type: (Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
715+
batch_sizes = None
716+
max_batch_size = input.size(0) if self.batch_first else input.size(1)
717+
sorted_indices = None
718+
unsorted_indices = None
719+
output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices)
720+
return output, self.permute_hidden(hidden, unsorted_indices)
721+
722+
@torch._jit_internal.ignore
723+
def forward(self, input, hx=None):
724+
if isinstance(input, PackedSequence):
725+
return self.forward_packed(input, hx)
726+
else:
727+
return self.forward_tensor(input, hx)
728+
662729

663730
class RNNCellBase(Module):
664731
__constants__ = ['input_size', 'hidden_size', 'bias']

torch/nn/utils/rnn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ def pad_sequence(sequences, batch_first=False, padding_value=0):
387387

388388

389389
def pack_sequence(sequences, enforce_sorted=True):
390+
# type: (List[Tensor], bool) -> PackedSequence
390391
r"""Packs a list of variable length Tensors
391392
392393
``sequences`` should be a list of Tensors of size ``L x *``, where `L` is

0 commit comments

Comments
 (0)