Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3873,19 +3873,48 @@ def test_expand(self):
self.assertEqual(torch.randn(()).expand(()), torch.randn(()))

def test_repeat(self):
result = torch.Tensor()
tensor = torch.rand(8, 4)

initial_shape = (8, 4)
tensor = torch.rand(*initial_shape)

size = (3, 1, 1)
torchSize = torch.Size(size)
target = [3, 8, 4]
self.assertEqual(tensor.repeat(*size).size(), target, 'Error in repeat')
self.assertEqual(tensor.repeat(torchSize).size(), target, 'Error in repeat using LongStorage')
self.assertEqual(tensor.repeat(torchSize).size(), target,
'Error in repeat using LongStorage')
result = tensor.repeat(*size)
self.assertEqual(result.size(), target, 'Error in repeat using result')
result = tensor.repeat(torchSize)
self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage')
self.assertEqual(result.mean(0).view(8, 4), tensor, 'Error in repeat (not equal)')

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_repeat_tile(self):

initial_shape = (8, 4)

repeats = ((3, 1, 1),
(3, 3, 3),
(1, 2, 1),
(2, 2, 2, 2))

def _generate_noncontiguous_input():

out = np.broadcast_to(np.random.random((1, 4)),
initial_shape)

assert not (out.flags.c_contiguous or out.flags.f_contiguous)

return out

for repeat in repeats:
for tensor in (torch.from_numpy(np.random.random(initial_shape)),
torch.from_numpy(_generate_noncontiguous_input()),):

self.assertEqual(tensor.repeat(*repeat).numpy(),
np.tile(tensor.numpy(), repeat))

def test_is_same_size(self):
t1 = torch.Tensor(3, 4, 9, 10)
t2 = torch.Tensor(3, 4)
Expand Down
31 changes: 16 additions & 15 deletions torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,30 +266,31 @@ def repeat(self, *sizes):
# If args == (torch.Size,), then we need to unpack the tuple
if len(sizes) == 1 and isinstance(sizes[0], torch.Size):
sizes = sizes[0]

repeats = list(sizes)
result = self.new()
src = self.contiguous()

if len(repeats) < src.dim():
if len(repeats) < self.dim():
raise ValueError('Number of dimensions of repeat dims can not be '
'smaller than number of dimensions of tensor')

xtensor = src.new().set_(src)

This comment was marked as off-topic.

This comment was marked as off-topic.

xsize = list(xtensor.size())
for i in _range(len(repeats) - src.dim()):
xsize = [1] + xsize
# Add new leading dimensions to the tensor if the
# number of target dimensions is larger than the
# number of source dimensions.
num_new_dimensions = len(repeats) - self.dim()
padded_size = [1] * num_new_dimensions + list(self.size())
target_size = torch.Size([a * b for a, b in zip(padded_size, repeats)])

xtensor = self.new().set_(self)
xtensor = xtensor.expand(padded_size)

size = torch.Size([a * b for a, b in zip(xsize, repeats)])
xtensor.resize_(torch.Size(xsize))
result.resize_(size)
result = self.new()
result.resize_(target_size)
urtensor = result.new(result)
for i in _range(xtensor.dim()):
urtensor = urtensor.unfold(i, xtensor.size(i), xtensor.size(i))
for i in _range(urtensor.dim() - xtensor.dim()):
xsize = [1] + xsize
xtensor.resize_(torch.Size(xsize))
xxtensor = xtensor.expand_as(urtensor)
urtensor.copy_(xxtensor)

urtensor.copy_(xtensor.expand_as(urtensor))

This comment was marked as off-topic.

This comment was marked as off-topic.


return result

def masked_copy_(self, *args, **kwargs):
Expand Down