Skip to content

Commit 456c524

Browse files
t-viweiyangfb
authored andcommitted
Enhance diagonal (fixes pytorch#6479) (pytorch#6718)
* Enhance diagonal This patch - adds Tensor.diagonal to complement torch.diagonal - implements diagonal natively in ATen - makes diagonal a view - implements taking arbitrary diagonals - implements diagonal backward instead of referring to the (more limited) diag * add tests, copy diagonal code to backward for double differentiability * improve tests and doc comment. Thank you, Adam! * Mark diagonal as view function in gen_autograd.py, use simple backward.
1 parent 1f6bea3 commit 456c524

File tree

8 files changed

+113
-11
lines changed

8 files changed

+113
-11
lines changed

aten/src/ATen/native/TensorShape.cpp

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,42 @@ Tensor diagflat(const Tensor& self, int64_t offset) {
4848
return self.contiguous().view(-1).diag(offset);
4949
}
5050

51-
Tensor diagonal(const Tensor& self, int64_t offset) {
52-
if (self.dim() != 2) {
53-
throw std::runtime_error("diagonal expects a 2-dimensional tensor");
54-
}
55-
return self.diag(offset);
51+
Tensor diagonal(const Tensor& self, int64_t offset, int64_t dim1_, int64_t dim2_) {
52+
int64_t nDims = self.dim();
53+
int64_t dim1 = maybe_wrap_dim(dim1_, nDims);
54+
int64_t dim2 = maybe_wrap_dim(dim2_, nDims);
55+
AT_ASSERT(dim1 != dim2, "diagonal dimensions cannot be identical %zd, %zd", dim1_, dim2_);
56+
int64_t diag_size;
57+
int64_t storage_offset = self.storage_offset();
58+
// compute storage offset and size for the diagonal
59+
// for positive values of offset (above the main diagonal)
60+
// "leftmost columns" (along dim2) are dropped
61+
// for negative values of offset (below the main diagonal)
62+
// "topmost rows" (along dim1) are dropped.
63+
// Note that we invert +/- in the second to absorb the negative
64+
// sign in the offset.
65+
if (offset >= 0) {
66+
diag_size = std::min(self.size(dim1), self.size(dim2)-offset);
67+
storage_offset += offset * self.stride(dim2);
68+
} else {
69+
diag_size = std::min(self.size(dim1)+offset, self.size(dim2));
70+
storage_offset -= offset * self.stride(dim1);
71+
}
72+
AT_ASSERT(diag_size > 0, "invalid diagonal offset %zd", offset); // the diagonal offset was too large in magnitude
73+
74+
// construct new size and stride: we drop dim1 and dim2 (maximum first for not changing the index of the minumum)
75+
// the new ("joint") dimension is appended to the end of the shape / stride to match numpy semantics
76+
auto sizes = std::vector<int64_t>(self.sizes());
77+
auto strides = std::vector<int64_t>(self.strides());
78+
sizes.erase(sizes.begin() + std::max(dim1, dim2));
79+
strides.erase(strides.begin() + std::max(dim1, dim2));
80+
sizes.erase(sizes.begin() + std::min(dim1, dim2));
81+
strides.erase(strides.begin() + std::min(dim1, dim2));
82+
sizes.push_back(diag_size);
83+
strides.push_back(self.stride(dim1)+self.stride(dim2));
84+
85+
// return view with new parameters
86+
return self.as_strided(sizes, strides, storage_offset);
5687
}
5788

5889
Tensor expand(const Tensor& self, IntList size) {

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,7 @@
277277
- func: diagflat(Tensor self, int64_t offset=0) -> Tensor
278278
variants: function
279279

280-
- func: diagonal(Tensor self, int64_t offset=0) -> Tensor
281-
variants: function
280+
- func: diagonal(Tensor self, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) -> Tensor
282281

283282
- func: dot(Tensor self, Tensor tensor) -> Tensor
284283

test/test_autograd.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2145,6 +2145,18 @@ def test_mul_out_result_requires_grad(self):
21452145
# we should throw an exception if the output requires grad
21462146
self.assertRaisesRegex(RuntimeError, 'out=', lambda: torch.mul(a, b, out=x))
21472147

2148+
def test_diagonal_derivative_requires_grad(self):
2149+
# test that the backward requires grad
2150+
# we do this is because diagonal_backward uses inplace
2151+
# operations and gradgradcheck does not catch whether
2152+
# they works as expected (it will succeed even if
2153+
# the gradient has requires_grad == False
2154+
a = torch.randn(5, 6, requires_grad=True)
2155+
b = torch.diagonal(a)**2
2156+
c = b.sum()
2157+
d, = torch.autograd.grad(c, a, retain_graph=True, create_graph=True)
2158+
self.assertTrue(d.requires_grad)
2159+
21482160

21492161
def index_variable(shape, max_indices):
21502162
if not isinstance(shape, tuple):
@@ -2661,6 +2673,18 @@ class dont_convert(tuple):
26612673
('diag', (M,), NO_ARGS, '1d'),
26622674
('diag', (M, M), (1,), '2d_1'),
26632675
('diag', (M, M), (2,), '2d_2'),
2676+
('diagonal', (M, M), NO_ARGS, '2d'),
2677+
('diagonal', (3, 5), NO_ARGS, '2d_wide'),
2678+
('diagonal', (3, 5), (2,), '2d_wide_pos'),
2679+
('diagonal', (3, 5), (-2,), '2d_wide_neg'),
2680+
('diagonal', (5, 3), NO_ARGS, '2d_tall'),
2681+
('diagonal', (5, 3), (2,), '2d_tall_pos'),
2682+
('diagonal', (5, 3), (-2,), '2d_tall_neg'),
2683+
('diagonal', (M, M), (1,), '2d_1'),
2684+
('diagonal', (M, M), (2,), '2d_2'),
2685+
('diagonal', (M, M, M), (1, 1, 2), '3d_1'),
2686+
('diagonal', (M, M, M), (2, 0, 1), '3d_2'),
2687+
('diagonal', (M, M, M), (-2, 0, 1), '3d_3'),
26642688
('tril', (M, M), NO_ARGS),
26652689
('tril', (M, M), (2,), 'idx'),
26662690
('triu', (M, M), NO_ARGS),

test/test_torch.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2014,6 +2014,25 @@ def _test_diagonal(self, dtype, device):
20142014
def test_diagonal(self):
20152015
self._test_diagonal(self, dtype=torch.float32, device='cpu')
20162016

2017+
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
2018+
def test_diagonal_multidim(self):
2019+
x = torch.randn(10, 11, 12, 13)
2020+
xn = x.numpy()
2021+
for args in [(2, 2, 3),
2022+
(2,),
2023+
(-2, 1, 2),
2024+
(0, -2, -1)]:
2025+
result = torch.diagonal(x, *args)
2026+
expected = xn.diagonal(*args)
2027+
self.assertEqual(expected.shape, result.shape)
2028+
self.assertTrue(np.allclose(expected, result.numpy()))
2029+
# test non-continguous
2030+
xp = x.permute(1, 2, 3, 0)
2031+
result = torch.diagonal(xp, 0, -2, -1)
2032+
expected = xp.numpy().diagonal(0, -2, -1)
2033+
self.assertEqual(expected.shape, result.shape)
2034+
self.assertTrue(np.allclose(expected, result.numpy()))
2035+
20172036
@staticmethod
20182037
def _test_diagflat(self, dtype, device):
20192038
# Basic sanity test

tools/autograd/derivatives.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,9 @@
200200
- name: diag(Tensor self, int64_t diagonal)
201201
self: diag_backward(grad, self.sizes(), diagonal)
202202

203+
- name: diagonal(Tensor self, int64_t offset, int64_t dim1, int64_t dim2)
204+
self: diagonal_backward(grad, self.sizes(), offset, dim1, dim2)
205+
203206
- name: dist(Tensor self, Tensor other, Scalar p)
204207
self: norm_backward(grad, self - other, p, result)
205208
other: -norm_backward(grad, self - other, p, result)

tools/autograd/gen_autograd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
deprecated_path = os.path.join(os.path.dirname(__file__), 'deprecated.yaml')
2020

2121
VIEW_FUNCTIONS = {
22-
'alias', 'as_strided', 'expand', 'narrow', 'permute', 'select', 'slice',
22+
'alias', 'as_strided', 'diagonal', 'expand', 'narrow', 'permute', 'select', 'slice',
2323
'squeeze', 't', 'transpose', 'unfold', 'unsqueeze', 'view',
2424
}
2525

tools/autograd/templates/Functions.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,13 @@ Tensor diag_backward(const Tensor & grad, IntList input_sizes, int64_t diagonal)
720720
return grad_input;
721721
}
722722

723+
Tensor diagonal_backward(const Tensor & grad, IntList input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
724+
auto grad_input = at::zeros(grad.type(), input_sizes);
725+
auto diag = grad_input.diagonal(offset, dim1, dim2);
726+
diag.copy_(grad);
727+
return grad_input;
728+
}
729+
723730
Tensor mse_loss_double_backward(const Tensor & grad, const Tensor & input, bool size_average, bool reduce) {
724731
auto grad_input = 2 * grad;
725732
if (size_average && reduce) {

torch/_torch_docs.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1026,9 +1026,11 @@ def parse_kwargs(desc):
10261026

10271027
add_docstr(torch.diagonal,
10281028
r"""
1029-
diagonal(input, offset=0) -> Tensor
1029+
diagonal(input, offset=0, dim1=0, dim2=1) -> Tensor
10301030
1031-
Returns a 1-D tensor with the diagonal elements of :attr:`input`.
1031+
Returns a partial view of :attr:`input` with the its diagonal elements
1032+
with respect to :attr:`dim1` and :attr:`dim2` appended as a dimension
1033+
at the end of the shape.
10321034
10331035
The argument :attr:`offset` controls which diagonal to consider:
10341036
@@ -1037,9 +1039,15 @@ def parse_kwargs(desc):
10371039
- If :attr:`offset` < 0, it is below the main diagonal.
10381040
10391041
Args:
1040-
input (Tensor): the input tensor. Must be 2-dimensional.
1042+
input (Tensor): the input tensor. Must be at least 2-dimensional.
10411043
offset (int, optional): which diagonal to consider. Default: 0
10421044
(main diagonal).
1045+
dim1 (int, optional): first dimension with respect to which to
1046+
take diagonal. Default: 0.
1047+
dim2 (int, optional): second dimension with respect to which to
1048+
take diagonal. Default: 1.
1049+
1050+
.. note:: To take a batch diagonal, pass in dim1=-2, dim2=-1.
10431051
10441052
Examples::
10451053
@@ -1058,6 +1066,17 @@ def parse_kwargs(desc):
10581066
tensor([ 1.1431, 0.0360])
10591067
10601068
1069+
>>> x = torch.randn(2, 5, 4, 2)
1070+
>>> torch.diagonal(x, offset=-1, dim1=1, dim2=2)
1071+
1072+
(0 ,.,.) =
1073+
-0.6806 -0.0281 -0.6595 -0.4199
1074+
0.8741 -0.1793 -0.6997 0.6265
1075+
1076+
(1 ,.,.) =
1077+
0.6182 1.3069 1.6503 1.7627
1078+
-0.2122 -0.2250 0.0990 -2.6433
1079+
[torch.FloatTensor of size (2,2,4)]
10611080
""")
10621081

10631082
add_docstr(torch.dist,

0 commit comments

Comments
 (0)