@@ -1615,6 +1615,20 @@ def test_broadcast_double_backwards_gpu(self):
16151615 torch .randn (4 , 4 ).cuda (),
16161616 torch .randn (4 , 4 ).cuda ())
16171617
1618+ @unittest .skipIf (not TEST_MULTIGPU , "multi-GPU not supported" )
1619+ def test_broadcast_not_requiring_grad (self ):
1620+ variables = [
1621+ Variable (torch .randn (1 , 2 ).cuda (), requires_grad = True ),
1622+ Variable (torch .randn (1 , 2 ).cuda (), requires_grad = False ),
1623+ Variable (torch .randn (1 , 2 ).cuda (), requires_grad = False ),
1624+ Variable (torch .randn (1 , 2 ).cuda (), requires_grad = True ),
1625+ Variable (torch .randn (1 , 2 ).cuda (), requires_grad = True ),
1626+ ]
1627+ broadcasted_variables = Broadcast .apply ((0 , 1 ), * variables )
1628+ for output_idx , broadcasted_var in enumerate (broadcasted_variables ):
1629+ input_var = variables [output_idx % len (variables )]
1630+ self .assertEqual (input_var .requires_grad , broadcasted_var .requires_grad )
1631+
16181632 @unittest .skipIf (not TEST_MULTIGPU , "multi-GPU not supported" )
16191633 def test_replicate (self ):
16201634 module = nn .Linear (10 , 5 ).float ().cuda ()
0 commit comments