Skip to content

Commit 75051fa

Browse files
committed
Merged real and complex cholesky tests
1 parent 76b75cc commit 75051fa

File tree

1 file changed

+16
-34
lines changed

1 file changed

+16
-34
lines changed

test/test_autograd.py

Lines changed: 16 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from torch.utils.checkpoint import checkpoint
3232
from torch.testing._internal.common_utils import (TEST_MKL, TEST_WITH_ROCM, TestCase, run_tests, skipIfNoLapack,
3333
suppress_warnings, slowTest,
34-
load_tests, random_symmetric_pd_matrix, random_symmetric_matrix,
34+
load_tests, random_symmetric_matrix,
3535
IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck)
3636
from torch.autograd import Variable, Function, detect_anomaly
3737
from torch.autograd.function import InplaceFunction
@@ -2501,22 +2501,28 @@ def test_var_mean_differentiable(self):
25012501
@skipIfNoLapack
25022502
def test_cholesky(self):
25032503
def func(root, upper):
2504-
x = torch.matmul(root, root.transpose(-1, -2)) + 1e-05
2504+
x = 0.5 * (root + root.transpose(-1, -2).conj())
25052505
return torch.cholesky(x, upper)
25062506

2507-
def run_test(upper, dims):
2508-
root = torch.rand(*dims, requires_grad=True)
2507+
def run_test(upper, dims, dtype):
2508+
root = torch.rand(*dims, dtype=dtype, requires_grad=True)
2509+
root = root + torch.eye(dims[-1])
25092510

25102511
gradcheck(func, [root, upper])
2511-
gradgradcheck(func, [root, upper])
2512+
# TODO: gradgradcheck does not work correctly yet for complex
2513+
if not dtype.is_complex:
2514+
gradgradcheck(func, [root, upper])
25122515

2513-
root = random_symmetric_pd_matrix(dims[-1], *dims[:-2]).requires_grad_()
2516+
root = torch.rand(*dims, dtype=dtype)
2517+
root = torch.matmul(root, root.transpose(-1, -2).conj())
2518+
root.requires_grad_()
25142519
chol = root.cholesky().sum().backward()
2515-
self.assertEqual(root.grad, root.grad.transpose(-1, -2)) # Check the gradient is symmetric
2520+
self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian
25162521

2517-
for upper, dims in product([True, False], [(3, 3), (4, 3, 2, 2)]):
2518-
run_test(upper, dims)
2519-
run_test(upper, dims)
2522+
for upper, dims, dtype in product([True, False],
2523+
[(3, 3), (4, 3, 2, 2)],
2524+
[torch.double, torch.cdouble]):
2525+
run_test(upper, dims, dtype)
25202526

25212527
@skipIfNoLapack
25222528
def test_cholesky_solve(self):
@@ -4922,30 +4928,6 @@ def fn(*inputs):
49224928
setattr(TestAutogradDeviceType, test_name, do_test)
49234929

49244930
class TestAutogradComplex(TestCase):
4925-
@skipIfNoLapack
4926-
def test_complex_cholesky(self):
4927-
def func(x, upper):
4928-
x = 0.5 * (x + x.transpose(-1, -2).conj())
4929-
return torch.cholesky(x, upper)
4930-
4931-
def run_test(upper, dims):
4932-
x = torch.rand(*dims, requires_grad=True, dtype=torch.cdouble)
4933-
x = x + torch.eye(dims[-1])
4934-
4935-
gradcheck(func, [x, upper])
4936-
# TODO: gradgradcheck does not work
4937-
# gradgradcheck(func, [x, upper])
4938-
4939-
x = torch.rand(*dims, dtype=torch.cdouble)
4940-
x = torch.matmul(x, x.transpose(-1, -2).conj())
4941-
x.requires_grad_()
4942-
chol = x.cholesky().sum().backward()
4943-
self.assertEqual(x.grad, x.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian
4944-
4945-
for upper, dims in product([True, False], [(3, 3), (4, 3, 2, 2)]):
4946-
run_test(upper, dims)
4947-
run_test(upper, dims)
4948-
49494931
def test_view_func_for_complex_views(self):
49504932
# case 1: both parent and child have view_func
49514933
x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)

0 commit comments

Comments
 (0)