Skip to content

Proposal: Add __tensor_wrap__ method similar to numpy __array_wrap__ #17249

@dylanbespalko

Description

@dylanbespalko

🚀 Feature

Add the ability to subclass torch.Tensor so that you can use isinstance(tensor, CustomTensor) to determine the "context" of the data stored in the subclass object. I think #9515 describes the need for dtype promotion logic, but this proposal is for python type promotion logic.

The array__wrap method already exists in numpy and ensures that the following ufunc preserves the custom tensor type:

custom_tensor = CustomTensor(torch.tensor([1, 2, 3, 4])
new_tensor = np.add(custom_tensor, 1.0) -> CustomTensor
new_tensor = np.add(1.0, custom_tensor) -> CustomTensor (Hmmm...)
isinstance(new_tensor, CustomTensor) #  returns True

The tensor_wrap method does not exist. As a result, the custom type information is easily lost:
custom_tensor = CustomTensor(torch.tensor([1, 2, 3, 4])
new_tensor = custom_tensor.add(1.0) -> Tensor
isinstance(new_tensor, CustomTensor) # returns False

Motivation

The torch.Tensor class provides a dynamically typed wrapper of statically typed Torch Classes (FloatTensor, DoubleTensor, HalfTensor, etc) using the numpy dtype. While this should save us from ever wanting to subclass the Tensor class, it would be nice to use the subclass of a Tensor to provide custom methods specific to the context of the data.

For example:

class TimeDomainTensor(torch.Tensor)

    def __repr__()
        return 'Time Information:\n' + super(Parameter, self).__repr__()
class FrequencyDomainTensor(torch.Tensor)
   
    def __repr__()
        return 'Frequency Information:\n' + super(Parameter, self).__repr__()
class SurfaceTensor(torch.Tensor)
   
    def __repr__()
        return 'Surface Information:\n' + super(Parameter, self).__repr__()

Pitch

Custom subclassing is already utilized by the torch.nn.Parameter Class, so most of the heavy lifting is already done. Note that performing a math operation on a Parameter also downcasts the Parameter to a Tensor.

Alternatives

Numpy provides the following methods for customizing the output behaviour of a ufunc.

  1. def array_wrap(self, array):
    A subclass can override what happens when executing numpy ufuncs after the ufunc is called.

  2. def array_ufunc(ufunc, method, *inputs, **kwargs):
    A subclass can override what happens when executing numpy ufuncs before and after the ufunc is called. New in version 1.13.

  3. def array_finalize(self, obj):
    The ability to insert any kind of custom attributes, events, io, etc. This may not work on GPU because this would open up the ability to define highly customized data containers.
    a. explicit constructor call (obj = MySubClass(params)).
    b. View casting
    c. Creating new from template

  4. Do Nothing:
    Revert to functional implementation. Simply replace all custom_tensor.method(*args, **kwargs) with method(custom_tensor, *args, **kwargs).

The array_wrap method is sufficient for this problem because we want to cast the tensor after the function has been performed. The array_ufunc would also work but could result in compatibility problems with older versions of numpy. The __array_finalize method may add too much customization that could create a wide range of gotchas and incompatibility between CPU and GPU.

Example

Sample code that partially works for add(custom_tensor, other):

from collections import OrderedDict
import torch as th


def _rebuild_custom_tensor(data, requires_grad, backward_hooks):
    param = CustomTensor(data, requires_grad)
    # NB: This line exists only for backwards compatibility; the
    # general expectation is that backward_hooks is an empty
    # OrderedDict.  See Note [Don't serialize hooks]
    param._backward_hooks = backward_hooks

    return param

class CustomTensor(th.Tensor):
    r"""A kind of Tensor that has time, freq as its inner-most dimensions.

    Arguments:
        data (Tensor): parameter tensor.
        requires_grad (bool, optional): if the parameter requires gradient. See
            :ref:`excluding-subgraphs` for more details. Default: `True`
    """

    def __new__(cls, data=None, requires_grad=False):
        if data is None:
            data = th.Tensor()
        return th.Tensor._make_subclass(cls, data, requires_grad)

    def __deepcopy__(self, memo):
        if id(self) in memo:
            return memo[id(self)]
        else:
            result = type(self)(self.data.clone(), self.requires_grad)
            memo[id(self)] = result
            return result

    def __repr__(self):
        return 'Custom Message:\n' + super(CustomTensor, self).__repr__()

    def __reduce_ex__(self, proto):
        # See Note [Don't serialize hooks]
        return (
            _rebuild_custom_tensor,
            (self.data, self.requires_grad, OrderedDict())
        )

    def add(self, value):
        tensor = super().add(value)
        return self.__tensor_wrap__(tensor)

    def __add__(self, other):
        tensor = super().__add__(other)
        return self.__tensor_wrap__(tensor)

    # Wrap Numpy array again in a suitable tensor when done, to support e.g.
    # `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor`
    def __array_wrap__(self, array):
        if array.dtype == bool:
            # Workaround, torch has no built-in bool tensor
            array = array.astype('uint8')
            return th.from_numpy(array)
        else:
            return th.Tensor._make_subclass(CustomTensor, th.from_numpy(array), self.requires_grad)

    # Same as __array_wrap__ but for Tensors
    def __tensor_wrap__(self, tensor):
        return th.Tensor._make_subclass(CustomTensor, tensor, self.requires_grad)

Sample Results:

import torch
from custom import CustomTensor
x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.double)
custom_x = CustomTensor(x)
custom_x + 1   # -> CustomTensor
1 + custom_x   # -> **Tensor** :(
torch.add(1, custom_tensor) -> **Tensor** :(
torch.add(custom_tensor, 1) -> **Tensor** :(

Conclusions

The above example demonstrates how numpy goes beyond binary operator overloading to ensure that all operations return the highest derived subclass of the ndarray. This is not how pytorch is organized. I think #9515 describes the need for dtype promotion logic, but this proposal is for python type promotion logic.

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: internalsRelated to internal abstractions in c10 and ATenmodule: numpyRelated to numpy support, and also numpy compatibility of our operatorstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions