@@ -53,38 +53,48 @@ def get_numerical_jacobian(fn, input, target=None, eps=1e-3):
5353 if target is None :
5454 target = input
5555 output_size = fn (input ).numel ()
56- jacobian = make_jacobian (target , output_size )
57-
56+ w_jacobian = make_jacobian (target , output_size )
57+ conj_w_jacobian = make_jacobian ( target , output_size )
5858 # It's much easier to iterate over flattened lists of tensors.
5959 # These are reference to the same objects in jacobian, so any changes
6060 # will be reflected in it as well.
6161 x_tensors = iter_tensors (target , True )
62- j_tensors = iter_tensors (jacobian )
63-
64- def compute_gradient (x , idx , is_mkldnn = False ):
65-
66- def fn_out ():
67- if not is_mkldnn :
68- # x is a view into input and so this works
69- return fn (input ).clone ()
70- else :
71- # convert the dense tensor back to have mkldnn layout
72- return fn ([x .to_mkldnn ()])
73-
74- orig = x [idx ].item ()
75- x [idx ] = orig - eps
76- outa = fn_out ()
77- x [idx ] = orig + eps
78- outb = fn_out ()
79- x [idx ] = orig
80- r = (outb - outa ) / (2 * eps )
81- return r .detach ().reshape (- 1 )
62+ w_j_tensors = iter_tensors (w_jacobian )
63+ conj_w_j_tensors = iter_tensors (conj_w_jacobian )
64+
65+ def update_jacobians (x , idx , d , conj_d , d_idx , is_mkldnn = False ):
66+
67+ def compute_gradient (delta = eps ):
68+ def fn_out ():
69+ if not is_mkldnn :
70+ # x is a view into input and so this works
71+ return fn (input ).clone ()
72+ else :
73+ # convert the dense tensor back to have mkldnn layout
74+ return fn ([x .to_mkldnn ()])
75+
76+ orig = x [idx ].item ()
77+ x [idx ] = orig - delta
78+ outa = fn_out ()
79+ x [idx ] = orig + delta
80+ outb = fn_out ()
81+ x [idx ] = orig
82+ r = (outb - outa ) / (2 * eps )
83+ return r .detach ().reshape (- 1 )
84+
85+ ds_dx = compute_gradient (delta = eps )
86+ if x .is_complex ():
87+ ds_dy = compute_gradient (delta = (eps * 1j ))
88+ d [d_idx ] = 0.5 * (ds_dx - ds_dy * 1j )
89+ conj_d [d_idx ] = 0.5 * (ds_dx + ds_dy * 1j )
90+ else :
91+ d [d_idx ] = 0.5 * ds_dx
92+ conj_d [d_idx ] = 0.5 * ds_dx
8293
8394 # TODO: compare structure
84- for x_tensor , d_tensor in zip (x_tensors , j_tensors ):
95+ for x_tensor , d_tensor , conj_d_tensor in zip (x_tensors , w_j_tensors , conj_w_j_tensors ):
8596 is_complex = x_tensor .dtype .is_complex
86- if is_complex :
87- eps *= (1 + 1j )
97+
8898 if x_tensor .is_sparse :
8999 def get_stride (size ):
90100 dim = len (size )
@@ -109,7 +119,7 @@ def get_stride(size):
109119 for x_idx in product (* [range (m ) for m in x_values .size ()[1 :]]):
110120 indices = x_indices [i ].tolist () + list (x_idx )
111121 d_idx = sum (indices [k ] * x_stride [k ] for k in range (len (x_size )))
112- d_tensor [ d_idx ] = compute_gradient (x_value , x_idx )
122+ update_jacobians (x_value , x_idx , d_tensor , conj_d_tensor , d_idx )
113123 elif x_tensor .layout == torch ._mkldnn :
114124 # Use .data here to get around the version check
115125 x_tensor = x_tensor .data
@@ -120,14 +130,14 @@ def get_stride(size):
120130 # this is really inefficient, but without indexing implemented, there's
121131 # not really a better way than converting back and forth
122132 x_tensor_dense = x_tensor .to_dense ()
123- d_tensor [ d_idx ] = compute_gradient (x_tensor_dense , x_idx , is_mkldnn = True )
133+ update_jacobians (x_tensor_dense , x_idx , d_tensor , conj_d_tensor , d_idx , is_mkldnn = True )
124134 else :
125135 # Use .data here to get around the version check
126136 x_tensor = x_tensor .data
127137 for d_idx , x_idx in enumerate (product (* [range (m ) for m in x_tensor .size ()])):
128- d_tensor [ d_idx ] = compute_gradient (x_tensor , x_idx )
138+ update_jacobians (x_tensor , x_idx , d_tensor , conj_d_tensor , d_idx )
129139
130- return jacobian
140+ return w_jacobian , conj_w_jacobian
131141
132142
133143def get_analytical_jacobian (input , output , nondet_tol = 0.0 ):
@@ -286,8 +296,9 @@ def fail_test(msg):
286296 for i , o in enumerate (func_out ):
287297 def fn (input ):
288298 return _as_tuple (func (* input ))[i ]
289- numerical = get_numerical_jacobian (fn , tupled_inputs , eps = eps )
290- for n in numerical :
299+ numerical_w , numerical_conj_w = get_numerical_jacobian (fn , tupled_inputs , eps = eps )[0 ]
300+ # TODO: update this to also include check for numerical_conj_w
301+ for n in numerical_w :
291302 if torch .ne (n , 0 ).sum () > 0 :
292303 return fail_test ('Numerical gradient for function expected to be zero' )
293304 return True
@@ -300,17 +311,17 @@ def fn(input):
300311 return _as_tuple (func (* input ))[i ]
301312
302313 analytical , reentrant , correct_grad_sizes = get_analytical_jacobian (tupled_inputs , o , nondet_tol = nondet_tol )
303- numerical = get_numerical_jacobian (fn , tupled_inputs , eps = eps )
314+ numerical_w , numerical_conj_w = get_numerical_jacobian (fn , tupled_inputs , eps = eps )
304315
305316 if not correct_grad_sizes :
306317 return fail_test ('Analytical gradient has incorrect size' )
307318
308- for j , (a , n ) in enumerate (zip (analytical , numerical )):
309- if a .numel () != 0 or n .numel () != 0 :
310- if not torch .allclose (a , n , rtol , atol ):
319+ for j , (a , n_w , n_conj_w , inp ) in enumerate (zip (analytical , numerical_w , numerical_conj_w , tupled_inputs )):
320+ if a .numel () != 0 or n_re .numel () != 0 :
321+ dL_dz_conj = n_conj_w + n_w .conj ()
322+ if not torch .allclose (a , dL_dz_conj , rtol , atol ):
311323 return fail_test ('Jacobian mismatch for output %d with respect to input %d,\n '
312- 'numerical:%s\n analytical:%s\n ' % (i , j , n , a ))
313-
324+ 'numerical:%s\n analytical:%s\n ' % (i , j , dL_dz_conj , a ))
314325 if not reentrant :
315326 return fail_test ('Backward is not reentrant, i.e., running backward with same '
316327 'input and grad_output multiple times gives different values, '
0 commit comments