@@ -42,6 +42,15 @@ def iter_tensors(x, only_requiring_grad=False):
4242 for result in iter_tensors (elem , only_requiring_grad ):
4343 yield result
4444
45+ def complex_iter_tensors (x , only_requiring_grad = False ):
46+ if isinstance (x , torch .Tensor ):
47+ if (x .requires_grad or not only_requiring_grad ) and x .is_complex ():
48+ yield x
49+ elif isinstance (x , container_abcs .Iterable ) and not isinstance (x , str ):
50+ for elem in x :
51+ for result in iter_tensors (elem , only_requiring_grad ):
52+ yield result
53+
4554def get_numerical_jacobian (fn , input , target = None , eps = 1e-3 ):
4655 """
4756 input: input to `fn`
@@ -53,37 +62,44 @@ def get_numerical_jacobian(fn, input, target=None, eps=1e-3):
5362 if target is None :
5463 target = input
5564 output_size = fn (input ).numel ()
56- jacobian = make_jacobian (target , output_size )
57-
65+ real_jacobian = make_jacobian (target , output_size )
5866 # It's much easier to iterate over flattened lists of tensors.
5967 # These are reference to the same objects in jacobian, so any changes
6068 # will be reflected in it as well.
6169 x_tensors = iter_tensors (target , True )
62- j_tensors = iter_tensors (jacobian )
70+ real_j_tensors = iter_tensors (target )
6371
64- def compute_gradient (x , idx , is_mkldnn = False ):
72+ imag_jacobian = make_jacobian (target , output_size )
73+ imag_j_tensors = complex_iter_tensors (imag_jacobian )
74+
75+ contains_complex_input = imag_j_tensors is not None
76+
77+ def compute_gradient (x_tensor , x_idx , is_mkldnn = False , imag_delta = False ):
6578
6679 def fn_out (is_mkldnn = False ):
6780 if not is_mkldnn :
6881 return fn (input ).clone ()
6982 else :
70- return fn ([x .to_mkldnn ()])
83+ return fn ([x_tensor .to_mkldnn ()])
7184
72- orig = x [idx ].item ()
73- x [idx ] = orig - eps
85+ if imag_delta :
86+ delta = eps * 1j
87+ else :
88+ delta = eps
89+ orig = x_tensor [x_idx ].item ()
90+ x_tensor [x_idx ] = orig - eps
7491 outa = fn_out (is_mkldnn )
75- x [ idx ] = orig + eps
92+ x_tensor [ x_idx ] = orig + eps
7693 outb = fn_out (is_mkldnn )
7794 if not is_mkldnn :
78- x [ idx ] = orig
95+ x_tensor [ x_idx ] = orig
7996 r = (outb - outa ) / (2 * eps )
8097 return r .detach ().reshape (- 1 )
8198
8299 # TODO: compare structure
83- for x_tensor , d_tensor in zip (x_tensors , j_tensors ):
84- is_complex = x_tensor .dtype .is_complex
85- if is_complex :
86- eps *= (1 + 1j )
100+ for x_tensor , real_d_tensor , imag_d_tensor in zip (x_tensors , real_j_tensors , imag_j_tensors ):
101+ update_imag_d_tensor = contains_complex_input and x_tensor .dtype .is_complex
102+
87103 if x_tensor .is_sparse :
88104 def get_stride (size ):
89105 dim = len (size )
@@ -108,7 +124,9 @@ def get_stride(size):
108124 for x_idx in product (* [range (m ) for m in x_values .size ()[1 :]]):
109125 indices = x_indices [i ].tolist () + list (x_idx )
110126 d_idx = sum (indices [k ] * x_stride [k ] for k in range (len (x_size )))
111- d_tensor [d_idx ] = compute_gradient (x_value , x_idx )
127+ real_d_tensor [d_idx ] = compute_gradient (x_value , x_idx )
128+ if update_imag_d_tensor :
129+ imag_d_tensor [d_idx ] = compute_gradient (x_value , x_idx , imag_delta = True )
112130 elif x_tensor .layout == torch ._mkldnn :
113131 # Use .data here to get around the version check
114132 x_tensor = x_tensor .data
@@ -119,14 +137,21 @@ def get_stride(size):
119137 # this is really inefficient, but without indexing implemented, there's
120138 # not really a better way than converting back and forth
121139 x_tensor_dense = x_tensor .to_dense ()
122- d_tensor [d_idx ] = compute_gradient (x_tensor_dense , x_idx , is_mkldnn = True )
140+ real_d_tensor [d_idx ] = compute_gradient (x_tensor_dense , x_idx , is_mkldnn = True )
141+ if update_imag_d_tensor :
142+ imag_d_tensor [d_idx ] = compute_gradient (x_tensor_dense , x_idx , imag_delta = True , is_mkldnn = True )
123143 else :
124144 # Use .data here to get around the version check
125145 x_tensor = x_tensor .data
126146 for d_idx , x_idx in enumerate (product (* [range (m ) for m in x_tensor .size ()])):
127- d_tensor [d_idx ] = compute_gradient (x_tensor , x_idx )
147+ real_d_tensor [d_idx ] = compute_gradient (x_tensor , x_idx )
148+ if update_imag_d_tensor :
149+ imag_d_tensor [d_idx ] = compute_gradient (x_tensor , x_idx , imag_delta = True )
128150
129- return jacobian
151+ if contains_complex_input :
152+ return real_jacobian , imag_jacobian
153+ else :
154+ return real_jacobian , None
130155
131156
132157def get_analytical_jacobian (input , output , nondet_tol = 0.0 ):
@@ -285,7 +310,8 @@ def fail_test(msg):
285310 for i , o in enumerate (func_out ):
286311 def fn (input ):
287312 return _as_tuple (func (* input ))[i ]
288- numerical = get_numerical_jacobian (fn , tupled_inputs , eps = eps )
313+ # TODO: update this to also include check for ds_dy
314+ numerical = get_numerical_jacobian (fn , tupled_inputs , eps = eps )[0 ]
289315 for n in numerical :
290316 if torch .ne (n , 0 ).sum () > 0 :
291317 return fail_test ('Numerical gradient for function expected to be zero' )
@@ -299,17 +325,30 @@ def fn(input):
299325 return _as_tuple (func (* input ))[i ]
300326
301327 analytical , reentrant , correct_grad_sizes = get_analytical_jacobian (tupled_inputs , o , nondet_tol = nondet_tol )
302- numerical = get_numerical_jacobian (fn , tupled_inputs , eps = eps )
328+ numerical_real , numerical_imag = get_numerical_jacobian (fn , tupled_inputs , eps = eps )
329+ print (numerical_real , numerical_imag )
303330
304331 if not correct_grad_sizes :
305332 return fail_test ('Analytical gradient has incorrect size' )
306333
307- for j , (a , n ) in enumerate (zip (analytical , numerical )):
308- if a .numel () != 0 or n .numel () != 0 :
309- if not torch .allclose (a , n , rtol , atol ):
310- return fail_test ('Jacobian mismatch for output %d with respect to input %d,\n '
311- 'numerical:%s\n analytical:%s\n ' % (i , j , n , a ))
312-
334+ if numerical_imag is None :
335+ for j , (a , n ) in enumerate (zip (analytical , numerical_real )):
336+ if a .numel () != 0 or n .numel () != 0 :
337+ if not torch .allclose (a , n , rtol , atol ):
338+ return fail_test ('Jacobian mismatch for output %d with respect to input %d,\n '
339+ 'numerical:%s\n analytical:%s\n ' % (i , j , n , a ))
340+ else :
341+ for j , (a , n_re , n_im , inp ) in enumerate (zip (analytical , numerical_real , numerical_imag , tupled_inputs )):
342+ if a .numel () != 0 or n_re .numel () != 0 :
343+ if inp .is_complex ():
344+ ds_dz = 0.5 * (n_re + 1j * n_im )
345+ ds_dz_conj = 0.5 * (n_re - 1j * n_im ).conj ()
346+ dL_dz_conj = ds_dz + ds_dz_conj
347+ else :
348+ dL_dz_conj = n_re
349+ if not torch .allclose (a , dL_dz_conj , rtol , atol ):
350+ return fail_test ('Jacobian mismatch for output %d with respect to input %d,\n '
351+ 'numerical:%s\n analytical:%s\n ' % (i , j , dL_dz_conj , a ))
313352 if not reentrant :
314353 return fail_test ('Backward is not reentrant, i.e., running backward with same '
315354 'input and grad_output multiple times gives different values, '
0 commit comments