Skip to content

Commit 2060f35

Browse files
zou3519apaszke
authored andcommitted
Fix python gc race condition with THPVariable_traverse (#4437)
1 parent 18a866a commit 2060f35

File tree

2 files changed

+14
-45
lines changed

2 files changed

+14
-45
lines changed

test/test_autograd.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -372,32 +372,6 @@ def bw_hook(grad):
372372
self.assertEqual(counter[0], 1, 'bw_hook not called')
373373
self.assertEqual(x.grad.data, torch.ones(5, 5) * 2)
374374

375-
@unittest.skipIf(sys.version_info[0] == 2, "Python 2 doesn't collect cycles involving __del__")
376-
def test_hooks_cycle(self):
377-
import gc
378-
counter = [0]
379-
380-
class GradHook(object):
381-
def __init__(self, var):
382-
self.var = var
383-
384-
def __del__(self):
385-
counter[0] += 1
386-
387-
def __call__(self, *args):
388-
pass
389-
390-
def run_test():
391-
x = Variable(torch.ones(5, 5), requires_grad=True)
392-
y = x * 2
393-
x.register_hook(GradHook(x))
394-
y.register_hook(GradHook(y))
395-
y._backward_hooks[1] = GradHook(y)
396-
397-
run_test()
398-
gc.collect()
399-
self.assertEqual(counter[0], 3)
400-
401375
def test_hook_none(self):
402376
# WARNING: this is a test for autograd internals.
403377
# You should never have to use such things in your code.

torch/csrc/autograd/python_variable.cpp

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -66,26 +66,21 @@ static int THPVariable_traverse(THPVariable *self, visitproc visit, void *arg)
6666
{
6767
Py_VISIT(self->data);
6868
Py_VISIT(self->backward_hooks);
69+
// We don't want to traverse the grad_fn, even if the Variable owns it and the
70+
// shared pointer's use count is 1. This is because we would need to treat
71+
// the grad_fn as part of the Python state and hold the GIL sometimes when
72+
// grad_fn's shared_ptr is copied, otherwise a race condition with the Python
73+
// GC could occur. Holding the GIL when the shared_ptr is copied adds
74+
// undesirable complexity/overhead.
75+
//
76+
// When hooks, a Variable, and its grad_fn are involved in a Python reference
77+
// cycle, because we're not traversing the grad_fn, the reference cycle will
78+
// in fact leak.
79+
//
80+
// See https://gist.github.com/zou3519/7ac92b84dd7d206dcc6eae55fee8372c
81+
// for more details about the race condition involving traversing the grad_fn
82+
// and the python GC.
6983
if (self->cdata.defined()) {
70-
// Only visit this if we actually own it (no one else use the shared pointer)
71-
auto& grad_fn = self->cdata.grad_fn();
72-
if (grad_fn.use_count() == 1) {
73-
if (auto fn = dynamic_cast<PyFunction*>(grad_fn.get())) {
74-
Py_VISIT(fn->obj);
75-
} else {
76-
// visit hooks in C++ implemented autograd functions
77-
for (auto& hook : grad_fn->pre_hooks) {
78-
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
79-
Py_VISIT(pyhook->dict);
80-
}
81-
}
82-
for (auto& hook : grad_fn->post_hooks) {
83-
if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
84-
Py_VISIT(pyhook->dict);
85-
}
86-
}
87-
}
88-
}
8984
for (auto& hook : self->cdata.hooks()) {
9085
if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
9186
Py_VISIT(pyhook->dict);

0 commit comments

Comments
 (0)