Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,32 +372,6 @@ def bw_hook(grad):
self.assertEqual(counter[0], 1, 'bw_hook not called')
self.assertEqual(x.grad.data, torch.ones(5, 5) * 2)

@unittest.skipIf(sys.version_info[0] == 2, "Python 2 doesn't collect cycles involving __del__")
def test_hooks_cycle(self):
import gc
counter = [0]

class GradHook(object):
def __init__(self, var):
self.var = var

def __del__(self):
counter[0] += 1

def __call__(self, *args):
pass

def run_test():
x = Variable(torch.ones(5, 5), requires_grad=True)
y = x * 2
x.register_hook(GradHook(x))
y.register_hook(GradHook(y))
y._backward_hooks[1] = GradHook(y)

run_test()
gc.collect()
self.assertEqual(counter[0], 3)

def test_hook_none(self):
# WARNING: this is a test for autograd internals.
# You should never have to use such things in your code.
Expand Down
33 changes: 14 additions & 19 deletions torch/csrc/autograd/python_variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,21 @@ static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
{
Py_VISIT(self->data);
Py_VISIT(self->backward_hooks);
// We don't want to traverse the grad_fn, even if the Variable owns it and the
// shared pointer's use count is 1. This is because we would need to treat
// the grad_fn as part of the Python state and hold the GIL sometimes when
// grad_fn's shared_ptr is copied, otherwise a race condition with the Python
// GC could occur. Holding the GIL when the shared_ptr is copied adds
// undesirable complexity/overhead.
//
// When hooks, a Variable, and its grad_fn are involved in a Python reference
// cycle, because we're not traversing the grad_fn, the reference cycle will
// in fact leak.
//
// See https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c
// for more details about the race condition involving traversing the grad_fn
// and the python GC.
if (self->cdata.defined()) {
// Only visit this if we actually own it (no one else use the shared pointer)
auto& grad_fn = self->cdata.grad_fn();
if (grad_fn.use_count() == 1) {
if (auto fn = dynamic_cast<PyFunction*>(grad_fn.get())) {
Py_VISIT(fn->obj);
} else {
// visit hooks in C++ implemented autograd functions
for (auto& hook : grad_fn->pre_hooks) {
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
for (auto& hook : grad_fn->post_hooks) {
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
}
}
}
}
for (auto& hook : self->cdata.hooks()) {
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
Py_VISIT(pyhook->dict);
Expand Down