Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2669,6 +2669,13 @@ def backward(ctx, grad):
# check one of them is using the computed buffer
self.assertTrue(p_a == p_g or p_b == p_g)

def test_gradcheck_single_input(self):
def f(inp):
return inp.mul(5)

gradcheck(f, torch.rand(10, dtype=torch.float64, requires_grad=True))
gradgradcheck(f, torch.rand(10, dtype=torch.float64, requires_grad=True))


def index_variable(shape, max_indices):
if not isinstance(shape, tuple):
Expand Down
29 changes: 16 additions & 13 deletions torch/autograd/gradcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True
Args:
func (function): a Python function that takes Tensor inputs and returns
a Tensor or a tuple of Tensors
inputs (tuple of Tensor): inputs to the function
inputs (tuple of Tensor or Tensor): inputs to the function
eps (float, optional): perturbation for finite differences
atol (float, optional): absolute tolerance
rtol (float, optional): relative tolerance
Expand Down Expand Up @@ -178,7 +178,7 @@ def gradcheck(func, inputs, eps=1e-6, atol=1e-5, rtol=1e-3, raise_exception=True
'gradcheck expects at least one input tensor to require gradient, '
'but none of the them have requires_grad=True.')

output = _differentiable_outputs(func(*inputs))
output = _differentiable_outputs(func(*tupled_inputs))

def fail_test(msg):
if raise_exception:
Expand All @@ -193,7 +193,7 @@ def fn(input):
return _as_tuple(func(*input))[i]

analytical, reentrant, correct_grad_sizes = get_analytical_jacobian(tupled_inputs, o)
numerical = get_numerical_jacobian(fn, inputs, eps=eps)
numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps)

if not correct_grad_sizes:
return fail_test('Analytical gradient has incorrect size')
Expand All @@ -210,9 +210,9 @@ def fn(input):
'although analytical gradient matches numerical gradient')

# check if the backward multiplies by grad_output
output = _differentiable_outputs(func(*inputs))
output = _differentiable_outputs(func(*tupled_inputs))
if any([o.requires_grad for o in output]):
diff_input_list = list(iter_tensors(inputs, True))
diff_input_list = list(iter_tensors(tupled_inputs, True))
if not diff_input_list:
raise RuntimeError("no Tensors requiring grad found in input")
grads_input = torch.autograd.grad(output, diff_input_list, [torch.zeros_like(o) for o in output],
Expand Down Expand Up @@ -258,9 +258,9 @@ def gradgradcheck(func, inputs, grad_outputs=None, eps=1e-6, atol=1e-5, rtol=1e-
Args:
func (function): a Python function that takes Tensor inputs and returns
a Tensor or a tuple of Tensors
inputs (tuple of Tensor): inputs to the function
grad_outputs (tuple of Tensor, optional): The gradients with respect to
the function's outputs.
inputs (tuple of Tensor or Tensor): inputs to the function
grad_outputs (tuple of Tensor or Tensor, optional): The gradients with
respect to the function's outputs.
eps (float, optional): perturbation for finite differences
atol (float, optional): absolute tolerance
rtol (float, optional): relative tolerance
Expand All @@ -274,6 +274,8 @@ def gradgradcheck(func, inputs, grad_outputs=None, eps=1e-6, atol=1e-5, rtol=1e-
Returns:
True if all differences satisfy allclose condition
"""
tupled_inputs = _as_tuple(inputs)

if grad_outputs is None:
# If grad_outputs is not specified, create random Tensors of the same
# shape, type, and device as the outputs
Expand All @@ -282,11 +284,12 @@ def randn_like(x):
if gen_non_contig_grad_outputs:
y = torch.testing.make_non_contiguous(y)
return y.requires_grad_()
outputs = _as_tuple(func(*inputs))
grad_outputs_gen = (randn_like(x) for x in outputs)
grad_outputs = list(grad_outputs_gen) if not isinstance(inputs, tuple) else tuple(grad_outputs_gen)
outputs = _as_tuple(func(*tupled_inputs))
tupled_grad_outputs = tuple(randn_like(x) for x in outputs)
else:
tupled_grad_outputs = _as_tuple(grad_outputs)

num_outputs = len(grad_outputs)
num_outputs = len(tupled_grad_outputs)

def new_func(*args):
input_args = args[:-num_outputs]
Expand All @@ -296,4 +299,4 @@ def new_func(*args):
grad_inputs = torch.autograd.grad(outputs, input_args, grad_outputs, create_graph=True)
return grad_inputs

return gradcheck(new_func, inputs + grad_outputs, eps, atol, rtol, raise_exception)
return gradcheck(new_func, tupled_inputs + tupled_grad_outputs, eps, atol, rtol, raise_exception)