Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,6 +1614,60 @@ def forward(self, input):
self.assertEqual(out.get_device(), 0)
self.assertEqual(out.data, expected_out)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_data_parallel_module_kwargs_only_empty_list(self):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l = l

def forward(self, input):
return self.l(input['data'])

l = nn.Linear(10, 5).float().cuda()
i = Variable(torch.randn(20, 10).float().cuda())
expected_out = l(i).data
n = nn.DataParallel(Net())
out = n(input={'data': i, 'unused': []})
self.assertEqual(out.get_device(), 0)
self.assertEqual(out.data, expected_out)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_data_parallel_module_kwargs_only_empty_dict(self):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l = l

def forward(self, input):
return self.l(input['data'])

l = nn.Linear(10, 5).float().cuda()
i = Variable(torch.randn(20, 10).float().cuda())
expected_out = l(i).data
n = nn.DataParallel(Net())
out = n(input={'data': i, 'unused': {}})
self.assertEqual(out.get_device(), 0)
self.assertEqual(out.data, expected_out)

@unittest.skipIf(not TEST_CUDA, "CUDA unavailable")
def test_data_parallel_module_kwargs_only_empty_tuple(self):
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l = l

def forward(self, input):
return self.l(input['data'])

l = nn.Linear(10, 5).float().cuda()
i = Variable(torch.randn(20, 10).float().cuda())
expected_out = l(i).data
n = nn.DataParallel(Net())
out = n(input={'data': i, 'unused': ()})
self.assertEqual(out.get_device(), 0)
self.assertEqual(out.data, expected_out)

def test_state_dict(self):
l = nn.Linear(5, 5)
block = nn.Module()
Expand Down
6 changes: 3 additions & 3 deletions torch/nn/parallel/scatter_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ def scatter_map(obj):
if isinstance(obj, Variable):
return Scatter.apply(target_gpus, None, dim, obj)
assert not torch.is_tensor(obj), "Tensors not supported in scatter."
if isinstance(obj, tuple):
if isinstance(obj, tuple) and len(obj) > 0:
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list):
if isinstance(obj, list) and len(obj) > 0:
return list(map(list, zip(*map(scatter_map, obj))))
if isinstance(obj, dict):
if isinstance(obj, dict) and len(obj) > 0:
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
return [obj for targets in target_gpus]

Expand Down