|
30 | 30 | module_tests, criterion_tests, TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \ |
31 | 31 | TEST_CUDNN_VERSION, loss_reference_fns, get_size_average, get_weight |
32 | 32 | from common import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, \ |
33 | | - TEST_SCIPY, download_file, IS_WINDOWS |
| 33 | + TEST_SCIPY, download_file, IS_WINDOWS, PY3 |
34 | 34 |
|
35 | 35 | if TEST_SCIPY: |
36 | 36 | from scipy import stats |
@@ -1710,6 +1710,29 @@ def test_data_parallel_small_back(self): |
1710 | 1710 | out = dp.data_parallel(l, i, (0, 1)) |
1711 | 1711 | self.assertEqual(out, l(i)) |
1712 | 1712 |
|
| 1713 | + @unittest.skipIf(not TEST_MULTIGPU or not PY3, "multi-GPU not supported") |
| 1714 | + def test_data_parallel_model_no_refcycles(self): |
| 1715 | + # Python 2.7 will create reference cycles with the following |
| 1716 | + # Module on multiple GPUs, but Python 3 shouldn't unless |
| 1717 | + # there are refcycles on the PyTorch side (or the defined module) |
| 1718 | + import gc |
| 1719 | + |
| 1720 | + class Model(nn.Module): |
| 1721 | + def __init__(self): |
| 1722 | + super(Model, self).__init__() |
| 1723 | + self.linear = nn.Linear(1, 1) |
| 1724 | + |
| 1725 | + def forward(self, x): |
| 1726 | + return self.linear(x) |
| 1727 | + |
| 1728 | + gc.collect() |
| 1729 | + model = nn.DataParallel(Model().cuda()) |
| 1730 | + data = Variable(torch.randn(1).cuda()) |
| 1731 | + model(data) |
| 1732 | + |
| 1733 | + refcycles = gc.collect() |
| 1734 | + self.assertEqual(refcycles, 0) |
| 1735 | + |
1713 | 1736 | @unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported") |
1714 | 1737 | def test_data_parallel_no_grad(self): |
1715 | 1738 | test = self |
|
0 commit comments