@@ -717,9 +717,36 @@ Tensor diag_backward(const Tensor & grad, IntList input_sizes, int64_t diagonal)
717717 return grad_input;
718718}
719719
720- Tensor diagonal_backward (const Tensor & grad, IntList input_sizes, int64_t offset, int64_t dim1 , int64_t dim2 ) {
720+ Tensor diagonal_backward (const Tensor & grad, IntList input_sizes, int64_t offset, int64_t dim1_ , int64_t dim2_ ) {
721721 auto grad_input = at::zeros (grad.type (), input_sizes);
722- auto diag = at::diagonal (grad_input, offset, dim1, dim2);
722+ // the following until the assignment of auto diag
723+ // copies the diagonal code in aten/src/ATen/native/TensorShape.cpp
724+ // that would be equivalent to
725+ // auto diag = grad_input.diagonal(offset, dim1, dim2);
726+ // when using diagonal, the output is not differentiable twice
727+ // while this works
728+ int64_t nDims = input_sizes.size ();
729+ int64_t dim1 = at::maybe_wrap_dim (dim1_, nDims);
730+ int64_t dim2 = at::maybe_wrap_dim (dim2_, nDims);
731+ int64_t diag_size;
732+ int64_t storage_offset = grad_input.storage_offset ();
733+ if (offset >= 0 ) {
734+ diag_size = std::min (grad_input.size (dim1), grad_input.size (dim2)-offset);
735+ storage_offset += offset * grad_input.stride (dim2);
736+ } else {
737+ diag_size = std::min (grad_input.size (dim1)+offset, grad_input.size (dim2));
738+ storage_offset -= offset * grad_input.stride (dim1);
739+ }
740+ auto sizes = std::vector<int64_t >(grad_input.sizes ());
741+ auto strides = std::vector<int64_t >(grad_input.strides ());
742+ sizes.erase (sizes.begin () + std::max (dim1, dim2));
743+ strides.erase (strides.begin () + std::max (dim1, dim2));
744+ sizes.erase (sizes.begin () + std::min (dim1, dim2));
745+ strides.erase (strides.begin () + std::min (dim1, dim2));
746+ sizes.push_back (diag_size);
747+ strides.push_back (grad_input.stride (dim1)+grad_input.stride (dim2));
748+ auto diag = grad_input.as_strided (sizes, strides, storage_offset);
749+
723750 diag.copy_ (grad);
724751 return grad_input;
725752}
0 commit comments