Skip to content

torch.split is divergent from np.split #50012

@mruberry

Description

@mruberry

Note this issue was previously discussed in #5212, too.

Despite having the same name, torch.split and np.split are different functions. The distinction is a little subtle when just reading the documentation, so let's look at some examples:

a = np.arange(0, 9)
t = torch.arange(0, 9)

# ignore the use of array_split for the moment
# array_split is essentially equivalent to split
np.array_split(a, 2)
: [array([0, 1, 2, 3, 4]), array([5, 6, 7, 8])]

# split does not work like np.split (np.array_split)
torch.split(t, 2)
: (tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8]))

# chunk works like np.split (np.array_split) 
torch.chunk(t, 2)
: (tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8]))

np.split(a, (1, 3))
: [array([0]), array([1, 2]), array([3, 4, 5, 6, 7, 8])]

torch.split(t, (1, 3))
: RuntimeError: split_with_sizes expects split_sizes to sum exactly to 9 (input tensor's size at dimension 0), but got split_sizes=[1, 3]

torch.split(t, (1, 3, 5))
: (tensor([0]), tensor([1, 2, 3]), tensor([4, 5, 6, 7, 8]))

np.split(a, (1, 3, 5))
: [array([0]), array([1, 2]), array([3, 4]), array([5, 6, 7, 8])]

There are a few things going on here. First, np.split and torch.split can accept integer or iterable arguments to describe how they split their input tensor:

  • if the input is an integer n, np.split creates n chunks, while torch.split creates m chunks of size n
  • if the input is an iterable, np.split slices the tensor at the indices described in the iterable, while torch.split creates chunks of the sizes specified in the iterable

Second, np.split and np.array_split are basically the same function. np.split just throws an error when given an integer that doesn't create evenly sized chunks. PyTorch's torch.array_split is equivalent to np.array_split.

For additional background, the PR adding torch.split's iterable support is #3837. That PR models split's iterable behavior on tf.split, which does have the same iterable behavior. tf.split's integer behavior is the same as torch.chunk's or np.split's, however, and not like torch.split's.

All these functions can be a little confusing, so here's a table:

Function Integer Behavior Iterable Behavior
torch.chunk defines # chunks unsupported
torch.split defines chunk size defines chunk size
torch.tensor_split defines # chunks defines split index
np.split defines # chunks defines split index
np.array_split defines # chunks defines split index
tf.split defines # chunks defines chunk size

As we can see from this table, torch.split is a unique function. Of all the functions it's also the most conceptually consistent about defining chunks using size. It is used and iterables are given to it, too.

To deprecate torch.split, I believe we would need to do the following:

  • deprecate torch.split for a release in favor of an equivalent alternative (requires auditing and updating internal calls)
  • remove torch.split for a release (requires implementing an upgrader for serialized torchscript models)
  • restore torch.split with NumPy-compatible behavior

torch.chunk would have been a good name for that alternative. We could deprecate torch.chunk in favor of torch.array_split, which implements a superset of torch.chunk's functionality, but that would further extend the deprecation process by several releases.

While this is a painful deprecation since torch.split is used, this is a great example of when users' NumPy experience or knowledge harms them coming to PyTorch.

np.split is a top 200 NumPy function, and torch.split is popular with hundreds of uses within Facebook.

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @mruberry @rgommers @heitorschueroff

Metadata

Metadata

Assignees

No one assigned

    Labels

    high prioritymodule: deprecationmodule: numpyRelated to numpy support, and also numpy compatibility of our operatorstriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions