Skip to content

Conversation

@nairbv
Copy link
Collaborator

@nairbv nairbv commented Oct 29, 2018

Adding a roll operator

// If the first dimension is zero, this is an empty tensor and rolls do nothing.
// Return a clone so the caller can safely modify result, and avoid a div by
// zero error below.
if( self.size(0) == 0 ) {

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Tensor roll_cpu(const Tensor& self, IntList shifts, IntList dims) {
// todo: support rolling along no or multiple dimensions as in numpy.roll.
AT_CHECK(dims.size() == 1, "only single dimension roll currently supported");
AT_CHECK(shifts.size() == dims.size(), "shifts and dimensions must align");

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

vec[index++] = tensors[i];
}

auto stacked = at::stack(vec, dim);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

CPU: flip_cpu
CUDA: flip_cuda

- func: roll(Tensor self, IntList[1] shifts, IntList[1] dims) -> Tensor

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

// corrects the difference.
if( start < 0 ) start = start + size;

const int64_t block_size = 512;

This comment was marked as off-topic.

This comment was marked as off-topic.

return;
}
// roll dim idx is the index of linear_index along the rolling dimension.
int64_t roll_dim_idx = linear_index % (stride * size) / stride;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.


add_docstr(torch.roll,
r"""
roll(input, shift, dims) -> Tensor

This comment was marked as off-topic.


add_docstr_all('roll',
r"""
roll(shift, dims) -> Tensor

This comment was marked as off-topic.

return;
}
// roll dim idx is the index of linear_index along the rolling dimension.
int64_t roll_dim_idx = linear_index % (stride * size) / stride;

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@gchanan gchanan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is good to go once the shift/shifts issue is dealt with.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nairbv is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

zdevito pushed a commit to zdevito/ATen that referenced this pull request Nov 5, 2018
Summary:
Adding a roll operator
Pull Request resolved: pytorch/pytorch#13261

Differential Revision: D12922575

Pulled By: nairbv

fbshipit-source-id: ff05c075d9c484a615011192b023debf47da4017
@ezyang ezyang added the merged label Jun 25, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants