Skip to content

Commit 840760c

Browse files
pemazaresoumith
authored andcommitted
Fix DataParallel scattering for empty lists / dicts / tuples (#3769)
* Fix DataParallel scattering for empty lists and dicts * Fix DataParallel scattering for empty tuples
1 parent ee24a05 commit 840760c

File tree

2 files changed

+57
-3
lines changed

2 files changed

+57
-3
lines changed

test/test_nn.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1609,6 +1609,60 @@ def forward(self, input):
16091609
self.assertEqual(out.get_device(), 0)
16101610
self.assertEqual(out.data, expected_out)
16111611

1612+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1613+
def test_data_parallel_module_kwargs_only_empty_list(self):
1614+
class Net(nn.Module):
1615+
def __init__(self):
1616+
super(Net, self).__init__()
1617+
self.l = l
1618+
1619+
def forward(self, input):
1620+
return self.l(input['data'])
1621+
1622+
l = nn.Linear(10, 5).float().cuda()
1623+
i = Variable(torch.randn(20, 10).float().cuda())
1624+
expected_out = l(i).data
1625+
n = nn.DataParallel(Net())
1626+
out = n(input={'data': i, 'unused': []})
1627+
self.assertEqual(out.get_device(), 0)
1628+
self.assertEqual(out.data, expected_out)
1629+
1630+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1631+
def test_data_parallel_module_kwargs_only_empty_dict(self):
1632+
class Net(nn.Module):
1633+
def __init__(self):
1634+
super(Net, self).__init__()
1635+
self.l = l
1636+
1637+
def forward(self, input):
1638+
return self.l(input['data'])
1639+
1640+
l = nn.Linear(10, 5).float().cuda()
1641+
i = Variable(torch.randn(20, 10).float().cuda())
1642+
expected_out = l(i).data
1643+
n = nn.DataParallel(Net())
1644+
out = n(input={'data': i, 'unused': {}})
1645+
self.assertEqual(out.get_device(), 0)
1646+
self.assertEqual(out.data, expected_out)
1647+
1648+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1649+
def test_data_parallel_module_kwargs_only_empty_tuple(self):
1650+
class Net(nn.Module):
1651+
def __init__(self):
1652+
super(Net, self).__init__()
1653+
self.l = l
1654+
1655+
def forward(self, input):
1656+
return self.l(input['data'])
1657+
1658+
l = nn.Linear(10, 5).float().cuda()
1659+
i = Variable(torch.randn(20, 10).float().cuda())
1660+
expected_out = l(i).data
1661+
n = nn.DataParallel(Net())
1662+
out = n(input={'data': i, 'unused': ()})
1663+
self.assertEqual(out.get_device(), 0)
1664+
self.assertEqual(out.data, expected_out)
1665+
16121666
def test_state_dict(self):
16131667
l = nn.Linear(5, 5)
16141668
block = nn.Module()

torch/nn/parallel/scatter_gather.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ def scatter_map(obj):
1414
if isinstance(obj, Variable):
1515
return Scatter.apply(target_gpus, None, dim, obj)
1616
assert not torch.is_tensor(obj), "Tensors not supported in scatter."
17-
if isinstance(obj, tuple):
17+
if isinstance(obj, tuple) and len(obj) > 0:
1818
return list(zip(*map(scatter_map, obj)))
19-
if isinstance(obj, list):
19+
if isinstance(obj, list) and len(obj) > 0:
2020
return list(map(list, zip(*map(scatter_map, obj))))
21-
if isinstance(obj, dict):
21+
if isinstance(obj, dict) and len(obj) > 0:
2222
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
2323
return [obj for targets in target_gpus]
2424

0 commit comments

Comments
 (0)