Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torch.utils.checkpoint import checkpoint
from torch.testing._internal.common_utils import (TEST_MKL, TEST_WITH_ROCM, TestCase, run_tests, skipIfNoLapack,
suppress_warnings, slowTest,
load_tests, random_symmetric_pd_matrix, random_symmetric_matrix,
load_tests, random_symmetric_matrix,
IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck)
from torch.autograd import Variable, Function, detect_anomaly
from torch.autograd.function import InplaceFunction
Expand Down Expand Up @@ -2501,22 +2501,28 @@ def test_var_mean_differentiable(self):
@skipIfNoLapack
def test_cholesky(self):
def func(root, upper):
x = torch.matmul(root, root.transpose(-1, -2)) + 1e-05
x = 0.5 * (root + root.transpose(-1, -2).conj())
return torch.cholesky(x, upper)

def run_test(upper, dims):
root = torch.rand(*dims, requires_grad=True)
def run_test(upper, dims, dtype):
root = torch.rand(*dims, dtype=dtype, requires_grad=True)
root = root + torch.eye(dims[-1])

gradcheck(func, [root, upper])
gradgradcheck(func, [root, upper])
# TODO: gradgradcheck does not work correctly yet for complex
if not dtype.is_complex:
gradgradcheck(func, [root, upper])

root = random_symmetric_pd_matrix(dims[-1], *dims[:-2]).requires_grad_()
root = torch.rand(*dims, dtype=dtype)
root = torch.matmul(root, root.transpose(-1, -2).conj())
root.requires_grad_()
chol = root.cholesky().sum().backward()
self.assertEqual(root.grad, root.grad.transpose(-1, -2)) # Check the gradient is symmetric
self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian

for upper, dims in product([True, False], [(3, 3), (4, 3, 2, 2)]):
run_test(upper, dims)
run_test(upper, dims)
for upper, dims, dtype in product([True, False],
[(3, 3), (4, 3, 2, 2)],
[torch.double, torch.cdouble]):
run_test(upper, dims, dtype)

@skipIfNoLapack
def test_cholesky_solve(self):
Expand Down
10 changes: 5 additions & 5 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -711,15 +711,15 @@ Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
// leads to stable gradient updates, and retains symmetry of the updated matrix if it
// were updated by a gradient based algorithm.
if (upper) {
L = L.transpose(-1, -2);
grad = grad.transpose(-1, -2);
L = L.transpose(-1, -2).conj();
grad = grad.transpose(-1, -2).conj();
}
auto L_inverse = std::get<0>(at::triangular_solve(at::eye(L.size(-1), L.options()), L, /*upper=*/false));
auto phi = at::matmul(L.transpose(-1, -2), grad);
auto phi = at::matmul(L.transpose(-1, -2).conj(), grad);
phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5);

auto grad_input = at::matmul(at::matmul(L_inverse.transpose(-1, -2), phi), L_inverse);
return grad_input.add(grad_input.transpose(-1, -2)).mul_(0.5); // Symmetrizing the gradient
auto grad_input = at::matmul(at::matmul(L_inverse.transpose(-1, -2).conj(), phi), L_inverse);
return grad_input.add(grad_input.transpose(-1, -2).conj()).mul_(0.5); // Symmetrizing the gradient
}

Tensor cholesky_inverse_backward(Tensor grad, Tensor L, bool upper, Tensor inverse) {
Expand Down