Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4215,6 +4215,28 @@ def test_split(self):
self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0)
start = start + target_size[dim]

# Variable sections split
tensor = torch.randn(20, 10)
dim = 0
split_sizes = [5, 5, 10]
target_sizes = ([[5, 10], [5, 10], [10, 10]])
splits = tensor.split(split_sizes, dim)
start = 0
for target_size, split in zip(target_sizes, splits):
self.assertEqual(split.size(), target_size)
self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0)
start = start + target_size[dim]

split_sizes = [2, 2, 6]
target_sizes = ([20, 2], [20, 2], [20, 6])
dim = 1
splits = tensor.split(split_sizes, dim)
start = 0
for target_size, split in zip(target_sizes, splits):
self.assertEqual(split.size(), target_size)
self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, 0)
start = start + target_size[dim]

def test_chunk(self):
tensor = torch.rand(4, 7)
num_chunks = 3
Expand Down
44 changes: 31 additions & 13 deletions torch/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,45 @@
]


def split(tensor, split_size, dim=0):
r"""Splits the tensor into chunks all of size :attr:`split_size` (if possible).

def split(tensor, split_size_or_sections, dim=0):
"""Splits the tensor into chunks.
If ``split_size_or_sections`` is an integer type, then ``tensor`` will be
split into equally sized chunks (if possible).
Last chunk will be smaller if the tensor size along a given dimension
is not divisible by :attr`split_size`.
is not divisible by ``split_size``.
If ``split_size_or_sections`` is a list, then ``tensor`` will be split
into ``len(split_size_or_sections)`` chunks with sizes in ``dim`` according
to ``split_size_or_sections``.

Arguments:
tensor (Tensor): the tensor to split
split_size (int): size of a single chunk
dim (int): dimension along which to split the tensor
tensor (Tensor): tensor to split.
split_size_or_sections (int) or (list(int)): size of a single chunk or
list of sizes for each chunk
dim (int): dimension along which to split the tensor.
"""
if dim < 0:
dim += tensor.dim()
dim_size = tensor.size(dim)
num_splits = (dim_size + split_size - 1) // split_size
last_split_size = split_size - (split_size * num_splits - dim_size)

def get_split_size(i):
return split_size if i < num_splits - 1 else last_split_size
return tuple(tensor.narrow(int(dim), int(i * split_size), int(get_split_size(i))) for i
in _range(0, num_splits))
if isinstance(split_size_or_sections, int):
split_size = split_size_or_sections
num_splits = (dim_size + split_size - 1) // split_size
last_split_size = split_size - (split_size * num_splits - dim_size)

def get_split_size(i):
return split_size if i < num_splits - 1 else last_split_size
return tuple(tensor.narrow(int(dim), int(i * split_size), int(get_split_size(i))) for i
in _range(0, num_splits))

else:
if dim_size != sum(split_size_or_sections):
raise ValueError("Sum of split sizes exceeds tensor dim")
split_indices = [0] + split_size_or_sections
split_indices = torch.cumsum(torch.Tensor(split_indices), dim=0)

return tuple(
tensor.narrow(int(dim), int(start), int(length))
for start, length in zip(split_indices, split_size_or_sections))


def chunk(tensor, chunks, dim=0):
Expand Down