Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,6 @@ ignore_errors = True
[mypy-torch.distributions.*]
ignore_errors = True

[mypy-torch.tensor]
ignore_errors = True

[mypy-torch._tensor_str]
ignore_errors = True

Expand Down
11 changes: 6 additions & 5 deletions tools/pyi/gen_pyi.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,7 @@
# Somehow, these are defined in both _C and in functional. Ick!
'broadcast_tensors',
# Manually define named tensor type stubs in __init__.pyi.in
'rename',
'refine_names',
'align_to',
'align_tensors',
'unflatten',
'meshgrid',
'cartesian_prod',
'block_diag',
Expand All @@ -87,7 +83,6 @@
'stft',
'istft',
'tensordot',
'norm',
'split',
'unique_consecutive',
'atleast_1d',
Expand Down Expand Up @@ -536,6 +531,7 @@ def gen_pyi(declarations_path, out):
'def __init__(self, other: Tensor) -> None: ...',
'def __init__(self, size: {}, *, {}) -> None: ...'.format(type_to_python('IntArrayRef'), DEVICE_PARAM),
],
'as_subclass': ["def as_subclass(self, cls: Tensor) -> Tensor: ..."],
# clamp has no default values in the Declarations
'clamp': ["def clamp(self, min: _float=-inf, max: _float=inf,"
" *, out: Optional[Tensor]=None) -> Tensor: ..."],
Expand All @@ -546,6 +542,7 @@ def gen_pyi(declarations_path, out):
'tolist': ['def tolist(self) -> List: ...'],
'requires_grad_': ['def requires_grad_(self, mode: _bool=True) -> Tensor: ...'],
'element_size': ['def element_size(self) -> _int: ...'],
'data_ptr': ['def data_ptr(self) -> _int: ...'],
'dim': ['def dim(self) -> _int: ...'],
'nonzero': ['def nonzero(self, *, as_tuple: _bool=...) -> Tensor: ...'],
'numel': ['def numel(self) -> _int: ...'],
Expand Down Expand Up @@ -576,6 +573,10 @@ def gen_pyi(declarations_path, out):
],
'item': ["def item(self) -> Number: ..."],
'copy_': ["def copy_(self, src: Tensor, non_blocking: _bool=False) -> Tensor: ..."],
'set_': ['def set_(self, storage: Storage, offset: _int, size: _size, stride: _size) -> Tensor: ...',
'def set_(self, storage: Storage) -> Tensor: ...'],
'split': ['def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...',
'def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...'],
})
for binop in ['mul', 'div', 'true_divide', 'floor_divide']:
for inplace in [False, True]:
Expand Down
7 changes: 7 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ ${dtype_class_hints}
class layout:
...

# Defined in torch/csrc/utils/disable_torch_function.cpp
def DisableTorchFunction(): ...

# Defined in torch/csrc/utils/tensor_layouts.cpp
strided : layout = ...
sparse_coo : layout = ...
Expand All @@ -105,6 +108,10 @@ class qscheme: ...

# Defined in torch/csrc/utils/tensor_qschemes.cpp
per_tensor_affine: qscheme = ...
per_channel_affine: qscheme = ...
per_tensor_symmetric: qscheme = ...
per_channel_symmetric: qscheme = ...
per_channel_affine_float_qparams: qscheme = ...

# Defined in torch/csrc/autograd/python_function.cpp
class _FunctionBase(object):
Expand Down
61 changes: 36 additions & 25 deletions torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
import weakref
from torch._C import _add_docstr
from typing import Any, Dict, Tuple, Union
from numbers import Number
import functools
from typing import Optional
Expand Down Expand Up @@ -53,6 +54,8 @@ def __deepcopy__(self, memo):
else:
new_storage = self.storage().__deepcopy__(memo)
if self.is_quantized:
# quantizer_params can be different type based on torch attribute
quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[torch.qscheme, Tensor, Tensor, int]]
if self.qscheme() == torch.per_tensor_affine:
quantizer_params = self.qscheme(), self.q_scale(), self.q_zero_point()
elif self.qscheme() in (torch.per_channel_affine, torch.per_channel_affine_float_qparams):
Expand Down Expand Up @@ -85,6 +88,7 @@ def __reduce_ex__(self, proto):
check_serializing_named_tensor(self)
# See Note [Don't serialize hooks]
torch.utils.hooks.warn_if_has_hooks(self)
backward_hooks: Dict[Any, Any] = OrderedDict()
# Note: Numpy array is chosen to be the rebuild component for XLA Tensor.
# We considered a few options:
# 1. CPU tensor can't be used here.
Expand All @@ -96,12 +100,14 @@ def __reduce_ex__(self, proto):
# `tolist()` converts every single element in the tensor into python objects
# and serialize them one by one.
if self.device.type == 'xla':
args = (self.cpu().numpy(),
self.dtype,
str(self.device),
self.requires_grad)
return (torch._utils._rebuild_xla_tensor, args)
arg_xla = (self.cpu().numpy(),
self.dtype,
str(self.device),
self.requires_grad)
return (torch._utils._rebuild_xla_tensor, arg_xla)
if self.is_quantized:
# quantizer_params can be different type based on torch attribute
quantizer_params: Union[Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]]
if self.qscheme() == torch.per_tensor_affine:
quantizer_params = (torch.per_tensor_affine,
self.q_scale(),
Expand All @@ -116,31 +122,31 @@ def __reduce_ex__(self, proto):
self.q_per_channel_axis())
else:
raise RuntimeError(f"Serialization is not supported for tensors of type {self.qscheme()}")
args = (self.storage(),
self.storage_offset(),
tuple(self.size()),
self.stride(),
quantizer_params,
self.requires_grad,
OrderedDict())
return (torch._utils._rebuild_qtensor, args)
args_qtensor = (self.storage(),
self.storage_offset(),
tuple(self.size()),
self.stride(),
quantizer_params,
self.requires_grad,
backward_hooks)
return (torch._utils._rebuild_qtensor, args_qtensor)
elif self.is_sparse:
if self.layout == torch.sparse_coo:
args = (self.layout,
(self._indices(),
self._values(),
self.size()))
args_sparse = (self.layout,
(self._indices(),
self._values(),
self.size()))
else:
raise NotImplementedError(
'sparse tensor __reduce_ex__ for layout `%s`' % (self.layout))
return (torch._utils._rebuild_sparse_tensor, args)
return (torch._utils._rebuild_sparse_tensor, args_sparse)
else:
args = (self.storage(),
self.storage_offset(),
tuple(self.size()),
self.stride(),
self.requires_grad,
OrderedDict()) # previously was self._backward_hooks
backward_hooks) # previously was self._backward_hooks
return (torch._utils._rebuild_tensor_v2, args)

def __setstate__(self, state):
Expand Down Expand Up @@ -528,7 +534,7 @@ def __format__(self, format_spec):
return self.item().__format__(format_spec)
return object.__format__(self, format_spec)

def __ipow__(self, other):
def __ipow__(self, other): # type: ignore[misc]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is probably the only issue I dont know what's the right way to fix:

apparently __ipow__ is not defined in the generated _C/__init__.pyi and in previous line __pow__ = _C._TensorBase.pow instead of _C._TensorBase.__pow__ (which is defined BTW). it got me a bit confused.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, reading the code here, it sounds like we don't implement __ipow__ (that's why we raise NotImplemented here). So it isn't unexpected that it's not defined in _C. What's the problem?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops. forgot to mention the problem - mypy expects identical signature between __pow__ and __ipow__ since a **= b is equivalent to a = a ** b. Yes I think we are expected to not implemented this function so I will just leave it as ignore for now.

relevant_args = (self, other)
from torch.overrides import has_torch_function, handle_torch_function
if type(self) is not Tensor and type(other) is not Tensor and has_torch_function(relevant_args):
Expand Down Expand Up @@ -652,7 +658,8 @@ def __contains__(self, element):
if type(self) is not Tensor and has_torch_function(relevant_args):
return handle_torch_function(Tensor.__contains__, relevant_args, self, element)
if isinstance(element, (torch.Tensor, Number)):
return (element == self).any().item()
# type hint doesn't understand the __contains__ result array
return (element == self).any().item() # type: ignore[union-attr]

raise RuntimeError(
"Tensor.__contains__ only supports Tensor or scalar, but you passed in a %s." %
Expand All @@ -669,7 +676,8 @@ def __cuda_array_interface__(self):
relevant_args = (self,)
from torch.overrides import has_torch_function, handle_torch_function
if type(self) is not Tensor and has_torch_function(relevant_args):
return handle_torch_function(Tensor.__cuda_array_interface__.__get__, relevant_args, self)
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
return handle_torch_function(Tensor.__cuda_array_interface__.__get__, relevant_args, self) # type: ignore[attr-defined]

# raise AttributeError for unsupported tensors, so that
# hasattr(cpu_tensor, "__cuda_array_interface__") is False.
Expand Down Expand Up @@ -936,7 +944,8 @@ def grad(self):
relevant_args = (self,)
from torch.overrides import has_torch_function, handle_torch_function
if type(self) is not Tensor and has_torch_function(relevant_args):
return handle_torch_function(Tensor.grad.__get__, relevant_args, self)
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
return handle_torch_function(Tensor.grad.__get__, relevant_args, self) # type: ignore[attr-defined]

if self.requires_grad and not hasattr(self, "retains_grad") and not self.is_leaf and self._grad is None:
warnings.warn("The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad "
Expand All @@ -951,15 +960,17 @@ def grad(self, new_grad):
relevant_args = (self,)
from torch.overrides import has_torch_function, handle_torch_function
if type(self) is not Tensor and has_torch_function(relevant_args):
return handle_torch_function(Tensor.grad.__set__, relevant_args, self, new_grad)
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
return handle_torch_function(Tensor.grad.__set__, relevant_args, self, new_grad) # type: ignore[attr-defined]
self._grad = new_grad

@grad.deleter
def grad(self):
relevant_args = (self,)
from torch.overrides import has_torch_function, handle_torch_function
if type(self) is not Tensor and has_torch_function(relevant_args):
return handle_torch_function(Tensor.grad.__delete__, relevant_args, self)
# TODO mypy doesn't support @property, see: https://github.com/python/mypy/issues/6185
return handle_torch_function(Tensor.grad.__delete__, relevant_args, self) # type: ignore[attr-defined]
del self._grad

@classmethod
Expand Down
16 changes: 14 additions & 2 deletions torch/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,25 @@
class Storage(object):
_cdata: int

def __deepcopy__(self, memo) -> 'Storage':
...

def _new_shared(self, int) -> 'Storage':
...

def _write_file(self, f: Any, is_real_file: _bool, save_size: _bool) -> None:
...

def size(self) -> int:
def element_size(self) -> int:
...

def _new_shared(self, int) -> 'Storage':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you deleting _new_shared() here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was reordering them based on alphabetical order :-)

def is_shared(self) -> bool:
...

def share_memory_(self) -> 'Storage':
...

def size(self) -> int:
...

...