@@ -3240,7 +3240,7 @@ def verify_reduction_scalars(input, reduction, output):
32403240 @unittest .skipIf ((not TEST_NUMPY ) or (not TEST_SCIPY ) or (scipy .__version__ < '1.0.0' ),
32413241 "Scipy v1.0 and/or numpy not found" )
32423242 def test_multihead_attention (self ):
3243- def _scaled_dot_attn_ref (Q , K , V , dims , unseen_mask = None , src_lengths = None ,
3243+ def _scaled_dot_attn_ref (Q , K , V , dims , unseen_mask = None , src_lengths = None ,
32443244 attn_mask = None , add_zero_attn = False ):
32453245 """ Numpy-based reference implementation of scaled dot attention
32463246 for testing"""
@@ -3374,7 +3374,7 @@ def _multihead_attn_test_helper(add_key_padding_mask, add_bias_kv=False, add_zer
33743374 decoder_state_tensor = torch .from_numpy (decoder_state ).double ()
33753375 source_hid_tensor = torch .from_numpy (K ).double ().transpose (0 , 1 )
33763376
3377- multihead_attn_module = MultiheadAttention (d_model , nheads ,
3377+ multihead_attn_module = MultiheadAttention (d_model , nheads ,
33783378 add_bias_kv = add_bias_kv ,
33793379 add_zero_attn = add_zero_attn )
33803380
@@ -3404,7 +3404,7 @@ def _multihead_attn_test_helper(add_key_padding_mask, add_bias_kv=False, add_zer
34043404 multihead_attn_module .bias_k , multihead_attn_module .bias_v ,
34053405 multihead_attn_module .add_zero_attn , multihead_attn_module .dropout ,
34063406 multihead_attn_module .out_proj .weight , multihead_attn_module .out_proj .bias ,
3407- multihead_attn_module .training , src_len_mask , True , attn_mask_tensor )
3407+ multihead_attn_module .training , src_len_mask , True , attn_mask_tensor )
34083408
34093409 result = result .squeeze (0 ).detach ().numpy ()
34103410
@@ -3846,6 +3846,42 @@ def fn(t):
38463846
38473847 torch .autograd .gradcheck (fn , (m .t_rg ,))
38483848
3849+ @unittest .skipIf (not TEST_MULTIGPU , "multi-GPU not supported" )
3850+ @skipIfRocm
3851+ def test_data_parallel_rnn (self ):
3852+
3853+ class TestModule (torch .nn .Module ):
3854+
3855+ def __init__ (self ):
3856+ super (TestModule , self ).__init__ ()
3857+ self .rnn = torch .nn .LSTM (300 , 1024 , 1 , batch_first = True , bidirectional = True )
3858+
3859+ def forward (self , x ):
3860+ self .rnn .flatten_parameters ()
3861+ return self .rnn (x )
3862+
3863+ def step (model ):
3864+ opt = torch .optim .SGD (model .parameters (), lr = 0.1 )
3865+ input = torch .ones (4 , 4 , 300 ).to (0 )
3866+ output = model (input )
3867+ loss = F .mse_loss (output [0 ], torch .zeros_like (output [0 ]))
3868+ loss .backward ()
3869+ opt .step ()
3870+
3871+ with torch .no_grad ():
3872+ model = TestModule ().to (0 )
3873+ model_dp = torch .nn .DataParallel (deepcopy (model ))
3874+
3875+ # make sure DP does not crash when grad is disabled.
3876+ # See #21108
3877+ model_dp (torch .rand (2 , 4 , 300 ).to (0 ))
3878+
3879+ step (model )
3880+ step (model_dp )
3881+
3882+ for p1 , p2 in zip (model .parameters (), model_dp .parameters ()):
3883+ p1 .allclose (p2 )
3884+
38493885 @unittest .skipIf (not TEST_MULTIGPU , "multi-GPU not supported" )
38503886 def test_parallel_apply (self ):
38513887 l1 = nn .Linear (10 , 5 ).to ("cuda:0" , torch .float )
0 commit comments