@@ -2287,8 +2287,8 @@ def test_scatter_gpu(self):
22872287
22882288 def _test_gather (self , output_device ):
22892289 inputs = (
2290- Variable ( torch .randn (2 , 4 ). cuda ( 0 ) , requires_grad = True ),
2291- Variable ( torch .randn (2 , 4 ). cuda ( 1 ) , requires_grad = True )
2290+ torch .randn (2 , 4 , device = ' cuda:0' , requires_grad = True ),
2291+ torch .randn (2 , 4 , device = ' cuda:1' , requires_grad = True ),
22922292 )
22932293 result = dp .gather (inputs , output_device )
22942294 self .assertEqual (result .size (), torch .Size ([4 , 4 ]))
@@ -2306,6 +2306,27 @@ def _test_gather(self, output_device):
23062306 self .assertEqual (inputs [1 ].grad .data , grad [2 :])
23072307 _assertGradAndGradgradChecks (self , lambda x , y : dp .gather ((x , y ), output_device ), inputs )
23082308
2309+ # test scalar inputs, should stack into a vector in this case
2310+ inputs = (
2311+ torch .randn ((), device = 'cuda:0' , requires_grad = True ),
2312+ torch .randn ((), device = 'cuda:1' , requires_grad = True ),
2313+ )
2314+ result = dp .gather (inputs , output_device )
2315+ self .assertEqual (result .size (), torch .Size ([2 ]))
2316+ self .assertEqual (result [0 ], inputs [0 ])
2317+ self .assertEqual (result [1 ], inputs [1 ])
2318+ if output_device != - 1 :
2319+ self .assertEqual (result .get_device (), output_device )
2320+ else :
2321+ self .assertFalse (result .is_cuda )
2322+ grad = torch .randn (2 )
2323+ if output_device != - 1 :
2324+ grad = grad .cuda (output_device )
2325+ result .backward (grad )
2326+ self .assertEqual (inputs [0 ].grad , grad [0 ])
2327+ self .assertEqual (inputs [1 ].grad , grad [1 ])
2328+ _assertGradAndGradgradChecks (self , lambda x , y : dp .gather ((x , y ), output_device ), inputs )
2329+
23092330 @unittest .skipIf (not TEST_MULTIGPU , "multi-GPU not supported" )
23102331 def test_gather_cpu (self ):
23112332 self ._test_gather (- 1 )
0 commit comments