Skip to content

Commit da3c4cb

Browse files
zou3519soumith
authored andcommitted
Fix refcycles in DataParallel scatter and gather (#4988)
* Eliminate reference cycles in scatter_gather * Test for refcycles * Better fix * Add comments
1 parent 30bd9e6 commit da3c4cb

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

torch/nn/parallel/scatter_gather.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,15 @@ def scatter_map(obj):
2222
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
2323
return [obj for targets in target_gpus]
2424

25-
return scatter_map(inputs)
25+
# After scatter_map is called, a scatter_map cell will exist. This cell
26+
# has a reference to the actual function scatter_map, which has references
27+
# to a closure that has a reference to the scatter_map cell (because the
28+
# fn is recursive). To avoid this reference cycle, we set the function to
29+
# None, clearing the cell
30+
try:
31+
return scatter_map(inputs)
32+
finally:
33+
scatter_map = None
2634

2735

2836
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
@@ -50,4 +58,10 @@ def gather_map(outputs):
5058
if out is None:
5159
return None
5260
return type(out)(map(gather_map, zip(*outputs)))
53-
return gather_map(outputs)
61+
62+
# Recursive function calls like this create reference cycles.
63+
# Setting the function to None clears the refcycle.
64+
try:
65+
return gather_map(outputs)
66+
finally:
67+
gather_map = None

0 commit comments

Comments
 (0)