Skip to content

Commit e4eaf67

Browse files
zou3519soumith
authored andcommitted
Fix torch.diag backward with non-square matrix (#4538)
* Fix torch.diag backward with non-square matrix * Addressed comments
1 parent 91efc30 commit e4eaf67

File tree

3 files changed

+39
-1
lines changed

3 files changed

+39
-1
lines changed

test/test_autograd.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2433,6 +2433,12 @@ class dont_convert(tuple):
24332433
('dist', (S, S, S), ((), 4), 'scalar_4_broadcast_rhs'),
24342434
('dist', (), ((S, S, S), 4), 'scalar_4_broadcast_lhs'),
24352435
('diag', (M, M), NO_ARGS, '2d'),
2436+
('diag', (3, 5), NO_ARGS, '2d_wide'),
2437+
('diag', (3, 5), (2,), '2d_wide_pos'),
2438+
('diag', (3, 5), (-2,), '2d_wide_neg'),
2439+
('diag', (5, 3), NO_ARGS, '2d_tall'),
2440+
('diag', (5, 3), (2,), '2d_tall_pos'),
2441+
('diag', (5, 3), (-2,), '2d_tall_neg'),
24362442
('diag', (M,), NO_ARGS, '1d'),
24372443
('diag', (M, M), (1,), '2d_1'),
24382444
('diag', (M, M), (2,), '2d_2'),

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@
198198
self: _det_with_svd_backward(grads, self, result0, result1, result2, result3)
199199

200200
- name: diag(Tensor self, int64_t diagonal)
201-
self: grad.diag(diagonal)
201+
self: diag_backward(grad, self.sizes(), diagonal)
202202

203203
- name: dist(Tensor self, Tensor other, Scalar p)
204204
self: norm_backward(grad, self - other, p, result)

tools/autograd/templates/Functions.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,38 @@ Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Ten
709709
return (r * grad).sum().view({1});
710710
}
711711

712+
static inline int64_t diag_size(int64_t height, int64_t width, int64_t diagonal) {
713+
if (width > height) {
714+
return diag_size(width, height, -diagonal);
715+
}
716+
// Assumes height >= width
717+
auto longest_diag = width;
718+
if (diagonal >= 0) {
719+
return longest_diag - diagonal;
720+
}
721+
if (longest_diag < height + diagonal) {
722+
return longest_diag;
723+
}
724+
return height + diagonal;
725+
}
726+
727+
Tensor diag_backward(const Tensor & grad, IntList input_sizes, int64_t diagonal) {
728+
auto ndimension = input_sizes.size();
729+
TORCH_ASSERT(ndimension == 1 || ndimension == 2);
730+
731+
if (ndimension == 1 || input_sizes[0] == input_sizes[1]) {
732+
return grad.diag(diagonal);
733+
}
734+
735+
// Input was a matrix but was not square
736+
auto grad_input = grad.type().zeros(input_sizes);
737+
auto diagonal_size = diag_size(input_sizes[0], input_sizes[1], diagonal);
738+
auto storage_offset = diagonal >= 0 ? diagonal : -diagonal * input_sizes[1];
739+
auto diag = grad_input.as_strided({diagonal_size}, {input_sizes[1] + 1}, storage_offset);
740+
diag.copy_(grad);
741+
return grad_input;
742+
}
743+
712744
Tensor max_pool2d_double_backward(const Tensor & grad, const Tensor & indices) {
713745
// fold the first two dims together and the last two together
714746
auto fold = [](const Tensor & t) -> Tensor {

0 commit comments

Comments
 (0)