Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
def run_tests():
unittest.main(argv=UNITTEST_ARGS)

PY3 = sys.version_info > (3, 0)

IS_WINDOWS = sys.platform == "win32"

TEST_NUMPY = True
Expand Down
25 changes: 24 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
module_tests, criterion_tests, TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
TEST_CUDNN_VERSION, loss_reference_fns, get_size_average, get_weight
from common import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, \
TEST_SCIPY, download_file, IS_WINDOWS
TEST_SCIPY, download_file, IS_WINDOWS, PY3

if TEST_SCIPY:
from scipy import stats
Expand Down Expand Up @@ -1710,6 +1710,29 @@ def test_data_parallel_small_back(self):
out = dp.data_parallel(l, i, (0, 1))
self.assertEqual(out, l(i))

@unittest.skipIf(not TEST_MULTIGPU or not PY3, "multi-GPU not supported")
def test_data_parallel_model_no_refcycles(self):
# Python 2.7 will create reference cycles with the following
# Module on multiple GPUs, but Python 3 shouldn't unless
# there are refcycles on the PyTorch side (or the defined module)
import gc

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear = nn.Linear(1, 1)

def forward(self, x):
return self.linear(x)

gc.collect()
model = nn.DataParallel(Model().cuda())
data = Variable(torch.randn(1).cuda())
model(data)

refcycles = gc.collect()
self.assertEqual(refcycles, 0)

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_no_grad(self):
test = self
Expand Down
18 changes: 16 additions & 2 deletions torch/nn/parallel/scatter_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,15 @@ def scatter_map(obj):
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return [obj for targets in target_gpus]

return scatter_map(inputs)
# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
return scatter_map(inputs)
finally:
scatter_map = None


def scatter_kwargs(inputs, kwargs, target_gpus, dim=0):
Expand Down Expand Up @@ -50,4 +58,10 @@ def gather_map(outputs):
if out is None:
return None
return type(out)(map(gather_map, zip(*outputs)))
return gather_map(outputs)

# Recursive function calls like this create reference cycles.
# Setting the function to None clears the refcycle.
try:
return gather_map(outputs)
finally:
gather_map = None