Skip to content

Commit c1fa449

Browse files
colesburyfacebook-github-bot
authored andcommitted
Break reference cycle in load_state_dict (#20397)
Summary: load_state_dict includes a recursive inner function `load` that captures Tensors through the close-over variable `state_dict`. Because it's recursive, it also captures itself leading to a reference cycle. This breaks the reference cycle so that any Tensors in state_dict can be collected immediately instead of waiting until the next GC cycle. Alternatively, we could have passed `state_dict` and `metadata` as arguments to load to prevent capture of Tensors. (That would still result in cyclic garbage, but not any cyclic garbage of Tensors). See: #20199 (comment) Pull Request resolved: #20397 Differential Revision: D15414834 Pulled By: colesbury fbshipit-source-id: 4c2275a08b2d8043deb3779db28be03bda15872d
1 parent 796e359 commit c1fa449

File tree

2 files changed

+14
-0
lines changed

2 files changed

+14
-0
lines changed

test/test_nn.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4396,6 +4396,19 @@ def test_load_state_dict_BC(self):
43964396
self.assertEqual(bn.num_batches_tracked.dtype, torch.long)
43974397
self.assertEqual(bn.num_batches_tracked.item(), 0)
43984398

4399+
@unittest.skipIf(not PY3, 'Python 2.7 generates cyclic trash')
4400+
def test_load_state_dict_ref_cycle(self):
4401+
# load_state_dict shouldn't cause a reference cycle involving Tensors
4402+
import gc
4403+
4404+
m = torch.nn.LSTM(16, 16, bidirectional=True)
4405+
4406+
gc.collect()
4407+
m.load_state_dict(deepcopy(m).state_dict())
4408+
refcycles = gc.collect()
4409+
4410+
self.assertEqual(refcycles, 0)
4411+
43994412
def test_parameter_assignment(self):
44004413
l = nn.Linear(5, 5)
44014414

torch/nn/modules/module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,7 @@ def load(module, prefix=''):
761761
load(child, prefix + name + '.')
762762

763763
load(self)
764+
load = None # break load->load reference cycle
764765

765766
if strict:
766767
if len(unexpected_keys) > 0:

0 commit comments

Comments
 (0)