Skip to content

Conversation

@t-vi
Copy link
Collaborator

@t-vi t-vi commented Apr 18, 2018

This patch

  • adds Tensor.diagonal to complement torch.diagonal
  • implements diagonal natively in ATen
  • makes diagonal a view (similar to numpy semantics, numpy currently has a read-only view to ease transition from diagonal returning a copy)
  • implements taking arbitrary diagonals with numpy semantics
  • implements diagonal backward instead of referring
    to the (more limited) diag

There is some discussion in #6479 .

This patch
- adds Tensor.diagonal to complement torch.diagonal
- implements diagonal natively in ATen
- makes diagonal a view
- implements taking arbitrary diagonals
- implements diagonal backward instead of referring
  to the (more limited) diag
diag_size = std::min(self.size(dim1)+offset, self.size(dim2));
storage_offset -= offset * self.stride(dim1);
}
AT_ASSERT(diag_size > 0, "invalid diagonal offset %zd", offset); // the diagonal offset was too large in magnitude

This comment was marked as off-topic.

This comment was marked as off-topic.


@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
def test_diagonal_multidim(self):
x = torch.randn(10, 11, 12, 13)

This comment was marked as off-topic.

Tensor diagonal_backward(const Tensor & grad, IntList input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
auto grad_input = at::zeros(grad.type(), input_sizes);
auto diag = at::diagonal(grad_input, offset, dim1, dim2);
diag.copy_(grad);

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@t-vi t-vi force-pushed the diagonal_with_dim branch from 67d2734 to a7d0b1f Compare April 23, 2018 14:05
@ezyang
Copy link
Contributor

ezyang commented Apr 23, 2018

@pytorchbot retest this please

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 23, 2018

@ezyang is that my patch or the macos CI that has the bug? I didn't know that anything I did affected the fourier tests...

@t-vi
Copy link
Collaborator Author

t-vi commented Apr 25, 2018

Hi, is there anything I can do to move this forward?
I'd be looking at adding the diagonal feature in einsum when the improved diagonal is available.

@ezyang
Copy link
Contributor

ezyang commented Apr 26, 2018

@pytorchbot retest this please

I don't see how this PR could have triggered this. Let's try again.

@apaszke
Copy link
Contributor

apaszke commented Apr 26, 2018

@pytorchbot retest this please

>>> x = torch.randn(2, 5, 4, 2)
>>> torch.diagonal(x, offset=-1, dim1=1, dim2=2)
(0 ,.,.) =

This comment was marked as off-topic.

@ezyang ezyang merged commit 2b44c42 into pytorch:master Apr 26, 2018
Jorghi12 pushed a commit to wsttiger/pytorch that referenced this pull request May 10, 2018
* Enhance diagonal

This patch
- adds Tensor.diagonal to complement torch.diagonal
- implements diagonal natively in ATen
- makes diagonal a view
- implements taking arbitrary diagonals
- implements diagonal backward instead of referring
  to the (more limited) diag

* add tests, copy diagonal code to backward for double differentiability

* improve tests and doc comment. Thank you, Adam!

* Mark diagonal as view function in gen_autograd.py, use simple backward.
weiyangfb pushed a commit to weiyangfb/pytorch that referenced this pull request Jun 11, 2018
* Enhance diagonal

This patch
- adds Tensor.diagonal to complement torch.diagonal
- implements diagonal natively in ATen
- makes diagonal a view
- implements taking arbitrary diagonals
- implements diagonal backward instead of referring
  to the (more limited) diag

* add tests, copy diagonal code to backward for double differentiability

* improve tests and doc comment. Thank you, Adam!

* Mark diagonal as view function in gen_autograd.py, use simple backward.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants