Skip to content

Commit 63a0bb0

Browse files
Add typing annotations for torch.nn.quantized.dynamic.modules.rnn (#43186)
Summary: Fixes #43185 xref: [gh-43072](#43072) Pull Request resolved: #43186 Reviewed By: ezyang Differential Revision: D23441259 Pulled By: malfet fbshipit-source-id: 80265ae7f3a70f0087e620969dbd4aa8ca17c317
1 parent 8ca3913 commit 63a0bb0

File tree

2 files changed

+48
-46
lines changed

2 files changed

+48
-46
lines changed

mypy.ini

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,6 @@ ignore_errors = True
170170
[mypy-torch.nn.qat.modules.conv]
171171
ignore_errors = True
172172

173-
[mypy-torch.nn.quantized.dynamic.modules.rnn]
174-
ignore_errors = True
175-
176173
[mypy-torch.nn.quantized.dynamic.modules.linear]
177174
ignore_errors = True
178175

torch/nn/quantized/dynamic/modules/rnn.py

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
import torch
77
import torch.nn as nn
88
from torch import Tensor # noqa: F401
9-
from torch._jit_internal import Tuple, Optional, List # noqa: F401
9+
from torch._jit_internal import Tuple, Optional, List, Union, Dict # noqa: F401
1010
from torch.nn.utils.rnn import PackedSequence
1111
from torch.nn.quantized.modules.utils import _quantize_weight
1212

13-
def apply_permutation(tensor, permutation, dim=1):
14-
# type: (Tensor, Tensor, int) -> Tensor
13+
def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
1514
return tensor.index_select(dim, permutation)
1615

1716
class PackedParameter(torch.nn.Module):
@@ -53,12 +52,14 @@ def __init__(self, mode, input_size, hidden_size,
5352
self.training = False
5453
num_directions = 2 if bidirectional else 1
5554

56-
if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
57-
isinstance(dropout, bool):
55+
# "type: ignore" is required since ints and Numbers are not fully comparable
56+
# https://github.com/python/mypy/issues/8566
57+
if not isinstance(dropout, numbers.Number) \
58+
or not 0 <= dropout <= 1 or isinstance(dropout, bool): # type: ignore
5859
raise ValueError("dropout should be a number in range [0, 1] "
5960
"representing the probability of an element being "
6061
"zeroed")
61-
if dropout > 0 and num_layers == 1:
62+
if dropout > 0 and num_layers == 1: # type: ignore
6263
warnings.warn("dropout option adds dropout after all but last "
6364
"recurrent layer, so non-zero dropout expects "
6465
"num_layers greater than 1, but got dropout={} and "
@@ -149,8 +150,7 @@ def __repr__(self):
149150
main_str += ')'
150151
return main_str
151152

152-
def check_input(self, input, batch_sizes):
153-
# type: (Tensor, Optional[Tensor]) -> None
153+
def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
154154
expected_input_dim = 2 if batch_sizes is not None else 3
155155
if input.dim() != expected_input_dim:
156156
raise RuntimeError(
@@ -161,33 +161,31 @@ def check_input(self, input, batch_sizes):
161161
'input.size(-1) must be equal to input_size. Expected {}, got {}'.format(
162162
self.input_size, input.size(-1)))
163163

164-
def get_expected_hidden_size(self, input, batch_sizes):
165-
# type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
164+
def get_expected_hidden_size(self, input: Tensor, batch_sizes: Optional[Tensor]) -> Tuple[int, int, int]:
166165
if batch_sizes is not None:
167-
mini_batch = batch_sizes[0]
168-
mini_batch = int(mini_batch)
166+
mini_batch = int(batch_sizes[0])
169167
else:
170168
mini_batch = input.size(0) if self.batch_first else input.size(1)
171169
num_directions = 2 if self.bidirectional else 1
172170
expected_hidden_size = (self.num_layers * num_directions,
173171
mini_batch, self.hidden_size)
174172
return expected_hidden_size
175173

176-
def check_hidden_size(self, hx, expected_hidden_size, msg='Expected hidden size {}, got {}'):
177-
# type: (Tensor, Tuple[int, int, int], str) -> None
174+
def check_hidden_size(
175+
self, hx: Tensor, expected_hidden_size: Tuple[int, int, int],
176+
msg: str = 'Expected hidden size {}, got {}'
177+
) -> None:
178178
if hx.size() != expected_hidden_size:
179179
raise RuntimeError(msg.format(
180180
expected_hidden_size, list(hx.size())))
181181

182-
def check_forward_args(self, input, hidden, batch_sizes):
183-
# type: (Tensor, Tensor, Optional[Tensor]) -> None
182+
def check_forward_args(self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]) -> None:
184183
self.check_input(input, batch_sizes)
185184
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
186185
self.check_hidden_size(hidden, expected_hidden_size,
187186
msg='Expected hidden size {}, got {}')
188187

189-
def permute_hidden(self, hx, permutation):
190-
# type: (Tensor, Optional[Tensor]) -> Tensor
188+
def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor:
191189
if permutation is None:
192190
return hx
193191
return apply_permutation(hx, permutation)
@@ -287,7 +285,7 @@ def quantize_and_pack(w, b):
287285

288286
def _weight_bias(self):
289287
# Returns a dict of weights and biases
290-
weight_bias_dict = {'weight' : {}, 'bias' : {}}
288+
weight_bias_dict: Dict[str, Dict] = {'weight' : {}, 'bias' : {}}
291289
count = 0
292290
num_directions = 2 if self.bidirectional else 1
293291
for layer in range(self.num_layers):
@@ -337,8 +335,11 @@ def __init__(self, *args, **kwargs):
337335
def _get_name(self):
338336
return 'DynamicQuantizedLSTM'
339337

340-
def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
341-
# type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
338+
def forward_impl(
339+
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]],
340+
batch_sizes: Optional[Tensor], max_batch_size: int,
341+
sorted_indices: Optional[Tensor]
342+
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
342343
if hx is None:
343344
num_directions = 2 if self.bidirectional else 1
344345
zeros = torch.zeros(self.num_layers * num_directions,
@@ -367,8 +368,9 @@ def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
367368
return output, hidden
368369

369370
@torch.jit.export
370-
def forward_tensor(self, input, hx=None):
371-
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
371+
def forward_tensor(
372+
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
373+
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
372374
batch_sizes = None
373375
max_batch_size = input.size(0) if self.batch_first else input.size(1)
374376
sorted_indices = None
@@ -380,27 +382,32 @@ def forward_tensor(self, input, hx=None):
380382
return output, self.permute_hidden(hidden, unsorted_indices)
381383

382384
@torch.jit.export
383-
def forward_packed(self, input, hx=None):
384-
# type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa
385-
input, batch_sizes, sorted_indices, unsorted_indices = input
385+
def forward_packed(
386+
self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
387+
) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]: # noqa
388+
input_, batch_sizes, sorted_indices, unsorted_indices = input
386389
max_batch_size = batch_sizes[0]
387390
max_batch_size = int(max_batch_size)
388391

389-
output, hidden = self.forward_impl(
390-
input, hx, batch_sizes, max_batch_size, sorted_indices)
392+
output_, hidden = self.forward_impl(
393+
input_, hx, batch_sizes, max_batch_size, sorted_indices)
391394

392-
output = PackedSequence(output, batch_sizes,
395+
output = PackedSequence(output_, batch_sizes,
393396
sorted_indices, unsorted_indices)
394397
return output, self.permute_hidden(hidden, unsorted_indices)
395398

396-
def permute_hidden(self, hx, permutation):
397-
# type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
399+
# "type: ignore" is required due to issue #43072
400+
def permute_hidden( # type: ignore
401+
self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]
402+
) -> Tuple[Tensor, Tensor]:
398403
if permutation is None:
399404
return hx
400405
return apply_permutation(hx[0], permutation), apply_permutation(hx[1], permutation)
401406

402-
def check_forward_args(self, input, hidden, batch_sizes):
403-
# type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor])->None
407+
# "type: ignore" is required due to issue #43072
408+
def check_forward_args( # type: ignore
409+
self, input: Tensor, hidden: Tuple[Tensor, Tensor], batch_sizes: Optional[Tensor]
410+
) -> None:
404411
self.check_input(input, batch_sizes)
405412
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
406413

@@ -483,8 +490,7 @@ def check_forward_input(self, input):
483490
"input has inconsistent input_size: got {}, expected {}".format(
484491
input.size(1), self.input_size))
485492

486-
def check_forward_hidden(self, input, hx, hidden_label=''):
487-
# type: (Tensor, Tensor, str) -> None
493+
def check_forward_hidden(self, input: Tensor, hx: Tensor, hidden_label: str = '') -> None:
488494
if input.size(0) != hx.size(0):
489495
raise RuntimeError(
490496
"Input batch size {} doesn't match hidden{} batch size {}".format(
@@ -518,6 +524,8 @@ def from_float(cls, mod):
518524
if dtype not in supported_scalar_types:
519525
raise RuntimeError('Unsupported dtype for dynamic RNN quantization: {}'.format(dtype))
520526

527+
qRNNCellBase: Union[LSTMCell, GRUCell, RNNCell]
528+
521529
if type(mod) == torch.nn.LSTMCell:
522530
qRNNCellBase = LSTMCell(mod.input_size, mod.hidden_size, bias=mod.bias, dtype=dtype)
523531
elif type(mod) == torch.nn.GRUCell:
@@ -561,7 +569,7 @@ def process_weights(weight, bias, dtype):
561569

562570
def _weight_bias(self):
563571
# Returns a dict of weights and biases
564-
weight_bias_dict = {'weight' : {}, 'bias' : {}}
572+
weight_bias_dict: Dict[str, Dict] = {'weight' : {}, 'bias' : {}}
565573
w1, b1 = self._packed_weight_ih.__getstate__()[0]
566574
w2, b2 = self._packed_weight_hh.__getstate__()[0]
567575
weight_bias_dict['weight']['weight_ih'] = w1
@@ -614,8 +622,7 @@ def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh", dtyp
614622
def _get_name(self):
615623
return 'DynamicQuantizedRNNCell'
616624

617-
def forward(self, input, hx=None):
618-
# type: (Tensor, Optional[Tensor]) -> Tensor
625+
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
619626
self.check_forward_input(input)
620627
if hx is None:
621628
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
@@ -661,13 +668,12 @@ class LSTMCell(RNNCellBase):
661668
"""
662669

663670
def __init__(self, *args, **kwargs):
664-
super(LSTMCell, self).__init__(*args, num_chunks=4, **kwargs)
671+
super(LSTMCell, self).__init__(*args, num_chunks=4, **kwargs) # type: ignore
665672

666673
def _get_name(self):
667674
return 'DynamicQuantizedLSTMCell'
668675

669-
def forward(self, input, hx=None):
670-
# type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
676+
def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
671677
self.check_forward_input(input)
672678
if hx is None:
673679
zeros = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)
@@ -707,8 +713,7 @@ def __init__(self, input_size, hidden_size, bias=True, dtype=torch.qint8):
707713
def _get_name(self):
708714
return 'DynamicQuantizedGRUCell'
709715

710-
def forward(self, input, hx=None):
711-
# type: (Tensor, Optional[Tensor]) -> Tensor
716+
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
712717
self.check_forward_input(input)
713718
if hx is None:
714719
hx = torch.zeros(input.size(0), self.hidden_size, dtype=input.dtype, device=input.device)

0 commit comments

Comments
 (0)