Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,6 +1281,16 @@ def compare_scaling(grads):
scale = compare_scaling(grads)
self.assertEqual(scale, 1)

# Should accept a single Tensor as input
p1, p2 = torch.randn(10, 10), torch.randn(10, 10)
g = torch.arange(1., 101).view(10, 10)
p1._grad = g.clone()
p2._grad = g.clone()
for norm_type in [0.5, 1.5, 2, 4, 'inf']:
clip_grad_norm_(p1, max_norm, norm_type=norm_type)
clip_grad_norm_([p2], max_norm, norm_type=norm_type)
self.assertEqual(p1.grad, p2.grad)

def test_clip_grad_value(self):
l = nn.Linear(10, 10)
clip_value = 2.5
Expand All @@ -1295,6 +1305,15 @@ def test_clip_grad_value(self):
self.assertLessEqual(p.grad.data.max(), clip_value)
self.assertGreaterEqual(p.grad.data.min(), -clip_value)

# Should accept a single Tensor as input
p1, p2 = torch.randn(10, 10), torch.randn(10, 10)
g = torch.arange(-50., 50).view(10, 10).div_(5)
p1._grad = g.clone()
p2._grad = g.clone()
clip_grad_value_(p1, clip_value)
clip_grad_value_([p2], clip_value)
self.assertEqual(p1.grad, p2.grad)

def test_parameters_to_vector(self):
conv1 = nn.Conv2d(3, 10, 5)
fc1 = nn.Linear(10, 20)
Expand Down
13 changes: 9 additions & 4 deletions torch/nn/utils/clip_grad.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
import torch


def clip_grad_norm_(parameters, max_norm, norm_type=2):
Expand All @@ -8,15 +9,17 @@ def clip_grad_norm_(parameters, max_norm, norm_type=2):
concatenated into a single vector. Gradients are modified in-place.

Arguments:
parameters (Iterable[Tensor]): an iterable of Tensors that will have
gradients normalized
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.

Returns:
Total norm of the parameters (viewed as a single vector).
"""
if torch.is_tensor(parameters):

This comment was marked as off-topic.

This comment was marked as off-topic.

parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
max_norm = float(max_norm)
norm_type = float(norm_type)
Expand Down Expand Up @@ -53,11 +56,13 @@ def clip_grad_value_(parameters, clip_value):
Gradients are modified in-place.

Arguments:
parameters (Iterable[Tensor]): an iterable of Tensors that will have
gradients normalized
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
clip_value (float or int): maximum allowed value of the gradients
The gradients are clipped in the range [-clip_value, clip_value]
"""
if torch.is_tensor(parameters):
parameters = [parameters]
clip_value = float(clip_value)
for p in filter(lambda p: p.grad is not None, parameters):
p.grad.data.clamp_(min=-clip_value, max=clip_value)