Skip to content

Commit 0185d5a

Browse files
maciejkulasoumith
authored andcommitted
Fix repeat non owning (#4084)
1 parent faea900 commit 0185d5a

File tree

2 files changed

+48
-18
lines changed

2 files changed

+48
-18
lines changed

test/test_torch.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3748,19 +3748,48 @@ def test_expand(self):
37483748
self.assertEqual(torch.randn(()).expand(()), torch.randn(()))
37493749

37503750
def test_repeat(self):
3751-
result = torch.Tensor()
3752-
tensor = torch.rand(8, 4)
3751+
3752+
initial_shape = (8, 4)
3753+
tensor = torch.rand(*initial_shape)
3754+
37533755
size = (3, 1, 1)
37543756
torchSize = torch.Size(size)
37553757
target = [3, 8, 4]
37563758
self.assertEqual(tensor.repeat(*size).size(), target, 'Error in repeat')
3757-
self.assertEqual(tensor.repeat(torchSize).size(), target, 'Error in repeat using LongStorage')
3759+
self.assertEqual(tensor.repeat(torchSize).size(), target,
3760+
'Error in repeat using LongStorage')
37583761
result = tensor.repeat(*size)
37593762
self.assertEqual(result.size(), target, 'Error in repeat using result')
37603763
result = tensor.repeat(torchSize)
37613764
self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage')
37623765
self.assertEqual(result.mean(0).view(8, 4), tensor, 'Error in repeat (not equal)')
37633766

3767+
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
3768+
def test_repeat_tile(self):
3769+
3770+
initial_shape = (8, 4)
3771+
3772+
repeats = ((3, 1, 1),
3773+
(3, 3, 3),
3774+
(1, 2, 1),
3775+
(2, 2, 2, 2))
3776+
3777+
def _generate_noncontiguous_input():
3778+
3779+
out = np.broadcast_to(np.random.random((1, 4)),
3780+
initial_shape)
3781+
3782+
assert not (out.flags.c_contiguous or out.flags.f_contiguous)
3783+
3784+
return out
3785+
3786+
for repeat in repeats:
3787+
for tensor in (torch.from_numpy(np.random.random(initial_shape)),
3788+
torch.from_numpy(_generate_noncontiguous_input()),):
3789+
3790+
self.assertEqual(tensor.repeat(*repeat).numpy(),
3791+
np.tile(tensor.numpy(), repeat))
3792+
37643793
def test_is_same_size(self):
37653794
t1 = torch.Tensor(3, 4, 9, 10)
37663795
t2 = torch.Tensor(3, 4)

torch/tensor.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -269,30 +269,31 @@ def repeat(self, *sizes):
269269
# If args == (torch.Size,), then we need to unpack the tuple
270270
if len(sizes) == 1 and isinstance(sizes[0], torch.Size):
271271
sizes = sizes[0]
272+
272273
repeats = list(sizes)
273-
result = self.new()
274-
src = self.contiguous()
275274

276-
if len(repeats) < src.dim():
275+
if len(repeats) < self.dim():
277276
raise ValueError('Number of dimensions of repeat dims can not be '
278277
'smaller than number of dimensions of tensor')
279278

280-
xtensor = src.new().set_(src)
281-
xsize = list(xtensor.size())
282-
for i in _range(len(repeats) - src.dim()):
283-
xsize = [1] + xsize
279+
# Add new leading dimensions to the tensor if the
280+
# number of target dimensions is larger than the
281+
# number of source dimensions.
282+
num_new_dimensions = len(repeats) - self.dim()
283+
padded_size = [1] * num_new_dimensions + list(self.size())
284+
target_size = torch.Size([a * b for a, b in zip(padded_size, repeats)])
285+
286+
xtensor = self.new().set_(self)
287+
xtensor = xtensor.expand(padded_size)
284288

285-
size = torch.Size([a * b for a, b in zip(xsize, repeats)])
286-
xtensor.resize_(torch.Size(xsize))
287-
result.resize_(size)
289+
result = self.new()
290+
result.resize_(target_size)
288291
urtensor = result.new(result)
289292
for i in _range(xtensor.dim()):
290293
urtensor = urtensor.unfold(i, xtensor.size(i), xtensor.size(i))
291-
for i in _range(urtensor.dim() - xtensor.dim()):
292-
xsize = [1] + xsize
293-
xtensor.resize_(torch.Size(xsize))
294-
xxtensor = xtensor.expand_as(urtensor)
295-
urtensor.copy_(xxtensor)
294+
295+
urtensor.copy_(xtensor.expand_as(urtensor))
296+
296297
return result
297298

298299
def masked_copy_(self, *args, **kwargs):

0 commit comments

Comments
 (0)