@@ -2047,6 +2047,18 @@ def test_gather_cpu(self):
20472047 def test_gather_gpu (self ):
20482048 self ._test_gather (0 )
20492049
2050+ @unittest .skipIf (not TEST_MULTIGPU , "multi-GPU not supported" )
2051+ def test_gather_different_len_dicts (self ):
2052+ inputs = (
2053+ {'a' : Variable (torch .randn (1 , 2 ).cuda (0 ), requires_grad = True )},
2054+ {
2055+ 'b' : Variable (torch .randn (1 , 2 ).cuda (1 ), requires_grad = True ),
2056+ 'a' : Variable (torch .randn (1 , 2 ).cuda (1 ), requires_grad = True )
2057+ }
2058+ )
2059+ with self .assertRaises (ValueError ):
2060+ _ = dp .gather (inputs , target_device = 0 )
2061+
20502062 def _test_broadcast_double_backwards (self , * tensors ):
20512063 variables = tuple (Variable (t , requires_grad = True ) for t in tensors )
20522064 _assertGradAndGradgradChecks (self , lambda * i : Broadcast .apply ((0 , 1 ), * i ), variables )
@@ -2290,7 +2302,10 @@ def test_data_parallel_sparse(self):
22902302 @unittest .skipIf (not TEST_MULTIGPU , "multi-GPU not supported" )
22912303 def test_data_parallel_nested_output (self ):
22922304 def fn (input ):
2293- return [input , (input .sin (), input .cos (), [input .add (1 )]), input ]
2305+ return [
2306+ input , (input .sin (), input .cos (), [input .add (1 )]), input ,
2307+ {'a' : input , 'b' : [input .sin ()]}
2308+ ]
22942309
22952310 class Net (nn .Module ):
22962311 def forward (self , input ):
@@ -2307,6 +2322,13 @@ def forward(self, input):
23072322 self .assertIsInstance (output [1 ][2 ], list )
23082323 self .assertIsInstance (output [1 ][2 ][0 ], Variable )
23092324 self .assertIsInstance (output [2 ], Variable )
2325+ self .assertIsInstance (output [3 ], dict )
2326+ self .assertEqual (len (output [3 ]), 2 )
2327+ self .assertIn ('a' , output [3 ])
2328+ self .assertIn ('b' , output [3 ])
2329+ self .assertIsInstance (output [3 ]['a' ], Variable )
2330+ self .assertIsInstance (output [3 ]['b' ], list )
2331+ self .assertIsInstance (output [3 ]['b' ][0 ], Variable )
23102332
23112333 @unittest .skipIf (not TEST_MULTIGPU , "multi-GPU not supported" )
23122334 def test_data_parallel_nested_input (self ):
0 commit comments