Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2393,6 +2393,12 @@ class dont_convert(tuple):
('dist', (S, S, S), ((), 4), 'scalar_4_broadcast_rhs'),
('dist', (), ((S, S, S), 4), 'scalar_4_broadcast_lhs'),
('diag', (M, M), NO_ARGS, '2d'),
('diag', (3, 5), NO_ARGS, '2d_wide'),
('diag', (3, 5), (2,), '2d_wide_pos'),
('diag', (3, 5), (-2,), '2d_wide_neg'),
('diag', (5, 3), NO_ARGS, '2d_tall'),
('diag', (5, 3), (2,), '2d_tall_pos'),
('diag', (5, 3), (-2,), '2d_tall_neg'),
('diag', (M,), NO_ARGS, '1d'),
('diag', (M, M), (1,), '2d_1'),
('diag', (M, M), (2,), '2d_2'),
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@
self: _det_with_svd_backward(grads, self, result0, result1, result2, result3)

- name: diag(Tensor self, int64_t diagonal)
self: grad.diag(diagonal)
self: diag_backward(grad, self.sizes(), diagonal)

- name: dist(Tensor self, Tensor other, Scalar p)
self: norm_backward(grad, self - other, p, result)
Expand Down
32 changes: 32 additions & 0 deletions tools/autograd/templates/Functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,38 @@ Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Ten
return (r * grad).sum().view({1});
}

static inline int64_t diag_size(int64_t height, int64_t width, int64_t diagonal) {
if (width > height) {
return diag_size(width, height, -diagonal);
}
// Assumes height >= width
auto longest_diag = width;
if (diagonal >= 0) {
return longest_diag - diagonal;
}
if (longest_diag < height + diagonal) {
return longest_diag;
}
return height + diagonal;
}

Tensor diag_backward(const Tensor & grad, IntList input_sizes, int64_t diagonal) {
auto ndimension = input_sizes.size();
TORCH_ASSERT(ndimension == 1 || ndimension == 2);

if (ndimension == 1 || input_sizes[0] == input_sizes[1]) {
return grad.diag(diagonal);
}

// Input was a matrix but was not square

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

auto grad_input = grad.type().zeros(input_sizes);

This comment was marked as off-topic.

This comment was marked as off-topic.

auto diagonal_size = diag_size(input_sizes[0], input_sizes[1], diagonal);
auto storage_offset = diagonal >= 0 ? diagonal : -diagonal * input_sizes[1];
auto diag = grad_input.as_strided({diagonal_size}, {input_sizes[1] + 1}, storage_offset);
diag.copy_(grad);
return grad_input;
}

Tensor max_pool2d_double_backward(const Tensor & grad, const Tensor & indices) {
// fold the first two dims together and the last two together
auto fold = [](const Tensor & t) -> Tensor {
Expand Down