Skip to content

Commit 3a08a16

Browse files
committed
Complex autograd logic
ghstack-source-id: 6a2e586 Pull Request resolved: #43208
1 parent 55443c4 commit 3a08a16

File tree

1 file changed

+64
-25
lines changed

1 file changed

+64
-25
lines changed

torch/autograd/gradcheck.py

Lines changed: 64 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ def iter_tensors(x, only_requiring_grad=False):
4242
for result in iter_tensors(elem, only_requiring_grad):
4343
yield result
4444

45+
def complex_iter_tensors(x, only_requiring_grad=False):
46+
if isinstance(x, torch.Tensor):
47+
if (x.requires_grad or not only_requiring_grad) and x.is_complex():
48+
yield x
49+
elif isinstance(x, container_abcs.Iterable) and not isinstance(x, str):
50+
for elem in x:
51+
for result in iter_tensors(elem, only_requiring_grad):
52+
yield result
53+
4554
def get_numerical_jacobian(fn, input, target=None, eps=1e-3):
4655
"""
4756
input: input to `fn`
@@ -53,37 +62,44 @@ def get_numerical_jacobian(fn, input, target=None, eps=1e-3):
5362
if target is None:
5463
target = input
5564
output_size = fn(input).numel()
56-
jacobian = make_jacobian(target, output_size)
57-
65+
real_jacobian = make_jacobian(target, output_size)
5866
# It's much easier to iterate over flattened lists of tensors.
5967
# These are reference to the same objects in jacobian, so any changes
6068
# will be reflected in it as well.
6169
x_tensors = iter_tensors(target, True)
62-
j_tensors = iter_tensors(jacobian)
70+
real_j_tensors = iter_tensors(target)
6371

64-
def compute_gradient(x, idx, is_mkldnn=False):
72+
imag_jacobian = make_jacobian(target, output_size)
73+
imag_j_tensors = complex_iter_tensors(imag_jacobian)
74+
75+
contains_complex_input = imag_j_tensors is not None
76+
77+
def compute_gradient(x_tensor, x_idx, is_mkldnn=False, imag_delta=False):
6578

6679
def fn_out(is_mkldnn=False):
6780
if not is_mkldnn:
6881
return fn(input).clone()
6982
else:
70-
return fn([x.to_mkldnn()])
83+
return fn([x_tensor.to_mkldnn()])
7184

72-
orig = x[idx].item()
73-
x[idx] = orig - eps
85+
if imag_delta:
86+
delta = eps * 1j
87+
else:
88+
delta = eps
89+
orig = x_tensor[x_idx].item()
90+
x_tensor[x_idx] = orig - eps
7491
outa = fn_out(is_mkldnn)
75-
x[idx] = orig + eps
92+
x_tensor[x_idx] = orig + eps
7693
outb = fn_out(is_mkldnn)
7794
if not is_mkldnn:
78-
x[idx] = orig
95+
x_tensor[x_idx] = orig
7996
r = (outb - outa) / (2 * eps)
8097
return r.detach().reshape(-1)
8198

8299
# TODO: compare structure
83-
for x_tensor, d_tensor in zip(x_tensors, j_tensors):
84-
is_complex = x_tensor.dtype.is_complex
85-
if is_complex:
86-
eps *= (1 + 1j)
100+
for x_tensor, real_d_tensor, imag_d_tensor in zip(x_tensors, real_j_tensors, imag_j_tensors):
101+
update_imag_d_tensor = contains_complex_input and x_tensor.dtype.is_complex
102+
87103
if x_tensor.is_sparse:
88104
def get_stride(size):
89105
dim = len(size)
@@ -108,7 +124,9 @@ def get_stride(size):
108124
for x_idx in product(*[range(m) for m in x_values.size()[1:]]):
109125
indices = x_indices[i].tolist() + list(x_idx)
110126
d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size)))
111-
d_tensor[d_idx] = compute_gradient(x_value, x_idx)
127+
real_d_tensor[d_idx] = compute_gradient(x_value, x_idx)
128+
if update_imag_d_tensor:
129+
imag_d_tensor[d_idx] = compute_gradient(x_value, x_idx, imag_delta=True)
112130
elif x_tensor.layout == torch._mkldnn:
113131
# Use .data here to get around the version check
114132
x_tensor = x_tensor.data
@@ -119,14 +137,21 @@ def get_stride(size):
119137
# this is really inefficient, but without indexing implemented, there's
120138
# not really a better way than converting back and forth
121139
x_tensor_dense = x_tensor.to_dense()
122-
d_tensor[d_idx] = compute_gradient(x_tensor_dense, x_idx, is_mkldnn=True)
140+
real_d_tensor[d_idx] = compute_gradient(x_tensor_dense, x_idx, is_mkldnn=True)
141+
if update_imag_d_tensor:
142+
imag_d_tensor[d_idx] = compute_gradient(x_tensor_dense, x_idx, imag_delta=True, is_mkldnn=True)
123143
else:
124144
# Use .data here to get around the version check
125145
x_tensor = x_tensor.data
126146
for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
127-
d_tensor[d_idx] = compute_gradient(x_tensor, x_idx)
147+
real_d_tensor[d_idx] = compute_gradient(x_tensor, x_idx)
148+
if update_imag_d_tensor:
149+
imag_d_tensor[d_idx] = compute_gradient(x_tensor, x_idx, imag_delta=True)
128150

129-
return jacobian
151+
if contains_complex_input:
152+
return real_jacobian, imag_jacobian
153+
else:
154+
return real_jacobian, None
130155

131156

132157
def get_analytical_jacobian(input, output, nondet_tol=0.0):
@@ -285,7 +310,8 @@ def fail_test(msg):
285310
for i, o in enumerate(func_out):
286311
def fn(input):
287312
return _as_tuple(func(*input))[i]
288-
numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps)
313+
# TODO: update this to also include check for ds_dy
314+
numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps)[0]
289315
for n in numerical:
290316
if torch.ne(n, 0).sum() > 0:
291317
return fail_test('Numerical gradient for function expected to be zero')
@@ -299,17 +325,30 @@ def fn(input):
299325
return _as_tuple(func(*input))[i]
300326

301327
analytical, reentrant, correct_grad_sizes = get_analytical_jacobian(tupled_inputs, o, nondet_tol=nondet_tol)
302-
numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps)
328+
numerical_real, numerical_imag = get_numerical_jacobian(fn, tupled_inputs, eps=eps)
329+
print(numerical_real, numerical_imag)
303330

304331
if not correct_grad_sizes:
305332
return fail_test('Analytical gradient has incorrect size')
306333

307-
for j, (a, n) in enumerate(zip(analytical, numerical)):
308-
if a.numel() != 0 or n.numel() != 0:
309-
if not torch.allclose(a, n, rtol, atol):
310-
return fail_test('Jacobian mismatch for output %d with respect to input %d,\n'
311-
'numerical:%s\nanalytical:%s\n' % (i, j, n, a))
312-
334+
if numerical_imag is None:
335+
for j, (a, n) in enumerate(zip(analytical, numerical_real)):
336+
if a.numel() != 0 or n.numel() != 0:
337+
if not torch.allclose(a, n, rtol, atol):
338+
return fail_test('Jacobian mismatch for output %d with respect to input %d,\n'
339+
'numerical:%s\nanalytical:%s\n' % (i, j, n, a))
340+
else:
341+
for j, (a, n_re, n_im, inp) in enumerate(zip(analytical, numerical_real, numerical_imag, tupled_inputs)):
342+
if a.numel() != 0 or n_re.numel() != 0:
343+
if inp.is_complex():
344+
ds_dz = 0.5 * (n_re + 1j * n_im)
345+
ds_dz_conj = 0.5 * (n_re - 1j * n_im).conj()
346+
dL_dz_conj = ds_dz + ds_dz_conj
347+
else:
348+
dL_dz_conj = n_re
349+
if not torch.allclose(a, dL_dz_conj, rtol, atol):
350+
return fail_test('Jacobian mismatch for output %d with respect to input %d,\n'
351+
'numerical:%s\nanalytical:%s\n' % (i, j, dL_dz_conj, a))
313352
if not reentrant:
314353
return fail_test('Backward is not reentrant, i.e., running backward with same '
315354
'input and grad_output multiple times gives different values, '

0 commit comments

Comments
 (0)