Skip to content

Commit 7fcaf3b

Browse files
tonybeltramelliezyang
authored andcommitted
Update torch.nn.init and torch.nn.utils.clip_grad (#6173)
Introducing two updates. 1. Add param to He initialization scheme in torch.nn.init Problem solved: The function calculate_gain can take an argument to specify the type of non-linearity used. However, it wasn't possible to pass this argument directly to the He / Kaiming weight initialization function. 2. Add util to clip gradient value in torch.nn.utils.clip_grad Problem solved: DL libraries typically provide users with easy access to functions for clipping the gradients both using the norm and a fixed value. However, the utils clip_grad.py only had a function to clip the gradient norm. * add param to He initialization scheme in torch.nn.init * add util to clip gradient value in torch/nn/utils/clip_grad.py * update doc in torch.nn.utils.clip_grad * update and add test for torch.nn.utils.clip_grad * update function signature in torch.nn.utils.clip_grad to match suffix_ convention * ensure backward compatibility in torch.nn.utils.clip_grad * remove DeprecationWarning in torch.nn.utils.clip_grad * extend test and implementation of torch.nn.utils.clip_grad * update test and implementation torch.nn.utils.clip_grad
1 parent 1e34493 commit 7fcaf3b

File tree

5 files changed

+67
-14
lines changed

5 files changed

+67
-14
lines changed

docs/source/nn.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -700,10 +700,15 @@ DataParallel layers (multi-GPU, distributed)
700700
Utilities
701701
---------
702702

703-
:hidden:`clip_grad_norm`
703+
:hidden:`clip_grad_norm_`
704704
~~~~~~~~~~~~~~~~~~~~~~~~
705705

706-
.. autofunction:: torch.nn.utils.clip_grad_norm
706+
.. autofunction:: torch.nn.utils.clip_grad_norm_
707+
708+
:hidden:`clip_grad_value_`
709+
~~~~~~~~~~~~~~~~~~~~~~~~
710+
711+
.. autofunction:: torch.nn.utils.clip_grad_value_
707712

708713
:hidden:`weight_norm`
709714
~~~~~~~~~~~~~~~~~~~~~

test/test_nn.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch.nn.init as init
2121
import torch.nn.utils.rnn as rnn_utils
2222
import torch.legacy.nn as legacy
23-
from torch.nn.utils import clip_grad_norm
23+
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
2424
from torch.nn.utils import parameters_to_vector, vector_to_parameters
2525
from torch.autograd import Variable, gradcheck
2626
from torch.autograd.gradcheck import gradgradcheck
@@ -1238,7 +1238,7 @@ def compare_scaling(grads):
12381238
for p, g in zip(l.parameters(), grads):
12391239
p._grad = Variable(g.clone().view_as(p.data))
12401240
norm_before = compute_norm(norm_type)
1241-
norm = clip_grad_norm(l.parameters(), max_norm, norm_type=norm_type)
1241+
norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type)
12421242
norm_after = compute_norm(norm_type)
12431243
self.assertEqual(norm, norm_before)
12441244
self.assertEqual(norm_after, max_norm)
@@ -1251,14 +1251,28 @@ def compare_scaling(grads):
12511251
for p, g in zip(l.parameters(), grads):
12521252
p.grad.data.copy_(g)
12531253
norm_before = compute_norm(norm_type)
1254-
norm = clip_grad_norm(l.parameters(), max_norm, norm_type=norm_type)
1254+
norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type)
12551255
norm_after = compute_norm(norm_type)
12561256
self.assertEqual(norm, norm_before)
12571257
self.assertEqual(norm_before, norm_after)
12581258
self.assertLessEqual(norm_after, max_norm)
12591259
scale = compare_scaling(grads)
12601260
self.assertEqual(scale, 1)
12611261

1262+
def test_clip_grad_value(self):
1263+
l = nn.Linear(10, 10)
1264+
clip_value = 2.5
1265+
1266+
grad_w, grad_b = torch.arange(-50, 50).view(10, 10).div(5), torch.ones(10).mul(2)
1267+
for grad_list in [[grad_w, grad_b], [grad_w, None]]:
1268+
for p, g in zip(l.parameters(), grad_list):
1269+
p._grad = Variable(g.clone().view_as(p.data)) if g is not None else g
1270+
1271+
clip_grad_value_(l.parameters(), clip_value)
1272+
for p in filter(lambda p: p.grad is not None, l.parameters()):
1273+
self.assertLessEqual(p.grad.data.max(), clip_value)
1274+
self.assertGreaterEqual(p.grad.data.min(), -clip_value)
1275+
12621276
def test_parameters_to_vector(self):
12631277
conv1 = nn.Conv2d(3, 10, 5)
12641278
fc1 = nn.Linear(10, 20)

torch/nn/init.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def _calculate_correct_fan(tensor, mode):
230230
return fan_in if mode == 'fan_in' else fan_out
231231

232232

233-
def kaiming_uniform_(tensor, a=0, mode='fan_in'):
233+
def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
234234
r"""Fills the input `Tensor` with values according to the method
235235
described in "Delving deep into rectifiers: Surpassing human-level
236236
performance on ImageNet classification" - He, K. et al. (2015), using a
@@ -250,20 +250,22 @@ def kaiming_uniform_(tensor, a=0, mode='fan_in'):
250250
preserves the magnitude of the variance of the weights in the
251251
forward pass. Choosing `fan_out` preserves the magnitudes in the
252252
backwards pass.
253+
nonlinearity: the non-linear function (`nn.functional` name),
254+
recommended to use only with 'relu' or 'leaky_relu' (default).
253255
254256
Examples:
255257
>>> w = torch.Tensor(3, 5)
256-
>>> nn.init.kaiming_uniform_(w, mode='fan_in')
258+
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
257259
"""
258260
fan = _calculate_correct_fan(tensor, mode)
259-
gain = calculate_gain('leaky_relu', a)
261+
gain = calculate_gain(nonlinearity, a)
260262
std = gain / math.sqrt(fan)
261263
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
262264
with torch.no_grad():
263265
return tensor.uniform_(-bound, bound)
264266

265267

266-
def kaiming_normal_(tensor, a=0, mode='fan_in'):
268+
def kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
267269
r"""Fills the input `Tensor` with values according to the method
268270
described in "Delving deep into rectifiers: Surpassing human-level
269271
performance on ImageNet classification" - He, K. et al. (2015), using a
@@ -283,13 +285,15 @@ def kaiming_normal_(tensor, a=0, mode='fan_in'):
283285
preserves the magnitude of the variance of the weights in the
284286
forward pass. Choosing `fan_out` preserves the magnitudes in the
285287
backwards pass.
288+
nonlinearity: the non-linear function (`nn.functional` name),
289+
recommended to use only with 'relu' or 'leaky_relu' (default).
286290
287291
Examples:
288292
>>> w = torch.Tensor(3, 5)
289-
>>> nn.init.kaiming_normal_(w, mode='fan_out')
293+
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
290294
"""
291295
fan = _calculate_correct_fan(tensor, mode)
292-
gain = calculate_gain('leaky_relu', a)
296+
gain = calculate_gain(nonlinearity, a)
293297
std = gain / math.sqrt(fan)
294298
with torch.no_grad():
295299
return tensor.normal_(0, std)

torch/nn/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from . import rnn
2-
from .clip_grad import clip_grad_norm
2+
from .clip_grad import clip_grad_norm, clip_grad_norm_, clip_grad_value_
33
from .weight_norm import weight_norm, remove_weight_norm
44
from .convert_parameters import parameters_to_vector, vector_to_parameters

torch/nn/utils/clip_grad.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import warnings
12

2-
def clip_grad_norm(parameters, max_norm, norm_type=2):
3+
4+
def clip_grad_norm_(parameters, max_norm, norm_type=2):
35
r"""Clips gradient norm of an iterable of parameters.
46
57
The norm is computed over all gradients together, as if they were
68
concatenated into a single vector. Gradients are modified in-place.
79
810
Arguments:
9-
parameters (Iterable[Variable]): an iterable of Variables that will have
11+
parameters (Iterable[Tensor]): an iterable of Tensors that will have
1012
gradients normalized
1113
max_norm (float or int): max norm of the gradients
1214
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
@@ -31,3 +33,31 @@ def clip_grad_norm(parameters, max_norm, norm_type=2):
3133
for p in parameters:
3234
p.grad.data.mul_(clip_coef)
3335
return total_norm
36+
37+
38+
def clip_grad_norm(parameters, max_norm, norm_type=2):
39+
r"""Clips gradient norm of an iterable of parameters.
40+
41+
.. warning::
42+
This method is now deprecated in favor of
43+
:func:`torch.nn.utils.clip_grad_norm_`.
44+
"""
45+
warnings.warn("torch.nn.utils.clip_grad_norm is now deprecated in favor "
46+
"of torch.nn.utils.clip_grad_norm_.", stacklevel=2)
47+
return clip_grad_norm_(parameters, max_norm, norm_type)
48+
49+
50+
def clip_grad_value_(parameters, clip_value):
51+
r"""Clips gradient of an iterable of parameters at specified value.
52+
53+
Gradients are modified in-place.
54+
55+
Arguments:
56+
parameters (Iterable[Tensor]): an iterable of Tensors that will have
57+
gradients normalized
58+
clip_value (float or int): maximum allowed value of the gradients
59+
The gradients are clipped in the range [-clip_value, clip_value]
60+
"""
61+
clip_value = float(clip_value)
62+
for p in filter(lambda p: p.grad is not None, parameters):
63+
p.grad.data.clamp_(min=-clip_value, max=clip_value)

0 commit comments

Comments
 (0)