-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 Feature
Add the ability to subclass torch.Tensor so that you can use isinstance(tensor, CustomTensor) to determine the "context" of the data stored in the subclass object. I think #9515 describes the need for dtype promotion logic, but this proposal is for python type promotion logic.
The array__wrap method already exists in numpy and ensures that the following ufunc preserves the custom tensor type:
custom_tensor = CustomTensor(torch.tensor([1, 2, 3, 4])
new_tensor = np.add(custom_tensor, 1.0) -> CustomTensor
new_tensor = np.add(1.0, custom_tensor) -> CustomTensor (Hmmm...)
isinstance(new_tensor, CustomTensor) # returns TrueThe tensor_wrap method does not exist. As a result, the custom type information is easily lost:
custom_tensor = CustomTensor(torch.tensor([1, 2, 3, 4])
new_tensor = custom_tensor.add(1.0) -> Tensor
isinstance(new_tensor, CustomTensor) # returns False
Motivation
The torch.Tensor class provides a dynamically typed wrapper of statically typed Torch Classes (FloatTensor, DoubleTensor, HalfTensor, etc) using the numpy dtype. While this should save us from ever wanting to subclass the Tensor class, it would be nice to use the subclass of a Tensor to provide custom methods specific to the context of the data.
For example:
class TimeDomainTensor(torch.Tensor)
def __repr__()
return 'Time Information:\n' + super(Parameter, self).__repr__()class FrequencyDomainTensor(torch.Tensor)
def __repr__()
return 'Frequency Information:\n' + super(Parameter, self).__repr__()class SurfaceTensor(torch.Tensor)
def __repr__()
return 'Surface Information:\n' + super(Parameter, self).__repr__()Pitch
Custom subclassing is already utilized by the torch.nn.Parameter Class, so most of the heavy lifting is already done. Note that performing a math operation on a Parameter also downcasts the Parameter to a Tensor.
Alternatives
Numpy provides the following methods for customizing the output behaviour of a ufunc.
-
def array_wrap(self, array):
A subclass can override what happens when executing numpy ufuncs after the ufunc is called. -
def array_ufunc(ufunc, method, *inputs, **kwargs):
A subclass can override what happens when executing numpy ufuncs before and after the ufunc is called. New in version 1.13. -
def array_finalize(self, obj):
The ability to insert any kind of custom attributes, events, io, etc. This may not work on GPU because this would open up the ability to define highly customized data containers.
a. explicit constructor call (obj = MySubClass(params)).
b. View casting
c. Creating new from template -
Do Nothing:
Revert to functional implementation. Simply replace all custom_tensor.method(*args, **kwargs) with method(custom_tensor, *args, **kwargs).
The array_wrap method is sufficient for this problem because we want to cast the tensor after the function has been performed. The array_ufunc would also work but could result in compatibility problems with older versions of numpy. The __array_finalize method may add too much customization that could create a wide range of gotchas and incompatibility between CPU and GPU.
Example
Sample code that partially works for add(custom_tensor, other):
from collections import OrderedDict
import torch as th
def _rebuild_custom_tensor(data, requires_grad, backward_hooks):
param = CustomTensor(data, requires_grad)
# NB: This line exists only for backwards compatibility; the
# general expectation is that backward_hooks is an empty
# OrderedDict. See Note [Don't serialize hooks]
param._backward_hooks = backward_hooks
return param
class CustomTensor(th.Tensor):
r"""A kind of Tensor that has time, freq as its inner-most dimensions.
Arguments:
data (Tensor): parameter tensor.
requires_grad (bool, optional): if the parameter requires gradient. See
:ref:`excluding-subgraphs` for more details. Default: `True`
"""
def __new__(cls, data=None, requires_grad=False):
if data is None:
data = th.Tensor()
return th.Tensor._make_subclass(cls, data, requires_grad)
def __deepcopy__(self, memo):
if id(self) in memo:
return memo[id(self)]
else:
result = type(self)(self.data.clone(), self.requires_grad)
memo[id(self)] = result
return result
def __repr__(self):
return 'Custom Message:\n' + super(CustomTensor, self).__repr__()
def __reduce_ex__(self, proto):
# See Note [Don't serialize hooks]
return (
_rebuild_custom_tensor,
(self.data, self.requires_grad, OrderedDict())
)
def add(self, value):
tensor = super().add(value)
return self.__tensor_wrap__(tensor)
def __add__(self, other):
tensor = super().__add__(other)
return self.__tensor_wrap__(tensor)
# Wrap Numpy array again in a suitable tensor when done, to support e.g.
# `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor`
def __array_wrap__(self, array):
if array.dtype == bool:
# Workaround, torch has no built-in bool tensor
array = array.astype('uint8')
return th.from_numpy(array)
else:
return th.Tensor._make_subclass(CustomTensor, th.from_numpy(array), self.requires_grad)
# Same as __array_wrap__ but for Tensors
def __tensor_wrap__(self, tensor):
return th.Tensor._make_subclass(CustomTensor, tensor, self.requires_grad)Sample Results:
import torch
from custom import CustomTensor
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.double)
custom_x = CustomTensor(x)
custom_x + 1 # -> CustomTensor
1 + custom_x # -> **Tensor** :(
torch.add(1, custom_tensor) -> **Tensor** :(
torch.add(custom_tensor, 1) -> **Tensor** :(Conclusions
The above example demonstrates how numpy goes beyond binary operator overloading to ensure that all operations return the highest derived subclass of the ndarray. This is not how pytorch is organized. I think #9515 describes the need for dtype promotion logic, but this proposal is for python type promotion logic.