@@ -711,25 +711,36 @@ Tensor smooth_l1_loss_double_backward_grad_output(const Tensor & grad, const Ten
711711 return (r * grad).sum ().view ({1 });
712712}
713713
714- Tensor diag_backward (const Tensor & grad, const Tensor & self, int64_t diagonal) {
715- auto ndimension = self.ndimension ();
714+ static inline int64_t diag_size (int64_t height, int64_t width, int64_t diagonal) {
715+ if (width > height) {
716+ return diag_size (width, height, -diagonal);
717+ }
718+ // Assumes height >= width
719+ auto longest_diag = width;
720+ if (diagonal >= 0 ) {
721+ return longest_diag - diagonal;
722+ }
723+ if (longest_diag < height + diagonal) {
724+ return longest_diag;
725+ }
726+ return height + diagonal;
727+ }
728+
729+ Tensor diag_backward (const Tensor & grad, IntList input_sizes, int64_t diagonal) {
730+ auto ndimension = input_sizes.size ();
716731 TORCH_ASSERT (ndimension == 1 || ndimension == 2 );
717732
718- auto grad_input = grad.diag (diagonal);
719- if (ndimension == 1 || self.size (0 ) == self.size (1 )) {
720- return grad_input;
733+ if (ndimension == 1 || input_sizes[0 ] == input_sizes[1 ]) {
734+ return grad.diag (diagonal);
721735 }
722736
723- // cat rows or cols to grad_input so that it matches self's shape.
724- auto length = grad_input.size (0 );
725- auto self_nrows = self.size (0 );
726- auto self_ncols = self.size (1 );
727- if (self_nrows == length) {
728- auto extra_cols = grad_input.type ().zeros ({self_nrows, self_ncols - length});
729- return at::cat ({grad_input, extra_cols}, 1 );
730- }
731- auto extra_rows = grad_input.type ().zeros ({self_nrows - length, self_ncols});
732- return at::cat ({grad_input, extra_rows});
737+ // Input was a matrix but was not square
738+ auto grad_input = grad.type ().zeros (input_sizes);
739+ auto diagonal_size = diag_size (input_sizes[0 ], input_sizes[1 ], diagonal);
740+ auto storage_offset = diagonal >= 0 ? diagonal : -diagonal * input_sizes[1 ];
741+ auto diag = grad_input.as_strided ({diagonal_size}, {input_sizes[1 ] + 1 }, storage_offset);
742+ diag.copy_ (grad);
743+ return grad_input;
733744}
734745
735746Tensor max_pool2d_double_backward (const Tensor & grad, const Tensor & indices) {
0 commit comments