Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,6 +1221,34 @@ def test_backward_no_grad(self):
with self.assertRaises(RuntimeError):
torch.autograd.backward([b], [None])

def test_backward_twice_with_saved_values(self):
b = torch.randn(3, requires_grad=True, dtype=torch.double)
c = torch.zeros(3, dtype=torch.double)
c[[1, 2]] = b[[1, 1]]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not completely obvious that this indexing expression causes a TensorList to be saved for backwards and that is what we're fixing in this PR (that backward twice with a saved TensorList works as expected). I don't see any other functions in derivatives.yaml that saves a TensorList, though, so I don't have ideas on how to make this test better.

c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
self.assertRaisesRegex(RuntimeError, 'Specify retain_graph=True',
lambda: c.backward(torch.tensor([1, 1, 1], dtype=torch.double)))

def test_backward_twice_retained_graph_with_saved_values(self):
b = torch.randn(3, requires_grad=True, dtype=torch.double)
c = torch.zeros(3, dtype=torch.double)
c[[1, 2]] = b[[1, 1]]
c.backward(torch.tensor([1, 1, 1], dtype=torch.double), retain_graph=True)
c.backward(torch.tensor([1, 1, 1], dtype=torch.double))

def test_backward_twice_without_saved_values(self):
b = torch.randn(3, requires_grad=True, dtype=torch.double)
c = b + 1
c.backward(torch.tensor([1, 1, 1], dtype=torch.double))
c.backward(torch.tensor([1, 1, 1], dtype=torch.double))

def test_backward_twice_retained_graph_without_saved_values(self):
b = torch.randn(3, requires_grad=True, dtype=torch.double)
c = torch.zeros(3, dtype=torch.double)
c[[1, 2]] = b[[1, 1]]
c.backward(torch.tensor([1, 1, 1], dtype=torch.double), retain_graph=True)
c.backward(torch.tensor([1, 1, 1], dtype=torch.double))

def test_next_functions(self):
x = torch.randn(5, 5, requires_grad=True)
y = torch.randn(5, 5, requires_grad=True)
Expand Down
8 changes: 8 additions & 0 deletions tools/autograd/gen_autograd_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

FUNCTION_DEFINITION = CodeTemplate("""\
variable_list ${op}::apply(variable_list&& grads) {
${asserts}
IndexRangeGenerator gen;
${compute_index_ranges}
variable_list grad_inputs(gen.size());
Expand Down Expand Up @@ -126,6 +127,7 @@ def process_function(func):
release_variables = []
saved_list_sizes = []
unpack = []
asserts = []

env['compute_index_ranges'] = []
for arg in func['args_with_derivatives']:
Expand All @@ -146,8 +148,13 @@ def save_arg(arg, is_output):
unpack.append('auto {} = {}_.unpack({});'.format(name, name, ptr))
elif arg['type'] == 'TensorList':
saved_variables.append('std::vector<SavedVariable> {}_;'.format(name))
saved_variables.append('bool {}_released_ = false;'.format(name))
# Just clear() is sufficient, we don't need to loop and clear each variable.
# Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
release_variables.append('{}_.clear();'.format(name))
release_variables.append('{}_released_ = true;'.format(name))
unpack.append('auto {} = unpack_list({}_);'.format(name, name))
asserts.append('TORCH_CHECK(!{}_released_, ERR_BACKWARD_TWICE);'.format(name))
elif arg['type'] == 'IntArrayRef':
saved_variables.append('std::vector<int64_t> {};'.format(name))
elif arg['type'] == 'int64_t':
Expand All @@ -162,6 +169,7 @@ def save_arg(arg, is_output):
env['saved_variables'] = saved_variables
env['release_variables'] = release_variables
env['saved_list_sizes'] = saved_list_sizes
env['asserts'] = asserts

if uses_retain_variables(func):
env['will_release_variables'] = WILL_RELEASE_VARIABLES.substitute()
Expand Down