Skip to content

Commit c76be41

Browse files
committed
Fix DataParallel scattering for empty tuples
1 parent ece8244 commit c76be41

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

test/test_nn.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1650,6 +1650,24 @@ def forward(self, input):
16501650
self.assertEqual(out.get_device(), 0)
16511651
self.assertEqual(out.data, expected_out)
16521652

1653+
@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
1654+
def test_data_parallel_module_kwargs_only_empty_tuple(self):
1655+
class Net(nn.Module):
1656+
def __init__(self):
1657+
super(Net, self).__init__()
1658+
self.l = l
1659+
1660+
def forward(self, input):
1661+
return self.l(input['data'])
1662+
1663+
l = nn.Linear(10, 5).float().cuda()
1664+
i = Variable(torch.randn(20, 10).float().cuda())
1665+
expected_out = l(i).data
1666+
n = nn.DataParallel(Net())
1667+
out = n(input={'data': i, 'unused': ()})
1668+
self.assertEqual(out.get_device(), 0)
1669+
self.assertEqual(out.data, expected_out)
1670+
16531671
def test_state_dict(self):
16541672
l = nn.Linear(5, 5)
16551673
block = nn.Module()

torch/nn/parallel/scatter_gather.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def scatter_map(obj):
1414
if isinstance(obj, Variable):
1515
return Scatter.apply(target_gpus, None, dim, obj)
1616
assert not torch.is_tensor(obj), "Tensors not supported in scatter."
17-
if isinstance(obj, tuple):
17+
if isinstance(obj, tuple) and len(obj) > 0:
1818
return list(zip(*map(scatter_map, obj)))
1919
if isinstance(obj, list) and len(obj) > 0:
2020
return list(map(list, zip(*map(scatter_map, obj))))

0 commit comments

Comments
 (0)