Skip to content

Commit 35f90b2

Browse files
committed
Addressed comments
1 parent a810a93 commit 35f90b2

File tree

3 files changed

+31
-16
lines changed

3 files changed

+31
-16
lines changed

test/test_autograd.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2394,7 +2394,11 @@ class dont_convert(tuple):
23942394
('dist', (), ((S, S, S), 4), 'scalar_4_broadcast_lhs'),
23952395
('diag', (M, M), NO_ARGS, '2d'),
23962396
('diag', (3, 5), NO_ARGS, '2d_wide'),
2397+
('diag', (3, 5), (2,), '2d_wide_pos'),
2398+
('diag', (3, 5), (-2,), '2d_wide_neg'),
23972399
('diag', (5, 3), NO_ARGS, '2d_tall'),
2400+
('diag', (5, 3), (2,), '2d_tall_pos'),
2401+
('diag', (5, 3), (-2,), '2d_tall_neg'),
23982402
('diag', (M,), NO_ARGS, '1d'),
23992403
('diag', (M, M), (1,), '2d_1'),
24002404
('diag', (M, M), (2,), '2d_2'),

tools/autograd/derivatives.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@
192192
self: _det_with_svd_backward(grads, self, result0, result1, result2, result3)
193193

194194
- name: diag(Tensor self, int64_t diagonal)
195-
self: diag_backward(grad, self, diagonal)
195+
self: diag_backward(grad, self.sizes(), diagonal)
196196

197197
- name: dist(Tensor self, Tensor other, Scalar p)
198198
self: norm_backward(grad, self - other, p, result)

tools/autograd/templates/Functions.cpp

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -711,25 +711,36 @@ Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Ten
711711
return (r * grad).sum().view({1});
712712
}
713713

714-
Tensor diag_backward(const Tensor & grad, const Tensor & self, int64_t diagonal) {
715-
auto ndimension = self.ndimension();
714+
static inline int64_t diag_size(int64_t height, int64_t width, int64_t diagonal) {
715+
if (width > height) {
716+
return diag_size(width, height, -diagonal);
717+
}
718+
// Assumes height >= width
719+
auto longest_diag = width;
720+
if (diagonal >= 0) {
721+
return longest_diag - diagonal;
722+
}
723+
if (longest_diag < height + diagonal) {
724+
return longest_diag;
725+
}
726+
return height + diagonal;
727+
}
728+
729+
Tensor diag_backward(const Tensor & grad, IntList input_sizes, int64_t diagonal) {
730+
auto ndimension = input_sizes.size();
716731
TORCH_ASSERT(ndimension == 1 || ndimension == 2);
717732

718-
auto grad_input = grad.diag(diagonal);
719-
if (ndimension == 1 || self.size(0) == self.size(1)) {
720-
return grad_input;
733+
if (ndimension == 1 || input_sizes[0] == input_sizes[1]) {
734+
return grad.diag(diagonal);
721735
}
722736

723-
// cat rows or cols to grad_input so that it matches self's shape.
724-
auto length = grad_input.size(0);
725-
auto self_nrows = self.size(0);
726-
auto self_ncols = self.size(1);
727-
if (self_nrows == length) {
728-
auto extra_cols = grad_input.type().zeros({self_nrows, self_ncols - length});
729-
return at::cat({grad_input, extra_cols}, 1);
730-
}
731-
auto extra_rows = grad_input.type().zeros({self_nrows - length, self_ncols});
732-
return at::cat({grad_input, extra_rows});
737+
// Input was a matrix but was not square
738+
auto grad_input = grad.type().zeros(input_sizes);
739+
auto diagonal_size = diag_size(input_sizes[0], input_sizes[1], diagonal);
740+
auto storage_offset = diagonal >= 0 ? diagonal : -diagonal * input_sizes[1];
741+
auto diag = grad_input.as_strided({diagonal_size}, {input_sizes[1] + 1}, storage_offset);
742+
diag.copy_(grad);
743+
return grad_input;
733744
}
734745

735746
Tensor max_pool2d_double_backward(const Tensor & grad, const Tensor & indices) {

0 commit comments

Comments
 (0)