Skip to content

Commit 528158a

Browse files
IvanYashchukfacebook-github-bot
authored andcommitted
Updated derivatives for complex mm, mv, ger, bmm, triangular_solve (#45737)
Summary: This PR updates derivatives for a few functions so that `gradgradcheck` for `torch.cholesky` is passed ([ref](#45267 (comment))). Some tests (that call to `bmm_cuda`) fail with with `RuntimeError: _th_bmm_out not supported on CUDAType for ComplexDouble` until PR #42553 is merged. Ref. #33152 Pull Request resolved: #45737 Reviewed By: bdhirsh Differential Revision: D24279917 Pulled By: anjali411 fbshipit-source-id: 7b696d2cfc2ef714332c2e3e5d207e257be67744
1 parent 7f458e1 commit 528158a

File tree

4 files changed

+33
-24
lines changed

4 files changed

+33
-24
lines changed

test/test_autograd.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2509,9 +2509,7 @@ def run_test(upper, dims, dtype):
25092509
root = root + torch.eye(dims[-1])
25102510

25112511
gradcheck(func, [root, upper])
2512-
# TODO: gradgradcheck does not work correctly yet for complex
2513-
if not dtype.is_complex:
2514-
gradgradcheck(func, [root, upper])
2512+
gradgradcheck(func, [root, upper])
25152513

25162514
root = torch.rand(*dims, dtype=dtype)
25172515
root = torch.matmul(root, root.transpose(-1, -2).conj())
@@ -2684,9 +2682,9 @@ def func(A, upper):
26842682

26852683
@skipIfNoLapack
26862684
def test_triangular_solve(self):
2687-
def _test_with_size(A_dims, B_dims):
2688-
A = torch.rand(*A_dims).requires_grad_()
2689-
b = torch.rand(*B_dims).requires_grad_()
2685+
def run_test(A_dims, B_dims, dtype):
2686+
A = torch.rand(*A_dims, dtype=dtype).requires_grad_()
2687+
b = torch.rand(*B_dims, dtype=dtype).requires_grad_()
26902688

26912689
for upper, transpose, unitriangular in product((True, False), repeat=3):
26922690
def func(A, b):
@@ -2695,10 +2693,11 @@ def func(A, b):
26952693
gradcheck(func, [A, b])
26962694
gradgradcheck(func, [A, b])
26972695

2698-
_test_with_size((3, 3), (3, 4))
2699-
_test_with_size((3, 3), (3, 2))
2700-
_test_with_size((2, 3, 3), (2, 3, 4))
2701-
_test_with_size((2, 3, 3), (2, 3, 2))
2696+
for dtype in (torch.double, torch.cdouble):
2697+
run_test((3, 3), (3, 4), dtype)
2698+
run_test((3, 3), (3, 2), dtype)
2699+
run_test((2, 3, 3), (2, 3, 4), dtype)
2700+
run_test((2, 3, 3), (2, 3, 2), dtype)
27022701

27032702
@unittest.skipIf(not TEST_MKL, "PyTorch is built without MKL support")
27042703
def test_fft_ifft_rfft_irfft(self):
@@ -4833,7 +4832,11 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
48334832
'permute', 'squeeze', 'unsqueeze', 'resize', 'resize_as', 'tril', 'triu',
48344833
'chunk', 'split', 'split_with_sizes', 'repeat', 'expand', 'zero_',
48354834
'eq_', 'ne_', 'add', '__radd__', 'sum', 'conj', 'sin', 'cos', 'mul', 'sinh',
4836-
'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split'] + separate_complex_tests
4835+
'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split',
4836+
'matmul', 'bmm', 'mv', 'ger', 'diagonal', ] + separate_complex_tests
4837+
4838+
# this list corresponds to cases that are not currently implemented
4839+
skip_cuda_list = ['bmm_complex', 'matmul_4d_4d_complex']
48374840

48384841
# TODO(@anjali411): add tests for 'sub', 'div
48394842
# TODO(@anjali411): add the commented tests back after updating the formula based on tensorflow definition - @anjali411
@@ -5019,6 +5022,11 @@ def fn(*inputs):
50195022
for skip in skipTestIf:
50205023
do_test = skip(do_test)
50215024

5025+
# TODO: remove this once tests from skip_cuda_list work
5026+
do_test = skipCUDAIf(
5027+
any(skip_test in test_name for skip_test in skip_cuda_list),
5028+
"not implemented for CUDA yet")(do_test)
5029+
50225030
setattr(TestAutogradDeviceType, test_name, do_test)
50235031

50245032
class TestAutogradComplex(TestCase):

tools/autograd/derivatives.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,8 @@
275275
self: zeros_like(grad)
276276

277277
- name: bmm(Tensor self, Tensor mat2) -> Tensor
278-
self: grad.bmm(mat2.transpose(1, 2))
279-
mat2: self.transpose(1, 2).bmm(grad)
278+
self: grad.bmm(mat2.transpose(1, 2).conj())
279+
mat2: self.transpose(1, 2).conj().bmm(grad)
280280

281281
- name: _bmm(Tensor self, Tensor mat2, *, bool deterministic=False) -> Tensor
282282
self: at::_bmm(grad, mat2.transpose(1, 2), deterministic)
@@ -498,8 +498,8 @@
498498
self: not_implemented("geqrf")
499499

500500
- name: ger(Tensor self, Tensor vec2) -> Tensor
501-
self: grad.mv(vec2)
502-
vec2: grad.t().mv(self)
501+
self: grad.mv(vec2.conj())
502+
vec2: grad.t().mv(self.conj())
503503

504504
- name: indices(Tensor(a) self) -> Tensor(a)
505505
output_differentiability: [False]
@@ -749,8 +749,8 @@
749749
self: mul_tensor_backward(grad, at::scalar_to_tensor(other), self.scalar_type())
750750

751751
- name: mv(Tensor self, Tensor vec) -> Tensor
752-
self: grad.ger(vec)
753-
vec: self.t().mv(grad)
752+
self: grad.ger(vec.conj())
753+
vec: self.conj().t().mv(grad)
754754

755755
- name: mvlgamma(Tensor self, int p) -> Tensor
756756
self: mvlgamma_backward(grad, self, p)

tools/autograd/gen_variable_type.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@
164164
'cosh', '__rmul__', 'sgn', 'asin', 'acos', 'sub', 'div', 'cat', 'view_as_complex',
165165
'neg', 'complex', 'select', '_s_where', 'as_strided', 'slice', 'constant_pad_nd',
166166
'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
167-
'dot', 'vdot', 'cholesky'
167+
'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger',
168+
'bmm', 'diagonal'
168169
}
169170

170171
# Some operators invalidate the grad_accumulator. Let's reset it.

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -530,9 +530,9 @@ Tensor mm_mat1_backward(const Tensor & grad, const Tensor & mat2, const Tensor &
530530
at::IntArrayRef sizes = mat1.sizes();
531531
at::IntArrayRef strides = mat1.strides();
532532
if (strides[0] == 1 && strides[1] == sizes[0]) {
533-
return maybe_multiply(mat2.mm(grad.t()).t(), alpha);
533+
return maybe_multiply(mat2.conj().mm(grad.t()).t(), alpha);
534534
} else {
535-
return maybe_multiply(grad.mm(mat2.t()), alpha);
535+
return maybe_multiply(grad.mm(mat2.t().conj()), alpha);
536536
}
537537
}
538538

@@ -550,9 +550,9 @@ Tensor mm_mat2_backward(const Tensor & grad, const Tensor & mat1, IntArrayRef si
550550
at::addmm_out(r, t, mat1.t(), grad, alpha, 1);
551551
return r;
552552
}
553-
return maybe_multiply(grad.t().mm(mat1).t(), alpha);
553+
return maybe_multiply(grad.t().mm(mat1.conj()).t(), alpha);
554554
} else {
555-
return maybe_multiply(mat1.t().mm(grad), alpha);
555+
return maybe_multiply(mat1.t().conj().mm(grad), alpha);
556556
}
557557
}
558558

@@ -2123,9 +2123,9 @@ std::tuple<Tensor, Tensor> triangular_solve_backward(
21232123
Tensor grad_b, grad_a;
21242124
if (grad_x.defined() || grad_m.defined()) {
21252125
if (grad_x.defined()) {
2126-
grad_b = std::get<0>(grad_x.triangular_solve(a, upper, !transpose, unitriangular));
2126+
grad_b = std::get<0>(grad_x.triangular_solve(a.conj(), upper, !transpose, unitriangular));
21272127
if (output_mask[1]) {
2128-
grad_a = transpose ? -x.matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2));
2128+
grad_a = transpose ? -x.conj().matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2).conj());
21292129
if (upper) {
21302130
grad_a = grad_a.triu((int) unitriangular);
21312131
} else {

0 commit comments

Comments
 (0)