|
31 | 31 | from torch.utils.checkpoint import checkpoint |
32 | 32 | from torch.testing._internal.common_utils import (TEST_MKL, TEST_WITH_ROCM, TestCase, run_tests, skipIfNoLapack, |
33 | 33 | suppress_warnings, slowTest, |
34 | | - load_tests, random_symmetric_pd_matrix, random_symmetric_matrix, |
| 34 | + load_tests, random_symmetric_matrix, |
35 | 35 | IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck) |
36 | 36 | from torch.autograd import Variable, Function, detect_anomaly |
37 | 37 | from torch.autograd.function import InplaceFunction |
@@ -2501,22 +2501,28 @@ def test_var_mean_differentiable(self): |
2501 | 2501 | @skipIfNoLapack |
2502 | 2502 | def test_cholesky(self): |
2503 | 2503 | def func(root, upper): |
2504 | | - x = torch.matmul(root, root.transpose(-1, -2)) + 1e-05 |
| 2504 | + x = 0.5 * (root + root.transpose(-1, -2).conj()) |
2505 | 2505 | return torch.cholesky(x, upper) |
2506 | 2506 |
|
2507 | | - def run_test(upper, dims): |
2508 | | - root = torch.rand(*dims, requires_grad=True) |
| 2507 | + def run_test(upper, dims, dtype): |
| 2508 | + root = torch.rand(*dims, dtype=dtype, requires_grad=True) |
| 2509 | + root = root + torch.eye(dims[-1]) |
2509 | 2510 |
|
2510 | 2511 | gradcheck(func, [root, upper]) |
2511 | | - gradgradcheck(func, [root, upper]) |
| 2512 | + # TODO: gradgradcheck does not work correctly yet for complex |
| 2513 | + if not dtype.is_complex: |
| 2514 | + gradgradcheck(func, [root, upper]) |
2512 | 2515 |
|
2513 | | - root = random_symmetric_pd_matrix(dims[-1], *dims[:-2]).requires_grad_() |
| 2516 | + root = torch.rand(*dims, dtype=dtype) |
| 2517 | + root = torch.matmul(root, root.transpose(-1, -2).conj()) |
| 2518 | + root.requires_grad_() |
2514 | 2519 | chol = root.cholesky().sum().backward() |
2515 | | - self.assertEqual(root.grad, root.grad.transpose(-1, -2)) # Check the gradient is symmetric |
| 2520 | + self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian |
2516 | 2521 |
|
2517 | | - for upper, dims in product([True, False], [(3, 3), (4, 3, 2, 2)]): |
2518 | | - run_test(upper, dims) |
2519 | | - run_test(upper, dims) |
| 2522 | + for upper, dims, dtype in product([True, False], |
| 2523 | + [(3, 3), (4, 3, 2, 2)], |
| 2524 | + [torch.double, torch.cdouble]): |
| 2525 | + run_test(upper, dims, dtype) |
2520 | 2526 |
|
2521 | 2527 | @skipIfNoLapack |
2522 | 2528 | def test_cholesky_solve(self): |
@@ -4922,30 +4928,6 @@ def fn(*inputs): |
4922 | 4928 | setattr(TestAutogradDeviceType, test_name, do_test) |
4923 | 4929 |
|
4924 | 4930 | class TestAutogradComplex(TestCase): |
4925 | | - @skipIfNoLapack |
4926 | | - def test_complex_cholesky(self): |
4927 | | - def func(x, upper): |
4928 | | - x = 0.5 * (x + x.transpose(-1, -2).conj()) |
4929 | | - return torch.cholesky(x, upper) |
4930 | | - |
4931 | | - def run_test(upper, dims): |
4932 | | - x = torch.rand(*dims, requires_grad=True, dtype=torch.cdouble) |
4933 | | - x = x + torch.eye(dims[-1]) |
4934 | | - |
4935 | | - gradcheck(func, [x, upper]) |
4936 | | - # TODO: gradgradcheck does not work |
4937 | | - # gradgradcheck(func, [x, upper]) |
4938 | | - |
4939 | | - x = torch.rand(*dims, dtype=torch.cdouble) |
4940 | | - x = torch.matmul(x, x.transpose(-1, -2).conj()) |
4941 | | - x.requires_grad_() |
4942 | | - chol = x.cholesky().sum().backward() |
4943 | | - self.assertEqual(x.grad, x.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian |
4944 | | - |
4945 | | - for upper, dims in product([True, False], [(3, 3), (4, 3, 2, 2)]): |
4946 | | - run_test(upper, dims) |
4947 | | - run_test(upper, dims) |
4948 | | - |
4949 | 4931 | def test_view_func_for_complex_views(self): |
4950 | 4932 | # case 1: both parent and child have view_func |
4951 | 4933 | x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True) |
|
0 commit comments