77import warnings
88import weakref
99from torch ._C import _add_docstr
10+ from typing import Any , Dict , Tuple , Union
1011from numbers import Number
1112import functools
1213from typing import Optional
@@ -53,6 +54,8 @@ def __deepcopy__(self, memo):
5354 else :
5455 new_storage = self .storage ().__deepcopy__ (memo )
5556 if self .is_quantized :
57+ # quantizer_params can be different type based on torch attribute
58+ quantizer_params : Union [Tuple [torch .qscheme , float , int ], Tuple [torch .qscheme , Tensor , Tensor , int ]]
5659 if self .qscheme () == torch .per_tensor_affine :
5760 quantizer_params = self .qscheme (), self .q_scale (), self .q_zero_point ()
5861 elif self .qscheme () in (torch .per_channel_affine , torch .per_channel_affine_float_qparams ):
@@ -72,7 +75,7 @@ def __deepcopy__(self, memo):
7275 self ._backward_hooks )
7376 else :
7477 new_tensor = self .new ()
75- new_tensor .set_ (new_storage , self .storage_offset (), self .size (), self .stride ())
78+ new_tensor .set_ (new_storage , self .storage_offset (), self .size (), self .stride ())
7679 new_tensor .requires_grad = self .requires_grad
7780 memo [id (self )] = new_tensor
7881 return new_tensor
@@ -85,6 +88,7 @@ def __reduce_ex__(self, proto):
8588 check_serializing_named_tensor (self )
8689 # See Note [Don't serialize hooks]
8790 torch .utils .hooks .warn_if_has_hooks (self )
91+ backward_hooks : Dict [Any , Any ] = OrderedDict ()
8892 # Note: Numpy array is chosen to be the rebuild component for XLA Tensor.
8993 # We considered a few options:
9094 # 1. CPU tensor can't be used here.
@@ -96,12 +100,14 @@ def __reduce_ex__(self, proto):
96100 # `tolist()` converts every single element in the tensor into python objects
97101 # and serialize them one by one.
98102 if self .device .type == 'xla' :
99- args = (self .cpu ().numpy (),
100- self .dtype ,
101- str (self .device ),
102- self .requires_grad )
103- return (torch ._utils ._rebuild_xla_tensor , args )
103+ arg_xla = (self .cpu ().numpy (),
104+ self .dtype ,
105+ str (self .device ),
106+ self .requires_grad )
107+ return (torch ._utils ._rebuild_xla_tensor , arg_xla )
104108 if self .is_quantized :
109+ # quantizer_params can be different type based on torch attribute
110+ quantizer_params : Union [Tuple [torch .qscheme , float , int ], Tuple [Any , Tensor , Tensor , int ]]
105111 if self .qscheme () == torch .per_tensor_affine :
106112 quantizer_params = (torch .per_tensor_affine ,
107113 self .q_scale (),
@@ -116,31 +122,31 @@ def __reduce_ex__(self, proto):
116122 self .q_per_channel_axis ())
117123 else :
118124 raise RuntimeError (f"Serialization is not supported for tensors of type { self .qscheme ()} " )
119- args = (self .storage (),
120- self .storage_offset (),
121- tuple (self .size ()),
122- self .stride (),
123- quantizer_params ,
124- self .requires_grad ,
125- OrderedDict () )
126- return (torch ._utils ._rebuild_qtensor , args )
125+ args_qtensor = (self .storage (),
126+ self .storage_offset (),
127+ tuple (self .size ()),
128+ self .stride (),
129+ quantizer_params ,
130+ self .requires_grad ,
131+ backward_hooks )
132+ return (torch ._utils ._rebuild_qtensor , args_qtensor )
127133 elif self .is_sparse :
128134 if self .layout == torch .sparse_coo :
129- args = (self .layout ,
130- (self ._indices (),
131- self ._values (),
132- self .size ()))
135+ args_sparse = (self .layout ,
136+ (self ._indices (),
137+ self ._values (),
138+ self .size ()))
133139 else :
134140 raise NotImplementedError (
135141 'sparse tensor __reduce_ex__ for layout `%s`' % (self .layout ))
136- return (torch ._utils ._rebuild_sparse_tensor , args )
142+ return (torch ._utils ._rebuild_sparse_tensor , args_sparse )
137143 else :
138144 args = (self .storage (),
139145 self .storage_offset (),
140146 tuple (self .size ()),
141147 self .stride (),
142148 self .requires_grad ,
143- OrderedDict () ) # previously was self._backward_hooks
149+ backward_hooks ) # previously was self._backward_hooks
144150 return (torch ._utils ._rebuild_tensor_v2 , args )
145151
146152 def __setstate__ (self , state ):
@@ -154,7 +160,7 @@ def __setstate__(self, state):
154160 raise RuntimeError ('__setstate__ can be only called on leaf Tensors' )
155161 if len (state ) == 4 :
156162 # legacy serialization of Tensor
157- self .set_ (* state )
163+ self .set_ (* state )
158164 return
159165 elif len (state ) == 5 :
160166 # legacy serialization of Variable
@@ -528,7 +534,7 @@ def __format__(self, format_spec):
528534 return self .item ().__format__ (format_spec )
529535 return object .__format__ (self , format_spec )
530536
531- def __ipow__ (self , other ):
537+ def __ipow__ (self , other ): # type: ignore[misc]
532538 relevant_args = (self , other )
533539 from torch .overrides import has_torch_function , handle_torch_function
534540 if type (self ) is not Tensor and type (other ) is not Tensor and has_torch_function (relevant_args ):
@@ -652,7 +658,8 @@ def __contains__(self, element):
652658 if type (self ) is not Tensor and has_torch_function (relevant_args ):
653659 return handle_torch_function (Tensor .__contains__ , relevant_args , self , element )
654660 if isinstance (element , (torch .Tensor , Number )):
655- return (element == self ).any ().item ()
661+ # type hint doesn't understand the __contains__ result array
662+ return (element == self ).any ().item () # type: ignore[union-attr]
656663
657664 raise RuntimeError (
658665 "Tensor.__contains__ only supports Tensor or scalar, but you passed in a %s." %
@@ -669,7 +676,8 @@ def __cuda_array_interface__(self):
669676 relevant_args = (self ,)
670677 from torch .overrides import has_torch_function , handle_torch_function
671678 if type (self ) is not Tensor and has_torch_function (relevant_args ):
672- return handle_torch_function (Tensor .__cuda_array_interface__ .__get__ , relevant_args , self )
679+ # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
680+ return handle_torch_function (Tensor .__cuda_array_interface__ .__get__ , relevant_args , self ) # type: ignore[attr-defined]
673681
674682 # raise AttributeError for unsupported tensors, so that
675683 # hasattr(cpu_tensor, "__cuda_array_interface__") is False.
@@ -936,7 +944,8 @@ def grad(self):
936944 relevant_args = (self ,)
937945 from torch .overrides import has_torch_function , handle_torch_function
938946 if type (self ) is not Tensor and has_torch_function (relevant_args ):
939- return handle_torch_function (Tensor .grad .__get__ , relevant_args , self )
947+ # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
948+ return handle_torch_function (Tensor .grad .__get__ , relevant_args , self ) # type: ignore[attr-defined]
940949
941950 if self .requires_grad and not hasattr (self , "retains_grad" ) and not self .is_leaf and self ._grad is None :
942951 warnings .warn ("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
@@ -951,15 +960,17 @@ def grad(self, new_grad):
951960 relevant_args = (self ,)
952961 from torch .overrides import has_torch_function , handle_torch_function
953962 if type (self ) is not Tensor and has_torch_function (relevant_args ):
954- return handle_torch_function (Tensor .grad .__set__ , relevant_args , self , new_grad )
963+ # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
964+ return handle_torch_function (Tensor .grad .__set__ , relevant_args , self , new_grad ) # type: ignore[attr-defined]
955965 self ._grad = new_grad
956966
957967 @grad .deleter
958968 def grad (self ):
959969 relevant_args = (self ,)
960970 from torch .overrides import has_torch_function , handle_torch_function
961971 if type (self ) is not Tensor and has_torch_function (relevant_args ):
962- return handle_torch_function (Tensor .grad .__delete__ , relevant_args , self )
972+ # TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
973+ return handle_torch_function (Tensor .grad .__delete__ , relevant_args , self ) # type: ignore[attr-defined]
963974 del self ._grad
964975
965976 @classmethod
0 commit comments