Skip to content

Commit 865aebf

Browse files
committed
Complex autograd logic
ghstack-source-id: 2b25ff1 Pull Request resolved: #43208
1 parent 20607f2 commit 865aebf

File tree

3 files changed

+58
-46
lines changed

3 files changed

+58
-46
lines changed

test/test_autograd.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4579,18 +4579,19 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
45794579
# the tests for these ops which do not have 'complex' in variant should not run for complex
45804580
# and only run for floating point
45814581

4582-
separate_complex_tests = ['log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan']
4582+
separate_complex_tests = [] #['log', 'log10', 'log1p', 'log2', 'reciprocal', 'tan']
45834583

45844584
# NOTE: Some non-holomorphic are separately tested in TestAutogradComplex until gradcheck works properly
45854585
# for non-holomorphic functions
45864586

45874587
# allow list for complex
4588-
complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'zero_', 'clone',
4589-
'tril', 'triu', 'fill_', 'eq_', 'ne_', 'permute', 'squeeze', 'unsqueeze',
4590-
'chunk', 'split', 'split_with_sizes', 'resize', 'resize_as', 'sin', 'cos',
4591-
'__rmul__', '__rdiv__', 'sum', 'transpose', 'round', 'add', 'roll',
4592-
'__radd__', 'repeat', 'expand', 'mul', 'tanh', 'flip', 'fliplr', 'flipud',
4593-
'rot90'] + separate_complex_tests
4588+
complex_list = []
4589+
# complex_list = ['t', 'view', 'reshape', 'reshape_as', 'view_as', 'zero_', 'clone',
4590+
# 'tril', 'triu', 'fill_', 'eq_', 'ne_', 'permute', 'squeeze', 'unsqueeze',
4591+
# 'chunk', 'split', 'split_with_sizes', 'resize', 'resize_as', 'sin', 'cos',
4592+
# '__rmul__', '__rdiv__', 'sum', 'transpose', 'round', 'add', 'roll',
4593+
# '__radd__', 'repeat', 'expand', 'mul', 'tanh', 'flip', 'fliplr', 'flipud',
4594+
# 'rot90'] + separate_complex_tests
45944595

45954596
def add_test(
45964597
name,

torch/autograd/gradcheck.py

Lines changed: 48 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -53,38 +53,48 @@ def get_numerical_jacobian(fn, input, target=None, eps=1e-3):
5353
if target is None:
5454
target = input
5555
output_size = fn(input).numel()
56-
jacobian = make_jacobian(target, output_size)
57-
56+
w_jacobian = make_jacobian(target, output_size)
57+
conj_w_jacobian = make_jacobian(target, output_size)
5858
# It's much easier to iterate over flattened lists of tensors.
5959
# These are reference to the same objects in jacobian, so any changes
6060
# will be reflected in it as well.
6161
x_tensors = iter_tensors(target, True)
62-
j_tensors = iter_tensors(jacobian)
63-
64-
def compute_gradient(x, idx, is_mkldnn=False):
65-
66-
def fn_out():
67-
if not is_mkldnn:
68-
# x is a view into input and so this works
69-
return fn(input).clone()
70-
else:
71-
# convert the dense tensor back to have mkldnn layout
72-
return fn([x.to_mkldnn()])
73-
74-
orig = x[idx].item()
75-
x[idx] = orig - eps
76-
outa = fn_out()
77-
x[idx] = orig + eps
78-
outb = fn_out()
79-
x[idx] = orig
80-
r = (outb - outa) / (2 * eps)
81-
return r.detach().reshape(-1)
62+
w_j_tensors = iter_tensors(w_jacobian)
63+
conj_w_j_tensors = iter_tensors(conj_w_jacobian)
64+
65+
def update_jacobians(x, idx, d, conj_d, d_idx, is_mkldnn=False):
66+
67+
def compute_gradient(delta=eps):
68+
def fn_out():
69+
if not is_mkldnn:
70+
# x is a view into input and so this works
71+
return fn(input).clone()
72+
else:
73+
# convert the dense tensor back to have mkldnn layout
74+
return fn([x.to_mkldnn()])
75+
76+
orig = x[idx].item()
77+
x[idx] = orig - delta
78+
outa = fn_out()
79+
x[idx] = orig + delta
80+
outb = fn_out()
81+
x[idx] = orig
82+
r = (outb - outa) / (2 * eps)
83+
return r.detach().reshape(-1)
84+
85+
ds_dx = compute_gradient(delta=eps)
86+
if x.is_complex():
87+
ds_dy = compute_gradient(delta=(eps * 1j))
88+
d[d_idx] = 0.5 * (ds_dx - ds_dy * 1j)
89+
conj_d[d_idx] = 0.5 * (ds_dx + ds_dy * 1j)
90+
else:
91+
d[d_idx] = 0.5 * ds_dx
92+
conj_d[d_idx] = 0.5 * ds_dx
8293

8394
# TODO: compare structure
84-
for x_tensor, d_tensor in zip(x_tensors, j_tensors):
95+
for x_tensor, d_tensor, conj_d_tensor in zip(x_tensors, w_j_tensors, conj_w_j_tensors):
8596
is_complex = x_tensor.dtype.is_complex
86-
if is_complex:
87-
eps *= (1 + 1j)
97+
8898
if x_tensor.is_sparse:
8999
def get_stride(size):
90100
dim = len(size)
@@ -109,7 +119,7 @@ def get_stride(size):
109119
for x_idx in product(*[range(m) for m in x_values.size()[1:]]):
110120
indices = x_indices[i].tolist() + list(x_idx)
111121
d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size)))
112-
d_tensor[d_idx] = compute_gradient(x_value, x_idx)
122+
update_jacobians(x_value, x_idx, d_tensor, conj_d_tensor, d_idx)
113123
elif x_tensor.layout == torch._mkldnn:
114124
# Use .data here to get around the version check
115125
x_tensor = x_tensor.data
@@ -120,14 +130,14 @@ def get_stride(size):
120130
# this is really inefficient, but without indexing implemented, there's
121131
# not really a better way than converting back and forth
122132
x_tensor_dense = x_tensor.to_dense()
123-
d_tensor[d_idx] = compute_gradient(x_tensor_dense, x_idx, is_mkldnn=True)
133+
update_jacobians(x_tensor_dense, x_idx, d_tensor, conj_d_tensor, d_idx, is_mkldnn=True)
124134
else:
125135
# Use .data here to get around the version check
126136
x_tensor = x_tensor.data
127137
for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
128-
d_tensor[d_idx] = compute_gradient(x_tensor, x_idx)
138+
update_jacobians(x_tensor, x_idx, d_tensor, conj_d_tensor, d_idx)
129139

130-
return jacobian
140+
return w_jacobian, conj_w_jacobian
131141

132142

133143
def get_analytical_jacobian(input, output, nondet_tol=0.0):
@@ -286,8 +296,9 @@ def fail_test(msg):
286296
for i, o in enumerate(func_out):
287297
def fn(input):
288298
return _as_tuple(func(*input))[i]
289-
numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps)
290-
for n in numerical:
299+
numerical_w, numerical_conj_w = get_numerical_jacobian(fn, tupled_inputs, eps=eps)[0]
300+
# TODO: update this to also include check for numerical_conj_w
301+
for n in numerical_w:
291302
if torch.ne(n, 0).sum() > 0:
292303
return fail_test('Numerical gradient for function expected to be zero')
293304
return True
@@ -300,17 +311,17 @@ def fn(input):
300311
return _as_tuple(func(*input))[i]
301312

302313
analytical, reentrant, correct_grad_sizes = get_analytical_jacobian(tupled_inputs, o, nondet_tol=nondet_tol)
303-
numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps)
314+
numerical_w, numerical_conj_w = get_numerical_jacobian(fn, tupled_inputs, eps=eps)
304315

305316
if not correct_grad_sizes:
306317
return fail_test('Analytical gradient has incorrect size')
307318

308-
for j, (a, n) in enumerate(zip(analytical, numerical)):
309-
if a.numel() != 0 or n.numel() != 0:
310-
if not torch.allclose(a, n, rtol, atol):
319+
for j, (a, n_w, n_conj_w, inp) in enumerate(zip(analytical, numerical_w, numerical_conj_w, tupled_inputs)):
320+
if a.numel() != 0 or n_w.numel() != 0:
321+
dL_dz_conj = n_conj_w + n_w.conj()
322+
if not torch.allclose(a, dL_dz_conj, rtol, atol):
311323
return fail_test('Jacobian mismatch for output %d with respect to input %d,\n'
312-
'numerical:%s\nanalytical:%s\n' % (i, j, n, a))
313-
324+
'numerical:%s\nanalytical:%s\n' % (i, j, dL_dz_conj, a))
314325
if not reentrant:
315326
return fail_test('Backward is not reentrant, i.e., running backward with same '
316327
'input and grad_output multiple times gives different values, '

torch/testing/_internal/common_nn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4481,10 +4481,10 @@ def fw(input):
44814481

44824482
res = tuple()
44834483
if jacobian_input:
4484-
res += get_numerical_jacobian(fw, input, eps=1e-6),
4484+
res += get_numerical_jacobian(fw, input, eps=1e-6)[0],
44854485
if jacobian_parameters:
44864486
param, _ = self._get_parameters(module)
4487-
res += torch.cat([get_numerical_jacobian(fw, input, p, eps=1e-6) for p in param], 0),
4487+
res += torch.cat([get_numerical_jacobian(fw, input, p, eps=1e-6)[0] for p in param], 0),
44884488
return res
44894489

44904490
def check_jacobian(self, module, input, jacobian_input=True):

0 commit comments

Comments
 (0)