|
11 | 11 | from ..._jit_internal import _parameter_list |
12 | 12 |
|
13 | 13 | _rnn_impls = { |
14 | | - 'GRU': _VF.gru, |
15 | 14 | 'RNN_TANH': _VF.rnn_tanh, |
16 | 15 | 'RNN_RELU': _VF.rnn_relu, |
17 | 16 | } |
@@ -167,12 +166,14 @@ def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size |
167 | 166 | raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size()))) |
168 | 167 |
|
169 | 168 | def check_forward_args(self, input, hidden, batch_sizes): |
| 169 | + # type: (Tensor, Tensor, Optional[Tensor]) -> None |
170 | 170 | self.check_input(input, batch_sizes) |
171 | 171 | expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) |
172 | 172 |
|
173 | 173 | self.check_hidden_size(hidden, expected_hidden_size) |
174 | 174 |
|
175 | 175 | def permute_hidden(self, hx, permutation): |
| 176 | + # type: (Tensor, Optional[Tensor]) -> Tensor |
176 | 177 | if permutation is None: |
177 | 178 | return hx |
178 | 179 | return apply_permutation(hx, permutation) |
@@ -369,6 +370,16 @@ def __init__(self, *args, **kwargs): |
369 | 370 | super(RNN, self).__init__(mode, *args, **kwargs) |
370 | 371 |
|
371 | 372 |
|
| 373 | +# XXX: LSTM and GRU implementation is different from RNNBase, this is because: |
| 374 | +# 1. we want to support nn.LSTM and nn.GRU in TorchScript and TorchScript in |
| 375 | +# its current state could not support the python Union Type or Any Type. |
| 376 | +# 2. TorchScript static typing does not allow a Function or Callable type in |
| 377 | +# Dict values, so we have to separately call _VF instead of using _rnn_impls |
| 378 | +# 3. This is temporary only and in the transition state that we want to make it |
| 379 | +# on time for the release |
| 380 | +# |
| 381 | +# TODO: remove the overriding implementations for LSTM and GRU when TorchScript |
| 382 | +# support expressing these two modules generally. |
372 | 383 | class LSTM(RNNBase): |
373 | 384 | r"""Applies a multi-layer long short-term memory (LSTM) RNN to an input |
374 | 385 | sequence. |
@@ -655,10 +666,66 @@ class GRU(RNNBase): |
655 | 666 | >>> h0 = torch.randn(2, 3, 20) |
656 | 667 | >>> output, hn = rnn(input, h0) |
657 | 668 | """ |
| 669 | + __overloads__ = {'forward': ['forward_packed', 'forward_tensor']} |
658 | 670 |
|
659 | 671 | def __init__(self, *args, **kwargs): |
660 | 672 | super(GRU, self).__init__('GRU', *args, **kwargs) |
661 | 673 |
|
| 674 | + def run_impl(self, input, hx, batch_sizes): |
| 675 | + # type: (Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor] |
| 676 | + if batch_sizes is None: |
| 677 | + result = _VF.gru(input, hx, self._get_flat_weights(), self.bias, self.num_layers, |
| 678 | + self.dropout, self.training, self.bidirectional, self.batch_first) |
| 679 | + else: |
| 680 | + result = _VF.gru(input, batch_sizes, hx, self._get_flat_weights(), self.bias, |
| 681 | + self.num_layers, self.dropout, self.training, self.bidirectional) |
| 682 | + return result |
| 683 | + |
| 684 | + def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices): |
| 685 | + # type: (Tensor, Optional[Tensor], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tensor] # noqa |
| 686 | + if hx is None: |
| 687 | + num_directions = 2 if self.bidirectional else 1 |
| 688 | + hx = torch.zeros(self.num_layers * num_directions, |
| 689 | + max_batch_size, self.hidden_size, |
| 690 | + dtype=input.dtype, device=input.device) |
| 691 | + else: |
| 692 | + # Each batch of the hidden state should match the input sequence that |
| 693 | + # the user believes he/she is passing in. |
| 694 | + hx = self.permute_hidden(hx, sorted_indices) |
| 695 | + |
| 696 | + self.check_forward_args(input, hx, batch_sizes) |
| 697 | + result = self.run_impl(input, hx, batch_sizes) |
| 698 | + output = result[0] |
| 699 | + hidden = result[1] |
| 700 | + return output, hidden |
| 701 | + |
| 702 | + @torch._jit_internal.export |
| 703 | + def forward_packed(self, input, hx=None): |
| 704 | + # type: (PackedSequence, Optional[Tensor]) -> Tuple[PackedSequence, Tensor] |
| 705 | + input, batch_sizes, sorted_indices, unsorted_indices = input |
| 706 | + max_batch_size = batch_sizes[0] |
| 707 | + max_batch_size = int(max_batch_size) |
| 708 | + output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices) |
| 709 | + output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices) |
| 710 | + return output, self.permute_hidden(hidden, unsorted_indices) |
| 711 | + |
| 712 | + @torch._jit_internal.export |
| 713 | + def forward_tensor(self, input, hx=None): |
| 714 | + # type: (Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor] |
| 715 | + batch_sizes = None |
| 716 | + max_batch_size = input.size(0) if self.batch_first else input.size(1) |
| 717 | + sorted_indices = None |
| 718 | + unsorted_indices = None |
| 719 | + output, hidden = self.forward_impl(input, hx, batch_sizes, max_batch_size, sorted_indices) |
| 720 | + return output, self.permute_hidden(hidden, unsorted_indices) |
| 721 | + |
| 722 | + @torch._jit_internal.ignore |
| 723 | + def forward(self, input, hx=None): |
| 724 | + if isinstance(input, PackedSequence): |
| 725 | + return self.forward_packed(input, hx) |
| 726 | + else: |
| 727 | + return self.forward_tensor(input, hx) |
| 728 | + |
662 | 729 |
|
663 | 730 | class RNNCellBase(Module): |
664 | 731 | __constants__ = ['input_size', 'hidden_size', 'bias'] |
|
0 commit comments