Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions test/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,12 @@ def method_tests():
'tall_all', (), NO_ARGS, [skipIfNoLapack], lambda usv: (usv[0][:, :(S - 2)], usv[1], usv[2])),
('svd', lambda: random_fullrank_matrix_distinct_singular_value(M), NO_ARGS,
'large', (), NO_ARGS, [skipIfNoLapack]),
('qr', (S, S), (False,), 'square_single', (), NO_ARGS, [skipIfNoLapack]),
('qr', (S, S - 2), (True,), 'tall_single' , (), NO_ARGS, [skipIfNoLapack]),
('qr', (3, S, S), (False,), 'square_batched', (), NO_ARGS, [skipIfNoLapack]),
('qr', (3, S, S - 2), (True,), 'tall_batched', (), NO_ARGS, [skipIfNoLapack]),
('qr', (3, 2, S, S), (False,), 'square_many_batched', (), NO_ARGS, [skipIfNoLapack]),
('qr', (3, 2, S, S - 2), (True,), 'tall_many_batched', (), NO_ARGS, [skipIfNoLapack]),
('solve', (S, S), (random_fullrank_matrix_distinct_singular_value(
S, silent=True),), '', (), NO_ARGS, [skipIfNoLapack]),
('solve', (S, S, S), (random_fullrank_matrix_distinct_singular_value(S, S, silent=True),),
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@
source: grad.take(index)

- name: qr(Tensor self, bool some)
self: not_implemented("qr")
self: qr_backward(grads, self, some, Q, R)

- name: random_(Tensor self, int64_t from, int64_t to, Generator generator)
self: zeros_like(grad)
Expand Down
64 changes: 64 additions & 0 deletions tools/autograd/templates/Functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1737,6 +1737,70 @@ Tensor symeig_backward(const std::vector<torch::autograd::Variable> &grads, cons
return result;
}

// We refer Walter, S.F and Lehmann, L., Algorithmic Differentiation of Linear
// Algebra Functions with Application in Optimum Experimental Design (Extended Version)
// The derivative for the QR decomposition is adapted from Eq. 42 of the
// above reference.
Tensor qr_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool some, const Tensor& Q, const Tensor& R) {
auto grad_Q = grads[0];
auto grad_R = grads[1];
TORCH_CHECK(R.size(-2) == R.size(-1),
"The derivative when R is non-square is not implemented. ");

// Compute R (R')^{T}
Tensor R_term;
if (grad_R.defined()) {
R_term = at::matmul(R, grad_R.transpose(-2, -1));
} else {
// R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N
R_term = at::zeros_like(R);
}

// Compute Q^{T} Q'
Tensor Q_term;
if (grad_Q.defined()) {
Q_term = at::matmul(Q.transpose(-2, -1), grad_Q);
} else {
// Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N
Q_term = at::zeros_like(R);
}

// We want to compute: (rhs_solve_1 . R^{-T})
// Note that (rhs_solve_1 . R^{-T}) = (R^{-1} . rhs_solve_1^{T})^{T}
// Since R is upper triangular, we can do this using
// triangular_solve(rhs_solve_1^{T}, R)^{T}
auto rhs_solve_1 = R_term - R_term.transpose(-2, -1) + Q_term - Q_term.transpose(-2, -1);
rhs_solve_1 = at::tril(rhs_solve_1, /*k=*/-1);
Tensor solve_soln_1;
std::tie(solve_soln_1, std::ignore) = at::triangular_solve(rhs_solve_1.transpose(-2, -1), R,
/*upper=*/true, /*transpose=*/false,
/*unitriangular=*/false);
Tensor grad_A;
if (grad_R.defined()) {
grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1) + grad_R);
} else {
grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1));
}

// Successive computations involve computation of QQ^{T} which is identity when A is square
if (self.size(-1) != self.size(-2)) {
Tensor rhs_solve_2;
// We use the same trick from above for this computation
if (grad_Q.defined()) {
rhs_solve_2 = grad_Q - at::matmul(Q, Q_term);
} else {
rhs_solve_2 = -at::matmul(Q, Q_term);
}
Tensor solve_soln_2;
std::tie(solve_soln_2, std::ignore) = at::triangular_solve(rhs_solve_2.transpose(-2, -1), R,
/*upper=*/true, /*transpose=*/false,
/*unitriangular=*/false);
grad_A.add_(solve_soln_2.transpose(-2, -1));
}
return grad_A;
}

// Invertible case is derived from Jacobi's formula, and also can be found at:
// http://eprints.maths.ox.ac.uk/1079/1/NA-08-01.pdf
Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det) {
Expand Down