Skip to content

Commit 0bf5126

Browse files
committed
Test for refcycles
1 parent 2f4e03b commit 0bf5126

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

test/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
def run_tests():
3838
unittest.main(argv=UNITTEST_ARGS)
3939

40+
PY3 = sys.version_info > (3, 0)
41+
4042
IS_WINDOWS = sys.platform == "win32"
4143

4244
TEST_NUMPY = True

test/test_nn.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
module_tests, criterion_tests, TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
3131
TEST_CUDNN_VERSION, loss_reference_fns, get_size_average, get_weight
3232
from common import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, \
33-
TEST_SCIPY, download_file, IS_WINDOWS
33+
TEST_SCIPY, download_file, IS_WINDOWS, PY3
3434

3535
if TEST_SCIPY:
3636
from scipy import stats
@@ -1710,6 +1710,29 @@ def test_data_parallel_small_back(self):
17101710
out = dp.data_parallel(l, i, (0, 1))
17111711
self.assertEqual(out, l(i))
17121712

1713+
@unittest.skipIf(not TEST_MULTIGPU or not PY3, "multi-GPU not supported")
1714+
def test_data_parallel_model_no_refcycles(self):
1715+
# Python 2.7 will create reference cycles with the following
1716+
# Module on multiple GPUs, but Python 3 shouldn't unless
1717+
# there are refcycles on the PyTorch side (or the defined module)
1718+
import gc
1719+
1720+
class Model(nn.Module):
1721+
def __init__(self):
1722+
super(Model, self).__init__()
1723+
self.linear = nn.Linear(1, 1)
1724+
1725+
def forward(self, x):
1726+
return self.linear(x)
1727+
1728+
gc.collect()
1729+
model = nn.DataParallel(Model().cuda())
1730+
data = Variable(torch.randn(1).cuda())
1731+
model(data)
1732+
1733+
refcycles = gc.collect()
1734+
self.assertEqual(refcycles, 0)
1735+
17131736
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
17141737
def test_data_parallel_no_grad(self):
17151738
test = self

0 commit comments

Comments
 (0)