-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Note this issue was previously discussed in #5212, too.
Despite having the same name, torch.split and np.split are different functions. The distinction is a little subtle when just reading the documentation, so let's look at some examples:
a = np.arange(0, 9)
t = torch.arange(0, 9)
# ignore the use of array_split for the moment
# array_split is essentially equivalent to split
np.array_split(a, 2)
: [array([0, 1, 2, 3, 4]), array([5, 6, 7, 8])]
# split does not work like np.split (np.array_split)
torch.split(t, 2)
: (tensor([0, 1]), tensor([2, 3]), tensor([4, 5]), tensor([6, 7]), tensor([8]))
# chunk works like np.split (np.array_split)
torch.chunk(t, 2)
: (tensor([0, 1, 2, 3, 4]), tensor([5, 6, 7, 8]))
np.split(a, (1, 3))
: [array([0]), array([1, 2]), array([3, 4, 5, 6, 7, 8])]
torch.split(t, (1, 3))
: RuntimeError: split_with_sizes expects split_sizes to sum exactly to 9 (input tensor's size at dimension 0), but got split_sizes=[1, 3]
torch.split(t, (1, 3, 5))
: (tensor([0]), tensor([1, 2, 3]), tensor([4, 5, 6, 7, 8]))
np.split(a, (1, 3, 5))
: [array([0]), array([1, 2]), array([3, 4]), array([5, 6, 7, 8])]
There are a few things going on here. First, np.split and torch.split can accept integer or iterable arguments to describe how they split their input tensor:
- if the input is an integer n, np.split creates n chunks, while torch.split creates m chunks of size n
- if the input is an iterable, np.split slices the tensor at the indices described in the iterable, while torch.split creates chunks of the sizes specified in the iterable
Second, np.split and np.array_split are basically the same function. np.split just throws an error when given an integer that doesn't create evenly sized chunks. PyTorch's torch.array_split is equivalent to np.array_split.
For additional background, the PR adding torch.split's iterable support is #3837. That PR models split's iterable behavior on tf.split, which does have the same iterable behavior. tf.split's integer behavior is the same as torch.chunk's or np.split's, however, and not like torch.split's.
All these functions can be a little confusing, so here's a table:
| Function | Integer Behavior | Iterable Behavior |
|---|---|---|
| torch.chunk | defines # chunks | unsupported |
| torch.split | defines chunk size | defines chunk size |
| torch.tensor_split | defines # chunks | defines split index |
| np.split | defines # chunks | defines split index |
| np.array_split | defines # chunks | defines split index |
| tf.split | defines # chunks | defines chunk size |
As we can see from this table, torch.split is a unique function. Of all the functions it's also the most conceptually consistent about defining chunks using size. It is used and iterables are given to it, too.
To deprecate torch.split, I believe we would need to do the following:
- deprecate torch.split for a release in favor of an equivalent alternative (requires auditing and updating internal calls)
- remove torch.split for a release (requires implementing an upgrader for serialized torchscript models)
- restore torch.split with NumPy-compatible behavior
torch.chunk would have been a good name for that alternative. We could deprecate torch.chunk in favor of torch.array_split, which implements a superset of torch.chunk's functionality, but that would further extend the deprecation process by several releases.
While this is a painful deprecation since torch.split is used, this is a great example of when users' NumPy experience or knowledge harms them coming to PyTorch.
np.split is a top 200 NumPy function, and torch.split is popular with hundreds of uses within Facebook.
cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @mruberry @rgommers @heitorschueroff