66import torch
77import torch .nn as nn
88from torch import Tensor # noqa: F401
9- from torch ._jit_internal import Tuple , Optional , List # noqa: F401
9+ from torch ._jit_internal import Tuple , Optional , List , Union , Dict # noqa: F401
1010from torch .nn .utils .rnn import PackedSequence
1111from torch .nn .quantized .modules .utils import _quantize_weight
1212
13- def apply_permutation (tensor , permutation , dim = 1 ):
14- # type: (Tensor, Tensor, int) -> Tensor
13+ def apply_permutation (tensor : Tensor , permutation : Tensor , dim : int = 1 ) -> Tensor :
1514 return tensor .index_select (dim , permutation )
1615
1716class PackedParameter (torch .nn .Module ):
@@ -53,12 +52,14 @@ def __init__(self, mode, input_size, hidden_size,
5352 self .training = False
5453 num_directions = 2 if bidirectional else 1
5554
56- if not isinstance (dropout , numbers .Number ) or not 0 <= dropout <= 1 or \
57- isinstance (dropout , bool ):
55+ # "type: ignore" is required since ints and Numbers are not fully comparable
56+ # https://github.com/python/mypy/issues/8566
57+ if not isinstance (dropout , numbers .Number ) \
58+ or not 0 <= dropout <= 1 or isinstance (dropout , bool ): # type: ignore
5859 raise ValueError ("dropout should be a number in range [0, 1] "
5960 "representing the probability of an element being "
6061 "zeroed" )
61- if dropout > 0 and num_layers == 1 :
62+ if dropout > 0 and num_layers == 1 : # type: ignore
6263 warnings .warn ("dropout option adds dropout after all but last "
6364 "recurrent layer, so non-zero dropout expects "
6465 "num_layers greater than 1, but got dropout={} and "
@@ -149,8 +150,7 @@ def __repr__(self):
149150 main_str += ')'
150151 return main_str
151152
152- def check_input (self , input , batch_sizes ):
153- # type: (Tensor, Optional[Tensor]) -> None
153+ def check_input (self , input : Tensor , batch_sizes : Optional [Tensor ]) -> None :
154154 expected_input_dim = 2 if batch_sizes is not None else 3
155155 if input .dim () != expected_input_dim :
156156 raise RuntimeError (
@@ -161,33 +161,31 @@ def check_input(self, input, batch_sizes):
161161 'input.size(-1) must be equal to input_size. Expected {}, got {}' .format (
162162 self .input_size , input .size (- 1 )))
163163
164- def get_expected_hidden_size (self , input , batch_sizes ):
165- # type: (Tensor, Optional[Tensor]) -> Tuple[int, int, int]
164+ def get_expected_hidden_size (self , input : Tensor , batch_sizes : Optional [Tensor ]) -> Tuple [int , int , int ]:
166165 if batch_sizes is not None :
167- mini_batch = batch_sizes [0 ]
168- mini_batch = int (mini_batch )
166+ mini_batch = int (batch_sizes [0 ])
169167 else :
170168 mini_batch = input .size (0 ) if self .batch_first else input .size (1 )
171169 num_directions = 2 if self .bidirectional else 1
172170 expected_hidden_size = (self .num_layers * num_directions ,
173171 mini_batch , self .hidden_size )
174172 return expected_hidden_size
175173
176- def check_hidden_size (self , hx , expected_hidden_size , msg = 'Expected hidden size {}, got {}' ):
177- # type: (Tensor, Tuple[int, int, int], str) -> None
174+ def check_hidden_size (
175+ self , hx : Tensor , expected_hidden_size : Tuple [int , int , int ],
176+ msg : str = 'Expected hidden size {}, got {}'
177+ ) -> None :
178178 if hx .size () != expected_hidden_size :
179179 raise RuntimeError (msg .format (
180180 expected_hidden_size , list (hx .size ())))
181181
182- def check_forward_args (self , input , hidden , batch_sizes ):
183- # type: (Tensor, Tensor, Optional[Tensor]) -> None
182+ def check_forward_args (self , input : Tensor , hidden : Tensor , batch_sizes : Optional [Tensor ]) -> None :
184183 self .check_input (input , batch_sizes )
185184 expected_hidden_size = self .get_expected_hidden_size (input , batch_sizes )
186185 self .check_hidden_size (hidden , expected_hidden_size ,
187186 msg = 'Expected hidden size {}, got {}' )
188187
189- def permute_hidden (self , hx , permutation ):
190- # type: (Tensor, Optional[Tensor]) -> Tensor
188+ def permute_hidden (self , hx : Tensor , permutation : Optional [Tensor ]) -> Tensor :
191189 if permutation is None :
192190 return hx
193191 return apply_permutation (hx , permutation )
@@ -287,7 +285,7 @@ def quantize_and_pack(w, b):
287285
288286 def _weight_bias (self ):
289287 # Returns a dict of weights and biases
290- weight_bias_dict = {'weight' : {}, 'bias' : {}}
288+ weight_bias_dict : Dict [ str , Dict ] = {'weight' : {}, 'bias' : {}}
291289 count = 0
292290 num_directions = 2 if self .bidirectional else 1
293291 for layer in range (self .num_layers ):
@@ -337,8 +335,11 @@ def __init__(self, *args, **kwargs):
337335 def _get_name (self ):
338336 return 'DynamicQuantizedLSTM'
339337
340- def forward_impl (self , input , hx , batch_sizes , max_batch_size , sorted_indices ):
341- # type: (Tensor, Optional[Tuple[Tensor, Tensor]], Optional[Tensor], int, Optional[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
338+ def forward_impl (
339+ self , input : Tensor , hx : Optional [Tuple [Tensor , Tensor ]],
340+ batch_sizes : Optional [Tensor ], max_batch_size : int ,
341+ sorted_indices : Optional [Tensor ]
342+ ) -> Tuple [Tensor , Tuple [Tensor , Tensor ]]:
342343 if hx is None :
343344 num_directions = 2 if self .bidirectional else 1
344345 zeros = torch .zeros (self .num_layers * num_directions ,
@@ -367,8 +368,9 @@ def forward_impl(self, input, hx, batch_sizes, max_batch_size, sorted_indices):
367368 return output , hidden
368369
369370 @torch .jit .export
370- def forward_tensor (self , input , hx = None ):
371- # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]
371+ def forward_tensor (
372+ self , input : Tensor , hx : Optional [Tuple [Tensor , Tensor ]] = None
373+ ) -> Tuple [Tensor , Tuple [Tensor , Tensor ]]:
372374 batch_sizes = None
373375 max_batch_size = input .size (0 ) if self .batch_first else input .size (1 )
374376 sorted_indices = None
@@ -380,27 +382,32 @@ def forward_tensor(self, input, hx=None):
380382 return output , self .permute_hidden (hidden , unsorted_indices )
381383
382384 @torch .jit .export
383- def forward_packed (self , input , hx = None ):
384- # type: (PackedSequence, Optional[Tuple[Tensor, Tensor]]) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]] # noqa
385- input , batch_sizes , sorted_indices , unsorted_indices = input
385+ def forward_packed (
386+ self , input : PackedSequence , hx : Optional [Tuple [Tensor , Tensor ]] = None
387+ ) -> Tuple [PackedSequence , Tuple [Tensor , Tensor ]]: # noqa
388+ input_ , batch_sizes , sorted_indices , unsorted_indices = input
386389 max_batch_size = batch_sizes [0 ]
387390 max_batch_size = int (max_batch_size )
388391
389- output , hidden = self .forward_impl (
390- input , hx , batch_sizes , max_batch_size , sorted_indices )
392+ output_ , hidden = self .forward_impl (
393+ input_ , hx , batch_sizes , max_batch_size , sorted_indices )
391394
392- output = PackedSequence (output , batch_sizes ,
395+ output = PackedSequence (output_ , batch_sizes ,
393396 sorted_indices , unsorted_indices )
394397 return output , self .permute_hidden (hidden , unsorted_indices )
395398
396- def permute_hidden (self , hx , permutation ):
397- # type: (Tuple[Tensor, Tensor], Optional[Tensor]) -> Tuple[Tensor, Tensor]
399+ # "type: ignore" is required due to issue #43072
400+ def permute_hidden ( # type: ignore
401+ self , hx : Tuple [Tensor , Tensor ], permutation : Optional [Tensor ]
402+ ) -> Tuple [Tensor , Tensor ]:
398403 if permutation is None :
399404 return hx
400405 return apply_permutation (hx [0 ], permutation ), apply_permutation (hx [1 ], permutation )
401406
402- def check_forward_args (self , input , hidden , batch_sizes ):
403- # type: (Tensor, Tuple[Tensor, Tensor], Optional[Tensor])->None
407+ # "type: ignore" is required due to issue #43072
408+ def check_forward_args ( # type: ignore
409+ self , input : Tensor , hidden : Tuple [Tensor , Tensor ], batch_sizes : Optional [Tensor ]
410+ ) -> None :
404411 self .check_input (input , batch_sizes )
405412 expected_hidden_size = self .get_expected_hidden_size (input , batch_sizes )
406413
@@ -483,8 +490,7 @@ def check_forward_input(self, input):
483490 "input has inconsistent input_size: got {}, expected {}" .format (
484491 input .size (1 ), self .input_size ))
485492
486- def check_forward_hidden (self , input , hx , hidden_label = '' ):
487- # type: (Tensor, Tensor, str) -> None
493+ def check_forward_hidden (self , input : Tensor , hx : Tensor , hidden_label : str = '' ) -> None :
488494 if input .size (0 ) != hx .size (0 ):
489495 raise RuntimeError (
490496 "Input batch size {} doesn't match hidden{} batch size {}" .format (
@@ -518,6 +524,8 @@ def from_float(cls, mod):
518524 if dtype not in supported_scalar_types :
519525 raise RuntimeError ('Unsupported dtype for dynamic RNN quantization: {}' .format (dtype ))
520526
527+ qRNNCellBase : Union [LSTMCell , GRUCell , RNNCell ]
528+
521529 if type (mod ) == torch .nn .LSTMCell :
522530 qRNNCellBase = LSTMCell (mod .input_size , mod .hidden_size , bias = mod .bias , dtype = dtype )
523531 elif type (mod ) == torch .nn .GRUCell :
@@ -561,7 +569,7 @@ def process_weights(weight, bias, dtype):
561569
562570 def _weight_bias (self ):
563571 # Returns a dict of weights and biases
564- weight_bias_dict = {'weight' : {}, 'bias' : {}}
572+ weight_bias_dict : Dict [ str , Dict ] = {'weight' : {}, 'bias' : {}}
565573 w1 , b1 = self ._packed_weight_ih .__getstate__ ()[0 ]
566574 w2 , b2 = self ._packed_weight_hh .__getstate__ ()[0 ]
567575 weight_bias_dict ['weight' ]['weight_ih' ] = w1
@@ -614,8 +622,7 @@ def __init__(self, input_size, hidden_size, bias=True, nonlinearity="tanh", dtyp
614622 def _get_name (self ):
615623 return 'DynamicQuantizedRNNCell'
616624
617- def forward (self , input , hx = None ):
618- # type: (Tensor, Optional[Tensor]) -> Tensor
625+ def forward (self , input : Tensor , hx : Optional [Tensor ] = None ) -> Tensor :
619626 self .check_forward_input (input )
620627 if hx is None :
621628 hx = torch .zeros (input .size (0 ), self .hidden_size , dtype = input .dtype , device = input .device )
@@ -661,13 +668,12 @@ class LSTMCell(RNNCellBase):
661668 """
662669
663670 def __init__ (self , * args , ** kwargs ):
664- super (LSTMCell , self ).__init__ (* args , num_chunks = 4 , ** kwargs )
671+ super (LSTMCell , self ).__init__ (* args , num_chunks = 4 , ** kwargs ) # type: ignore
665672
666673 def _get_name (self ):
667674 return 'DynamicQuantizedLSTMCell'
668675
669- def forward (self , input , hx = None ):
670- # type: (Tensor, Optional[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, Tensor]
676+ def forward (self , input : Tensor , hx : Optional [Tuple [Tensor , Tensor ]] = None ) -> Tuple [Tensor , Tensor ]:
671677 self .check_forward_input (input )
672678 if hx is None :
673679 zeros = torch .zeros (input .size (0 ), self .hidden_size , dtype = input .dtype , device = input .device )
@@ -707,8 +713,7 @@ def __init__(self, input_size, hidden_size, bias=True, dtype=torch.qint8):
707713 def _get_name (self ):
708714 return 'DynamicQuantizedGRUCell'
709715
710- def forward (self , input , hx = None ):
711- # type: (Tensor, Optional[Tensor]) -> Tensor
716+ def forward (self , input : Tensor , hx : Optional [Tensor ] = None ) -> Tensor :
712717 self .check_forward_input (input )
713718 if hx is None :
714719 hx = torch .zeros (input .size (0 ), self .hidden_size , dtype = input .dtype , device = input .device )
0 commit comments