Skip to content

Commit c238ee3

Browse files
apaszkesoumith
authored andcommitted
Fix issues with lazy grad initialization (#912)
1 parent f17cfe4 commit c238ee3

File tree

3 files changed

+17
-5
lines changed

3 files changed

+17
-5
lines changed

test/test_nn.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -357,19 +357,31 @@ def bw_hook(module, grad_input, grad_output):
357357
self.assertEqual(input.grad.data, expected_grad)
358358

359359
def test_zero_grad(self):
360+
i = Variable(torch.randn(2, 5), requires_grad=True)
360361
module = nn.Linear(5, 5)
361362
for p in module.parameters():
362363
p.requires_grad = False
363364
module.zero_grad()
364365

365366
module.weight.requires_grad = True
366-
module.weight._grad = Variable(module.weight.data.clone().fill_(1))
367+
module.zero_grad()
368+
self.assertIsNone(module.weight.grad) # uninitialized grad
369+
370+
module(i).sum().backward()
371+
self.assertIsNotNone(module.weight.grad)
372+
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
367373
module.zero_grad()
368374
self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
369375

370376
module.bias.requires_grad = True
371-
module.weight._grad = Variable(module.weight.data.clone().fill_(1))
372-
module.bias._grad = Variable(module.bias.data.clone().fill_(1))
377+
module.zero_grad()
378+
self.assertIsNotNone(module.weight.grad)
379+
self.assertIsNone(module.bias.grad)
380+
module(i).sum().backward()
381+
self.assertIsNotNone(module.weight.grad)
382+
self.assertIsNotNone(module.bias.grad)
383+
self.assertGreater(module.weight.grad.data.abs().sum(), 0)
384+
self.assertGreater(module.bias.grad.data.abs().sum(), 0)
373385
module.zero_grad()
374386
self.assertEqual(module.weight.grad.data, module.weight.data.clone().zero_())
375387
self.assertEqual(module.bias.grad.data, module.bias.data.clone().zero_())

torch/nn/modules/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def eval(self):
377377
def zero_grad(self):
378378
"""Sets gradients of all model parameters to zero."""
379379
for p in self.parameters():
380-
if p.requires_grad:
380+
if p.grad is not None:
381381
p.grad.data.zero_()
382382

383383
def share_memory(self):

torch/nn/utils/clip_grad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
1111
max_norm (float or int): max norm of the gradients
1212
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
1313
"""
14-
parameters = list(parameters)
14+
parameters = list(filter(lambda p: p.grad is not None, parameters))
1515
max_norm = float(max_norm)
1616
norm_type = float(norm_type)
1717
if norm_type == float('inf'):

0 commit comments

Comments
 (0)