Skip to content

Commit 4cec94d

Browse files
zou3519soumith
authored andcommitted
Fix python gc race condition with THPVariable_traverse (#4437)
1 parent d721743 commit 4cec94d

File tree

2 files changed

+14
-45
lines changed

2 files changed

+14
-45
lines changed

test/test_autograd.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -345,32 +345,6 @@ def bw_hook(grad):
345345
self.assertEqual(counter[0], 1, 'bw_hook not called')
346346
self.assertEqual(x.grad.data, torch.ones(5, 5) * 2)
347347

348-
@unittest.skipIf(sys.version_info[0] == 2, "Python 2 doesn't collect cycles involving __del__")
349-
def test_hooks_cycle(self):
350-
import gc
351-
counter = [0]
352-
353-
class GradHook(object):
354-
def __init__(self, var):
355-
self.var = var
356-
357-
def __del__(self):
358-
counter[0] += 1
359-
360-
def __call__(self, *args):
361-
pass
362-
363-
def run_test():
364-
x = Variable(torch.ones(5, 5), requires_grad=True)
365-
y = x * 2
366-
x.register_hook(GradHook(x))
367-
y.register_hook(GradHook(y))
368-
y._backward_hooks[1] = GradHook(y)
369-
370-
run_test()
371-
gc.collect()
372-
self.assertEqual(counter[0], 3)
373-
374348
def test_hook_none(self):
375349
# WARNING: this is a test for autograd internals.
376350
# You should never have to use such things in your code.

torch/csrc/autograd/python_variable.cpp

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -100,26 +100,21 @@ static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
100100
{
101101
Py_VISIT(self->data);
102102
Py_VISIT(self->backward_hooks);
103+
// We don't want to traverse the grad_fn, even if the Variable owns it and the
104+
// shared pointer's use count is 1. This is because we would need to treat
105+
// the grad_fn as part of the Python state and hold the GIL sometimes when
106+
// grad_fn's shared_ptr is copied, otherwise a race condition with the Python
107+
// GC could occur. Holding the GIL when the shared_ptr is copied adds
108+
// undesirable complexity/overhead.
109+
//
110+
// When hooks, a Variable, and its grad_fn are involved in a Python reference
111+
// cycle, because we're not traversing the grad_fn, the reference cycle will
112+
// in fact leak.
113+
//
114+
// See https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c
115+
// for more details about the race condition involving traversing the grad_fn
116+
// and the python GC.
103117
if (self->cdata.defined()) {
104-
// Only visit this if we actually own it (no one else use the shared pointer)
105-
auto& grad_fn = self->cdata.grad_fn();
106-
if (grad_fn.use_count() == 1) {
107-
if (auto fn = dynamic_cast<PyFunction*>(grad_fn.get())) {
108-
Py_VISIT(fn->obj);
109-
} else {
110-
// visit hooks in C++ implemented autograd functions
111-
for (auto& hook : grad_fn->pre_hooks) {
112-
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
113-
Py_VISIT(pyhook->dict);
114-
}
115-
}
116-
for (auto& hook : grad_fn->post_hooks) {
117-
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
118-
Py_VISIT(pyhook->dict);
119-
}
120-
}
121-
}
122-
}
123118
for (auto& hook : self->cdata.hooks()) {
124119
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
125120
Py_VISIT(pyhook->dict);

0 commit comments

Comments
 (0)