Skip to content

Commit f62a006

Browse files
mrshenlifacebook-github-bot
authored andcommitted
Retry Fix Python DataParallel RNN in no_grad mode (#21262)
Summary: Retry #21197 The previous one failed because it uses some Python3 only syntax. ezyang Do we still have multi-GPU py2 tests? I am curious why the CI tests did not catch this error. Pull Request resolved: #21262 Differential Revision: D15598941 Pulled By: mrshenli fbshipit-source-id: 95f416589448c443685d6d236d205b011998a715
1 parent 0c6efbd commit f62a006

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

test/test_nn.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3240,7 +3240,7 @@ def verify_reduction_scalars(input, reduction, output):
32403240
@unittest.skipIf((not TEST_NUMPY) or (not TEST_SCIPY) or (scipy.__version__ < '1.0.0'),
32413241
"Scipy v1.0 and/or numpy not found")
32423242
def test_multihead_attention(self):
3243-
def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=None, src_lengths=None,
3243+
def _scaled_dot_attn_ref(Q, K, V, dims, unseen_mask=None, src_lengths=None,
32443244
attn_mask=None, add_zero_attn=False):
32453245
""" Numpy-based reference implementation of scaled dot attention
32463246
for testing"""
@@ -3374,7 +3374,7 @@ def _multihead_attn_test_helper(add_key_padding_mask, add_bias_kv=False, add_zer
33743374
decoder_state_tensor = torch.from_numpy(decoder_state).double()
33753375
source_hid_tensor = torch.from_numpy(K).double().transpose(0, 1)
33763376

3377-
multihead_attn_module = MultiheadAttention(d_model, nheads,
3377+
multihead_attn_module = MultiheadAttention(d_model, nheads,
33783378
add_bias_kv=add_bias_kv,
33793379
add_zero_attn=add_zero_attn)
33803380

@@ -3404,7 +3404,7 @@ def _multihead_attn_test_helper(add_key_padding_mask, add_bias_kv=False, add_zer
34043404
multihead_attn_module.bias_k, multihead_attn_module.bias_v,
34053405
multihead_attn_module.add_zero_attn, multihead_attn_module.dropout,
34063406
multihead_attn_module.out_proj.weight, multihead_attn_module.out_proj.bias,
3407-
multihead_attn_module.training, src_len_mask, True, attn_mask_tensor)
3407+
multihead_attn_module.training, src_len_mask, True, attn_mask_tensor)
34083408

34093409
result = result.squeeze(0).detach().numpy()
34103410

@@ -3846,6 +3846,42 @@ def fn(t):
38463846

38473847
torch.autograd.gradcheck(fn, (m.t_rg,))
38483848

3849+
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
3850+
@skipIfRocm
3851+
def test_data_parallel_rnn(self):
3852+
3853+
class TestModule(torch.nn.Module):
3854+
3855+
def __init__(self):
3856+
super(TestModule, self).__init__()
3857+
self.rnn = torch.nn.LSTM(300, 1024, 1, batch_first=True, bidirectional=True)
3858+
3859+
def forward(self, x):
3860+
self.rnn.flatten_parameters()
3861+
return self.rnn(x)
3862+
3863+
def step(model):
3864+
opt = torch.optim.SGD(model.parameters(), lr=0.1)
3865+
input = torch.ones(4, 4, 300).to(0)
3866+
output = model(input)
3867+
loss = F.mse_loss(output[0], torch.zeros_like(output[0]))
3868+
loss.backward()
3869+
opt.step()
3870+
3871+
with torch.no_grad():
3872+
model = TestModule().to(0)
3873+
model_dp = torch.nn.DataParallel(deepcopy(model))
3874+
3875+
# make sure DP does not crash when grad is disabled.
3876+
# See #21108
3877+
model_dp(torch.rand(2, 4, 300).to(0))
3878+
3879+
step(model)
3880+
step(model_dp)
3881+
3882+
for p1, p2 in zip(model.parameters(), model_dp.parameters()):
3883+
p1.allclose(p2)
3884+
38493885
@unittest.skipIf(not TEST_MULTIGPU, "multi-GPU not supported")
38503886
def test_parallel_apply(self):
38513887
l1 = nn.Linear(10, 5).to("cuda:0", torch.float)

torch/nn/parallel/data_parallel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def forward(self, *inputs, **kwargs):
153153
return self.gather(outputs, self.output_device)
154154

155155
def replicate(self, module, device_ids):
156-
return replicate(module, device_ids)
156+
return replicate(module, device_ids, not torch.is_grad_enabled())
157157

158158
def scatter(self, inputs, kwargs, device_ids):
159159
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)

0 commit comments

Comments
 (0)