Skip to content

Commit 92a0f78

Browse files
mseitzerapaszke
authored andcommitted
Support returning dictionaries in DataParallel (#6113)
1 parent 0b17f4b commit 92a0f78

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

test/test_nn.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2047,6 +2047,18 @@ def test_gather_cpu(self):
20472047
def test_gather_gpu(self):
20482048
self._test_gather(0)
20492049

2050+
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
2051+
def test_gather_different_len_dicts(self):
2052+
inputs = (
2053+
{'a': Variable(torch.randn(1, 2).cuda(0), requires_grad=True)},
2054+
{
2055+
'b': Variable(torch.randn(1, 2).cuda(1), requires_grad=True),
2056+
'a': Variable(torch.randn(1, 2).cuda(1), requires_grad=True)
2057+
}
2058+
)
2059+
with self.assertRaises(ValueError):
2060+
_ = dp.gather(inputs, target_device=0)
2061+
20502062
def _test_broadcast_double_backwards(self, *tensors):
20512063
variables = tuple(Variable(t, requires_grad=True) for t in tensors)
20522064
_assertGradAndGradgradChecks(self, lambda *i: Broadcast.apply((0, 1), *i), variables)
@@ -2290,7 +2302,10 @@ def test_data_parallel_sparse(self):
22902302
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
22912303
def test_data_parallel_nested_output(self):
22922304
def fn(input):
2293-
return [input, (input.sin(), input.cos(), [input.add(1)]), input]
2305+
return [
2306+
input, (input.sin(), input.cos(), [input.add(1)]), input,
2307+
{'a': input, 'b': [input.sin()]}
2308+
]
22942309

22952310
class Net(nn.Module):
22962311
def forward(self, input):
@@ -2307,6 +2322,13 @@ def forward(self, input):
23072322
self.assertIsInstance(output[1][2], list)
23082323
self.assertIsInstance(output[1][2][0], Variable)
23092324
self.assertIsInstance(output[2], Variable)
2325+
self.assertIsInstance(output[3], dict)
2326+
self.assertEqual(len(output[3]), 2)
2327+
self.assertIn('a', output[3])
2328+
self.assertIn('b', output[3])
2329+
self.assertIsInstance(output[3]['a'], Variable)
2330+
self.assertIsInstance(output[3]['b'], list)
2331+
self.assertIsInstance(output[3]['b'][0], Variable)
23102332

23112333
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
23122334
def test_data_parallel_nested_input(self):

torch/nn/parallel/scatter_gather.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,11 @@ def gather_map(outputs):
5757
return Gather.apply(target_device, dim, *outputs)
5858
if out is None:
5959
return None
60+
if isinstance(out, dict):
61+
if not all((len(out) == len(d) for d in outputs)):
62+
raise ValueError('All dicts must have the same number of keys')
63+
return type(out)(((k, gather_map([d[k] for d in outputs]))
64+
for k in out))
6065
return type(out)(map(gather_map, zip(*outputs)))
6166

6267
# Recursive function calls like this create reference cycles.

0 commit comments

Comments
 (0)