Skip to content

Conversation

@maciejkula
Copy link
Contributor

Tentative solution for #4054.

I must admit that the actual details of what is happening in tensor.repeat remain somewhat opaque to me, so I opted for the minimal possible change.

I would be happy to add some more robust tests; have you looked into Hypothesis? It's probably the most advanced property-based testing library out there, and is extremely good at making it easier to write effective tests.

These test tensor.repeat when using numpy-shared storage. The new
test adds some additional test cases and compares against np.tile.
Clone the input tensor instead of resizing it in place.
raise ValueError('Number of dimensions of repeat dims can not be '
'smaller than number of dimensions of tensor')

xtensor = src.new().set_(src)

This comment was marked as off-topic.

This comment was marked as off-topic.

@maciejkula
Copy link
Contributor Author

Updated to use .view.

torch/tensor.py Outdated

size = torch.Size([a * b for a, b in zip(xsize, repeats)])
xtensor.resize_(torch.Size(xsize))
xtensor = xtensor.view(torch.Size(xsize))

This comment was marked as off-topic.

This comment was marked as off-topic.

I added some explanatory comments to the function itself and added a
test that repeats a non-contiguous input.
@maciejkula
Copy link
Contributor Author

@apaszke I moved to using an expand to add the new leading dimensions, and added a test for repeating non-contiguous inputs.

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Some minor comments, and should be good to merge once they are resolved.

xxtensor = xtensor.expand_as(urtensor)
urtensor.copy_(xxtensor)

urtensor.copy_(xtensor.expand_as(urtensor))

This comment was marked as off-topic.

This comment was marked as off-topic.

result = tensor.repeat(torchSize)
self.assertEqual(result.size(), target, 'Error in repeat using result and LongStorage')
self.assertEqual(result.mean(0).view(8, 4), tensor, 'Error in repeat (not equal)')

This comment was marked as off-topic.

@apaszke apaszke merged commit d4d8698 into pytorch:master Dec 16, 2017
@apaszke
Copy link
Contributor

apaszke commented Dec 16, 2017

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants