Skip to content

Commit 8c82951

Browse files
author
Rong Rong
committed
enable type checking on tensor.py
several ignores are put in due to limitation on mypy
1 parent 4a0aa69 commit 8c82951

File tree

5 files changed

+65
-37
lines changed

5 files changed

+65
-37
lines changed

mypy.ini

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,6 @@ ignore_errors = True
108108
[mypy-torch.distributions.*]
109109
ignore_errors = True
110110

111-
[mypy-torch.tensor]
112-
ignore_errors = True
113-
114111
[mypy-torch._tensor_str]
115112
ignore_errors = True
116113

tools/pyi/gen_pyi.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,7 @@
7474
# Somehow, these are defined in both _C and in functional. Ick!
7575
'broadcast_tensors',
7676
# Manually define named tensor type stubs in __init__.pyi.in
77-
'rename',
78-
'refine_names',
79-
'align_to',
8077
'align_tensors',
81-
'unflatten',
8278
'meshgrid',
8379
'cartesian_prod',
8480
'block_diag',
@@ -87,7 +83,6 @@
8783
'stft',
8884
'istft',
8985
'tensordot',
90-
'norm',
9186
'split',
9287
'unique_consecutive',
9388
'atleast_1d',
@@ -536,6 +531,7 @@ def gen_pyi(declarations_path, out):
536531
'def __init__(self, other: Tensor) -> None: ...',
537532
'def __init__(self, size: {}, *, {}) -> None: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM),
538533
],
534+
'as_subclass': ["def as_subclass(self, cls) -> Tensor: ..."],
539535
# clamp has no default values in the Declarations
540536
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
541537
" *, out: Optional[Tensor]=None) -> Tensor: ..."],
@@ -546,6 +542,7 @@ def gen_pyi(declarations_path, out):
546542
'tolist': ['def tolist(self) -> List: ...'],
547543
'requires_grad_': ['def requires_grad_(self, mode: _bool=True) -> Tensor: ...'],
548544
'element_size': ['def element_size(self) -> _int: ...'],
545+
'data_ptr': ['def data_ptr(self) -> _int: ...'],
549546
'dim': ['def dim(self) -> _int: ...'],
550547
'nonzero': ['def nonzero(self, *, as_tuple: _bool=...) -> Tensor: ...'],
551548
'numel': ['def numel(self) -> _int: ...'],
@@ -576,6 +573,10 @@ def gen_pyi(declarations_path, out):
576573
],
577574
'item': ["def item(self) -> Number: ..."],
578575
'copy_': ["def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."],
576+
'set_': ['def set_(self, storage: Storage, offset: _int, size: Size, stride: Tuple[_int]) -> Tensor: ...',
577+
'def set_(self, storage: Storage) -> Tensor: ...'],
578+
'split': ['def split(self, split_size: _int, dim: _int=0) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...',
579+
'def split(self, split_size: Tuple[_int], dim: _int=0) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...'],
579580
})
580581
for binop in ['mul', 'div', 'true_divide', 'floor_divide']:
581582
for inplace in [False, True]:

torch/_C/__init__.pyi.in

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ ${dtype_class_hints}
8787
class layout:
8888
...
8989

90+
# Defined in torch/csrc/utils/disable_torch_function.cpp
91+
def DisableTorchFunction(): ...
92+
9093
# Defined in torch/csrc/utils/tensor_layouts.cpp
9194
strided : layout = ...
9295
sparse_coo : layout = ...
@@ -105,6 +108,10 @@ class qscheme: ...
105108

106109
# Defined in torch/csrc/utils/tensor_qschemes.cpp
107110
per_tensor_affine: qscheme = ...
111+
per_channel_affine: qscheme = ...
112+
per_tensor_symmetric: qscheme = ...
113+
per_channel_symmetric: qscheme = ...
114+
per_channel_affine_float_qparams: qscheme = ...
108115

109116
# Defined in torch/csrc/autograd/python_function.cpp
110117
class _FunctionBase(object):

torch/tensor.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import warnings
88
import weakref
99
from torch._C import _add_docstr
10+
from typing import Any, Dict, Tuple, Union
1011
from numbers import Number
1112
import functools
1213
from typing import Optional
@@ -53,6 +54,8 @@ def __deepcopy__(self, memo):
5354
else:
5455
new_storage = self.storage().__deepcopy__(memo)
5556
if self.is_quantized:
57+
# quantizer_params can be different type based on torch attribute
58+
quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[torch.qscheme, Tensor, Tensor, int]]
5659
if self.qscheme() == torch.per_tensor_affine:
5760
quantizer_params = self.qscheme(), self.q_scale(), self.q_zero_point()
5861
elif self.qscheme() in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
@@ -72,7 +75,7 @@ def __deepcopy__(self, memo):
7275
self._backward_hooks)
7376
else:
7477
new_tensor = self.new()
75-
new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
78+
new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
7679
new_tensor.requires_grad = self.requires_grad
7780
memo[id(self)] = new_tensor
7881
return new_tensor
@@ -85,6 +88,7 @@ def __reduce_ex__(self, proto):
8588
check_serializing_named_tensor(self)
8689
# See Note [Don't serialize hooks]
8790
torch.utils.hooks.warn_if_has_hooks(self)
91+
backward_hooks: Dict[Any, Any] = OrderedDict()
8892
# Note: Numpy array is chosen to be the rebuild component for XLA Tensor.
8993
# We considered a few options:
9094
# 1. CPU tensor can't be used here.
@@ -96,12 +100,14 @@ def __reduce_ex__(self, proto):
96100
# `tolist()` converts every single element in the tensor into python objects
97101
# and serialize them one by one.
98102
if self.device.type == 'xla':
99-
args = (self.cpu().numpy(),
100-
self.dtype,
101-
str(self.device),
102-
self.requires_grad)
103-
return (torch._utils._rebuild_xla_tensor, args)
103+
arg_xla = (self.cpu().numpy(),
104+
self.dtype,
105+
str(self.device),
106+
self.requires_grad)
107+
return (torch._utils._rebuild_xla_tensor, arg_xla)
104108
if self.is_quantized:
109+
# quantizer_params can be different type based on torch attribute
110+
quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]]
105111
if self.qscheme() == torch.per_tensor_affine:
106112
quantizer_params = (torch.per_tensor_affine,
107113
self.q_scale(),
@@ -116,31 +122,31 @@ def __reduce_ex__(self, proto):
116122
self.q_per_channel_axis())
117123
else:
118124
raise RuntimeError(f"Serialization is not supported for tensors of type {self.qscheme()}")
119-
args = (self.storage(),
120-
self.storage_offset(),
121-
tuple(self.size()),
122-
self.stride(),
123-
quantizer_params,
124-
self.requires_grad,
125-
OrderedDict())
126-
return (torch._utils._rebuild_qtensor, args)
125+
args_qtensor = (self.storage(),
126+
self.storage_offset(),
127+
tuple(self.size()),
128+
self.stride(),
129+
quantizer_params,
130+
self.requires_grad,
131+
backward_hooks)
132+
return (torch._utils._rebuild_qtensor, args_qtensor)
127133
elif self.is_sparse:
128134
if self.layout == torch.sparse_coo:
129-
args = (self.layout,
130-
(self._indices(),
131-
self._values(),
132-
self.size()))
135+
args_sparse = (self.layout,
136+
(self._indices(),
137+
self._values(),
138+
self.size()))
133139
else:
134140
raise NotImplementedError(
135141
'sparse tensor __reduce_ex__ for layout `%s`' % (self.layout))
136-
return (torch._utils._rebuild_sparse_tensor, args)
142+
return (torch._utils._rebuild_sparse_tensor, args_sparse)
137143
else:
138144
args = (self.storage(),
139145
self.storage_offset(),
140146
tuple(self.size()),
141147
self.stride(),
142148
self.requires_grad,
143-
OrderedDict()) # previously was self._backward_hooks
149+
backward_hooks) # previously was self._backward_hooks
144150
return (torch._utils._rebuild_tensor_v2, args)
145151

146152
def __setstate__(self, state):
@@ -154,7 +160,7 @@ def __setstate__(self, state):
154160
raise RuntimeError('__setstate__ can be only called on leaf Tensors')
155161
if len(state) == 4:
156162
# legacy serialization of Tensor
157-
self.set_(*state)
163+
self.set_(*state)
158164
return
159165
elif len(state) == 5:
160166
# legacy serialization of Variable
@@ -528,7 +534,7 @@ def __format__(self, format_spec):
528534
return self.item().__format__(format_spec)
529535
return object.__format__(self, format_spec)
530536

531-
def __ipow__(self, other):
537+
def __ipow__(self, other): # type: ignore[misc]
532538
relevant_args = (self, other)
533539
from torch.overrides import has_torch_function, handle_torch_function
534540
if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args):
@@ -652,7 +658,8 @@ def __contains__(self, element):
652658
if type(self) is not Tensor and has_torch_function(relevant_args):
653659
return handle_torch_function(Tensor.__contains__, relevant_args, self, element)
654660
if isinstance(element, (torch.Tensor, Number)):
655-
return (element == self).any().item()
661+
# type hint doesn't understand the __contains__ result array
662+
return (element == self).any().item() # type: ignore[union-attr]
656663

657664
raise RuntimeError(
658665
"Tensor.__contains__ only supports Tensor or scalar, but you passed in a %s." %
@@ -669,7 +676,8 @@ def __cuda_array_interface__(self):
669676
relevant_args = (self,)
670677
from torch.overrides import has_torch_function, handle_torch_function
671678
if type(self) is not Tensor and has_torch_function(relevant_args):
672-
return handle_torch_function(Tensor.__cuda_array_interface__.__get__, relevant_args, self)
679+
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
680+
return handle_torch_function(Tensor.__cuda_array_interface__.__get__, relevant_args, self) # type: ignore[attr-defined]
673681

674682
# raise AttributeError for unsupported tensors, so that
675683
# hasattr(cpu_tensor, "__cuda_array_interface__") is False.
@@ -936,7 +944,8 @@ def grad(self):
936944
relevant_args = (self,)
937945
from torch.overrides import has_torch_function, handle_torch_function
938946
if type(self) is not Tensor and has_torch_function(relevant_args):
939-
return handle_torch_function(Tensor.grad.__get__, relevant_args, self)
947+
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
948+
return handle_torch_function(Tensor.grad.__get__, relevant_args, self) # type: ignore[attr-defined]
940949

941950
if self.requires_grad and not hasattr(self, "retains_grad") and not self.is_leaf and self._grad is None:
942951
warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
@@ -951,15 +960,17 @@ def grad(self, new_grad):
951960
relevant_args = (self,)
952961
from torch.overrides import has_torch_function, handle_torch_function
953962
if type(self) is not Tensor and has_torch_function(relevant_args):
954-
return handle_torch_function(Tensor.grad.__set__, relevant_args, self, new_grad)
963+
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
964+
return handle_torch_function(Tensor.grad.__set__, relevant_args, self, new_grad) # type: ignore[attr-defined]
955965
self._grad = new_grad
956966

957967
@grad.deleter
958968
def grad(self):
959969
relevant_args = (self,)
960970
from torch.overrides import has_torch_function, handle_torch_function
961971
if type(self) is not Tensor and has_torch_function(relevant_args):
962-
return handle_torch_function(Tensor.grad.__delete__, relevant_args, self)
972+
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
973+
return handle_torch_function(Tensor.grad.__delete__, relevant_args, self) # type: ignore[attr-defined]
963974
del self._grad
964975

965976
@classmethod

torch/types.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,25 @@
3434
class Storage(object):
3535
_cdata: int
3636

37+
def __deepcopy__(self, memo) -> 'Storage':
38+
...
39+
40+
def _new_shared(self, int) -> 'Storage':
41+
...
42+
3743
def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool) -> None:
3844
...
3945

40-
def size(self) -> int:
46+
def element_size(self) -> int:
4147
...
4248

43-
def _new_shared(self, int) -> 'Storage':
49+
def is_shared(self) -> bool:
50+
...
51+
52+
def share_memory_(self) -> 'Storage':
53+
...
54+
55+
def size(self) -> int:
4456
...
4557

4658
...

0 commit comments

Comments
 (0)