Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3240,7 +3240,7 @@ def verify_reduction_scalars(input, reduction, output):
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
"Scipy v1.0 and/or numpy not found")
def test_multihead_attention(self):
def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=None, src_lengths=None,
def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=None, src_lengths=None,
attn_mask=None, add_zero_attn=False):
""" Numpy-based reference implementation of scaled dot attention
for testing"""
Expand Down Expand Up @@ -3374,7 +3374,7 @@ def _multihead_attn_test_helper(add_key_padding_mask, add_bias_kv=False, add_zer
decoder_state_tensor = torch.from_numpy(decoder_state).double()
source_hid_tensor = torch.from_numpy(K).double().transpose(0, 1)

multihead_attn_module = MultiheadAttention(d_model, nheads,
multihead_attn_module = MultiheadAttention(d_model, nheads,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn)

Expand Down Expand Up @@ -3404,7 +3404,7 @@ def _multihead_attn_test_helper(add_key_padding_mask, add_bias_kv=False, add_zer
multihead_attn_module.bias_k, multihead_attn_module.bias_v,
multihead_attn_module.add_zero_attn, multihead_attn_module.dropout,
multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias,
multihead_attn_module.training, src_len_mask, True, attn_mask_tensor)
multihead_attn_module.training, src_len_mask, True, attn_mask_tensor)

result = result.squeeze(0).detach().numpy()

Expand Down Expand Up @@ -3846,6 +3846,42 @@ def fn(t):

torch.autograd.gradcheck(fn, (m.t_rg,))

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
@skipIfRocm
def test_data_parallel_rnn(self):

class TestModule(torch.nn.Module):

def __init__(self):
super(TestModule, self).__init__()
self.rnn = torch.nn.LSTM(300, 1024, 1, batch_first=True, bidirectional=True)

def forward(self, x):
self.rnn.flatten_parameters()
return self.rnn(x)

def step(model):
opt = torch.optim.SGD(model.parameters(), lr=0.1)
input = torch.ones(4, 4, 300).to(0)
output = model(input)
loss = F.mse_loss(output[0], torch.zeros_like(output[0]))
loss.backward()
opt.step()

with torch.no_grad():
model = TestModule().to(0)
model_dp = torch.nn.DataParallel(deepcopy(model))

# make sure DP does not crash when grad is disabled.
# See #21108
model_dp(torch.rand(2, 4, 300).to(0))

step(model)
step(model_dp)

for p1, p2 in zip(model.parameters(), model_dp.parameters()):
p1.allclose(p2)

@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
def test_parallel_apply(self):
l1 = nn.Linear(10, 5).to("cuda:0", torch.float)
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/parallel/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def forward(self, *inputs, **kwargs):
return self.gather(outputs, self.output_device)

def replicate(self, module, device_ids):
return replicate(module, device_ids)
return replicate(module, device_ids, not torch.is_grad_enabled())

def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
Expand Down