Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2054,6 +2054,18 @@ def test_gather_cpu(self):
def test_gather_gpu(self):
self._test_gather(0)

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_gather_different_len_dicts(self):
inputs = (
{'a': Variable(torch.randn(1, 2).cuda(0), requires_grad=True)},
{
'b': Variable(torch.randn(1, 2).cuda(1), requires_grad=True),
'a': Variable(torch.randn(1, 2).cuda(1), requires_grad=True)
}
)
with self.assertRaises(ValueError):
_ = dp.gather(inputs, target_device=0)

def _test_broadcast_double_backwards(self, *tensors):
variables = tuple(Variable(t, requires_grad=True) for t in tensors)
_assertGradAndGradgradChecks(self, lambda *i: Broadcast.apply((0, 1), *i), variables)
Expand Down Expand Up @@ -2297,7 +2309,10 @@ def test_data_parallel_sparse(self):
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_nested_output(self):
def fn(input):
return [input, (input.sin(), input.cos(), [input.add(1)]), input]
return [
input, (input.sin(), input.cos(), [input.add(1)]), input,
{'a': input, 'b': [input.sin()]}
]

class Net(nn.Module):
def forward(self, input):
Expand All @@ -2314,6 +2329,13 @@ def forward(self, input):
self.assertIsInstance(output[1][2], list)
self.assertIsInstance(output[1][2][0], Variable)
self.assertIsInstance(output[2], Variable)
self.assertIsInstance(output[3], dict)
self.assertEqual(len(output[3]), 2)
self.assertIn('a', output[3])
self.assertIn('b', output[3])
self.assertIsInstance(output[3]['a'], Variable)
self.assertIsInstance(output[3]['b'], list)
self.assertIsInstance(output[3]['b'][0], Variable)

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_data_parallel_nested_input(self):
Expand Down
5 changes: 5 additions & 0 deletions torch/nn/parallel/scatter_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def gather_map(outputs):
return Gather.apply(target_device, dim, *outputs)
if out is None:
return None
if isinstance(out, dict):
if not all((len(out) == len(d) for d in outputs)):
raise ValueError('All dicts must have the same number of keys')
return type(out)(((k, gather_map([d[k] for d in outputs]))
for k in out))
return type(out)(map(gather_map, zip(*outputs)))

# Recursive function calls like this create reference cycles.
Expand Down