Skip to content

Commit e4d83d5

Browse files
milesialpytorchmergebot
authored andcommitted
Foreach gradient clipping (#91846)
Faster gradient clipping using the foreach functions ``` [------------------------ (tensors, scalar) -------------------------] | without foreach | with foreach | apex 1 threads: ---------------------------------------------------------------------- 10 tensors of size 4 | 120.5 | 61.1 | 50.3 100 tensors of size 4 | 946.2 | 239.5 | 136.3 1000 tensors of size 4 | 9808.5 | 2151.1 | 1006.9 10000 tensors of size 4 | 96871.2 | 22637.4 | 10119.1 10 tensors of size 16 | 121.0 | 64.1 | 52.5 100 tensors of size 16 | 993.4 | 252.6 | 136.7 1000 tensors of size 16 | 9427.7 | 2151.2 | 1049.5 10000 tensors of size 16 | 97437.1 | 22203.1 | 10340.0 10 tensors of size 256 | 118.9 | 62.3 | 51.5 100 tensors of size 256 | 955.2 | 243.1 | 134.2 1000 tensors of size 256 | 9374.9 | 2140.7 | 1009.6 10000 tensors of size 256 | 95302.5 | 21849.4 | 10215.5 10 tensors of size 65536 | 118.5 | 62.4 | 51.1 100 tensors of size 65536 | 1740.7 | 243.3 | 225.3 1000 tensors of size 65536 | 17364.1 | 2228.7 | 2004.5 10000 tensors of size 65536 | 177510.1 | 25410.4 | 20678.2 ``` Pull Request resolved: #91846 Approved by: https://github.com/janeyx99
1 parent 44b7a0b commit e4d83d5

File tree

5 files changed

+156
-104
lines changed

5 files changed

+156
-104
lines changed

.github/ci_commit_pins/xla.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
3ab8494305810d3c943f670bc6b028514942c7a0
1+
eac4e547138ab22a9b41c6f96208613fd7dd19d5

test/test_nn.py

Lines changed: 90 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from functools import partial
1414
from collections import OrderedDict
1515
from tempfile import NamedTemporaryFile
16+
from unittest import SkipTest
1617

1718
import torch
1819

@@ -1654,85 +1655,6 @@ def assign_weight():
16541655
# This should work though
16551656
l2.weight = Parameter(torch.randn(10, 10))
16561657

1657-
def test_clip_grad_norm(self):
1658-
l = nn.Linear(10, 10)
1659-
max_norm = 2
1660-
1661-
def compute_norm(norm_type):
1662-
norm_type = float(norm_type)
1663-
if norm_type != inf:
1664-
total_norm = 0
1665-
for p in l.parameters():
1666-
total_norm += p.grad.data.abs().pow(norm_type).sum()
1667-
return pow(total_norm, 1. / norm_type)
1668-
else:
1669-
return max(p.grad.data.abs().max() for p in l.parameters())
1670-
1671-
def compare_scaling(grads):
1672-
p_scale = [p.grad.data.div(g).view(-1) for p, g in zip(l.parameters(), grads)]
1673-
scale = torch.cat(p_scale)
1674-
self.assertEqual(scale.std(), 0)
1675-
return scale[0]
1676-
1677-
grads = torch.arange(1., 101).view(10, 10), torch.ones(10).div(1000)
1678-
for norm_type in [0.5, 1.5, 2, 4, 'inf']:
1679-
for p, g in zip(l.parameters(), grads):
1680-
p._grad = g.clone().view_as(p.data)
1681-
norm_before = compute_norm(norm_type)
1682-
norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type)
1683-
norm_after = compute_norm(norm_type)
1684-
self.assertEqual(norm, norm_before)
1685-
self.assertEqual(norm_after, max_norm)
1686-
self.assertLessEqual(norm_after, norm_before)
1687-
compare_scaling(grads)
1688-
1689-
# Small gradients should be left unchanged
1690-
grads = torch.rand(10, 10).div(10000), torch.ones(10).div(500)
1691-
for norm_type in [0.5, 1.5, 2, 4, 'inf']:
1692-
for p, g in zip(l.parameters(), grads):
1693-
p.grad.data.copy_(g)
1694-
norm_before = compute_norm(norm_type)
1695-
norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type)
1696-
norm_after = compute_norm(norm_type)
1697-
self.assertEqual(norm, norm_before)
1698-
self.assertEqual(norm_before, norm_after)
1699-
self.assertLessEqual(norm_after, max_norm)
1700-
scale = compare_scaling(grads)
1701-
self.assertEqual(scale, 1)
1702-
1703-
# Should accept a single Tensor as input
1704-
p1, p2 = torch.randn(10, 10), torch.randn(10, 10)
1705-
g = torch.arange(1., 101).view(10, 10)
1706-
p1._grad = g.clone()
1707-
p2._grad = g.clone()
1708-
for norm_type in [0.5, 1.5, 2, 4, 'inf']:
1709-
clip_grad_norm_(p1, max_norm, norm_type=norm_type)
1710-
clip_grad_norm_([p2], max_norm, norm_type=norm_type)
1711-
self.assertEqual(p1.grad, p2.grad)
1712-
1713-
def test_clip_grad_value(self):
1714-
l = nn.Linear(10, 10)
1715-
clip_value = 2.5
1716-
1717-
grad_w, grad_b = torch.arange(-50., 50).view(10, 10).div_(5), torch.ones(10).mul_(2)
1718-
for grad_list in [[grad_w, grad_b], [grad_w, None]]:
1719-
for p, g in zip(l.parameters(), grad_list):
1720-
p._grad = g.clone().view_as(p.data) if g is not None else g
1721-
1722-
clip_grad_value_(l.parameters(), clip_value)
1723-
for p in filter(lambda p: p.grad is not None, l.parameters()):
1724-
self.assertLessEqual(p.grad.data.max(), clip_value)
1725-
self.assertGreaterEqual(p.grad.data.min(), -clip_value)
1726-
1727-
# Should accept a single Tensor as input
1728-
p1, p2 = torch.randn(10, 10), torch.randn(10, 10)
1729-
g = torch.arange(-50., 50).view(10, 10).div_(5)
1730-
p1._grad = g.clone()
1731-
p2._grad = g.clone()
1732-
clip_grad_value_(p1, clip_value)
1733-
clip_grad_value_([p2], clip_value)
1734-
self.assertEqual(p1.grad, p2.grad)
1735-
17361658
def test_parameters_to_vector(self):
17371659
conv1 = nn.Conv2d(3, 10, 5)
17381660
fc1 = nn.Linear(10, 20)
@@ -11473,7 +11395,8 @@ def run_test_case(norm_type, error_if_nonfinite, scalar, grad_only_one_elem, pre
1147311395

1147411396
@onlyCUDA
1147511397
@deviceCountAtLeast(2)
11476-
def test_clip_grad_norm_multi_device(self, devices):
11398+
@parametrize_test('foreach', (False, True))
11399+
def test_clip_grad_norm_multi_device(self, devices, foreach):
1147711400
class TestModel(nn.Module):
1147811401
def __init__(self):
1147911402
super(TestModel, self).__init__()
@@ -11489,8 +11412,8 @@ def __init__(self):
1148911412
p.grad = torch.ones_like(p)
1149011413
for p in ref_model.parameters():
1149111414
p.grad = torch.ones_like(p)
11492-
norm = clip_grad_norm_(test_model.parameters(), 0.5, norm_type=norm_type)
11493-
expected = clip_grad_norm_(ref_model.parameters(), 0.5, norm_type=norm_type)
11415+
norm = clip_grad_norm_(test_model.parameters(), 0.5, norm_type=norm_type, foreach=foreach)
11416+
expected = clip_grad_norm_(ref_model.parameters(), 0.5, norm_type=norm_type, foreach=foreach)
1149411417
self.assertEqual(norm, expected)
1149511418
for p, pe in zip(test_model.parameters(), ref_model.parameters()):
1149611419
self.assertEqual(p.grad.to(devices[0]), pe.grad)
@@ -12042,6 +11965,91 @@ def perm_fn(x):
1204211965
with cm:
1204311966
_test(activation=activation, batch_first=batch_first, training=training)
1204411967

11968+
@parametrize_test('foreach', (False, True))
11969+
def test_clip_grad_value(self, foreach, device):
11970+
if torch.device(device).type == 'xla' and foreach:
11971+
raise SkipTest('foreach not supported on XLA')
11972+
11973+
l = nn.Linear(10, 10).to(device)
11974+
clip_value = 2.5
11975+
11976+
grad_w, grad_b = torch.arange(-50., 50, device=device).view(10, 10).div_(5), torch.ones(10, device=device).mul_(2)
11977+
for grad_list in [[grad_w, grad_b], [grad_w, None]]:
11978+
for p, g in zip(l.parameters(), grad_list):
11979+
p._grad = g.clone().view_as(p.data) if g is not None else g
11980+
11981+
clip_grad_value_(l.parameters(), clip_value, foreach=foreach)
11982+
for p in filter(lambda p: p.grad is not None, l.parameters()):
11983+
self.assertLessEqual(p.grad.data.max(), clip_value)
11984+
self.assertGreaterEqual(p.grad.data.min(), -clip_value)
11985+
11986+
# Should accept a single Tensor as input
11987+
p1, p2 = torch.randn(10, 10, device=device), torch.randn(10, 10, device=device)
11988+
g = torch.arange(-50., 50, device=device).view(10, 10).div_(5)
11989+
p1._grad = g.clone()
11990+
p2._grad = g.clone()
11991+
clip_grad_value_(p1, clip_value, foreach=foreach)
11992+
clip_grad_value_([p2], clip_value, foreach=foreach)
11993+
self.assertEqual(p1.grad, p2.grad)
11994+
11995+
@parametrize_test('foreach', (False, True))
11996+
@parametrize_test('norm_type', (0.5, 1.5, 2, 4, 'inf'))
11997+
def test_clip_grad_norm(self, norm_type, foreach, device):
11998+
if torch.device(device).type == 'xla' and foreach:
11999+
raise SkipTest('foreach not supported on XLA')
12000+
12001+
l = nn.Linear(10, 10).to(device)
12002+
max_norm = 2
12003+
12004+
def compute_norm(norm_type):
12005+
norm_type = float(norm_type)
12006+
if norm_type != inf:
12007+
total_norm = 0
12008+
for p in l.parameters():
12009+
total_norm += p.grad.data.abs().pow(norm_type).sum()
12010+
return pow(total_norm, 1. / norm_type)
12011+
else:
12012+
return max(p.grad.data.abs().max() for p in l.parameters())
12013+
12014+
def compare_scaling(grads):
12015+
p_scale = [p.grad.data.div(g).view(-1) for p, g in zip(l.parameters(), grads)]
12016+
scale = torch.cat(p_scale)
12017+
self.assertEqual(scale.std(), 0)
12018+
return scale[0]
12019+
12020+
grads = torch.arange(1., 101, device=device).view(10, 10), torch.ones(10, device=device).div(1000)
12021+
for p, g in zip(l.parameters(), grads):
12022+
p._grad = g.clone().view_as(p.data)
12023+
norm_before = compute_norm(norm_type)
12024+
norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type, foreach=foreach)
12025+
norm_after = compute_norm(norm_type)
12026+
self.assertEqual(norm, norm_before)
12027+
self.assertEqual(norm_after, max_norm)
12028+
self.assertLessEqual(norm_after, norm_before)
12029+
compare_scaling(grads)
12030+
12031+
# Small gradients should be left unchanged
12032+
grads = torch.rand(10, 10, device=device).div(10000), torch.ones(10, device=device).div(500)
12033+
for p, g in zip(l.parameters(), grads):
12034+
p.grad.data.copy_(g)
12035+
norm_before = compute_norm(norm_type)
12036+
norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type, foreach=foreach)
12037+
norm_after = compute_norm(norm_type)
12038+
self.assertEqual(norm, norm_before)
12039+
self.assertEqual(norm_before, norm_after)
12040+
self.assertLessEqual(norm_after, max_norm)
12041+
scale = compare_scaling(grads)
12042+
self.assertEqual(scale, 1)
12043+
12044+
# Should accept a single Tensor as input
12045+
p1, p2 = torch.randn(10, 10, device=device), torch.randn(10, 10, device=device)
12046+
g = torch.arange(1., 101, device=device).view(10, 10)
12047+
p1._grad = g.clone()
12048+
p2._grad = g.clone()
12049+
clip_grad_norm_(p1, max_norm, norm_type=norm_type, foreach=foreach)
12050+
clip_grad_norm_([p2], max_norm, norm_type=norm_type, foreach=foreach)
12051+
self.assertEqual(p1.grad, p2.grad)
12052+
1204512053

1204612054
class TestFunctionalPickle(TestCase):
1204712055

torch/nn/utils/clip_grad.py

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import warnings
2+
from typing import Union, Iterable, List, Dict, Tuple, Optional
3+
24
import torch
5+
from torch import Tensor
36
from torch._six import inf
4-
from typing import Union, Iterable
7+
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
58

69
_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
710

811
__all__ = ['clip_grad_norm_', 'clip_grad_norm', 'clip_grad_value_']
912

1013
def clip_grad_norm_(
1114
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
12-
error_if_nonfinite: bool = False) -> torch.Tensor:
15+
error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor:
1316
r"""Clips gradient norm of an iterable of parameters.
1417
1518
The norm is computed over all gradients together, as if they were
@@ -24,6 +27,10 @@ def clip_grad_norm_(
2427
error_if_nonfinite (bool): if True, an error is thrown if the total
2528
norm of the gradients from :attr:`parameters` is ``nan``,
2629
``inf``, or ``-inf``. Default: False (will switch to True in the future)
30+
foreach (bool): use the faster foreach-based implementation.
31+
If ``None``, use the foreach implementation for CUDA and CPU tensors and silently fall back to the slow
32+
implementation for other device types.
33+
Default: ``None``
2734
2835
Returns:
2936
Total norm of the parameter gradients (viewed as a single vector).
@@ -35,12 +42,25 @@ def clip_grad_norm_(
3542
norm_type = float(norm_type)
3643
if len(grads) == 0:
3744
return torch.tensor(0.)
38-
device = grads[0].device
45+
first_device = grads[0].device
46+
grouped_grads: Dict[Tuple[torch.device, torch.dtype], List[List[Tensor]]] \
47+
= _group_tensors_by_device_and_dtype([[g.detach() for g in grads]]) # type: ignore[assignment]
48+
3949
if norm_type == inf:
40-
norms = [g.detach().abs().max().to(device) for g in grads]
50+
norms = [g.detach().abs().max().to(first_device) for g in grads]
4151
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
4252
else:
43-
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
53+
norms = []
54+
for ((device, _), [grads]) in grouped_grads.items():
55+
if (foreach is None or foreach) and device.type in {'cpu', 'cuda'}:
56+
norms.extend(torch._foreach_norm(grads, norm_type))
57+
elif foreach:
58+
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
59+
else:
60+
norms.extend([torch.norm(g, norm_type) for g in grads])
61+
62+
total_norm = torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)
63+
4464
if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
4565
raise RuntimeError(
4666
f'The total norm of order {norm_type} for gradients from '
@@ -52,14 +72,22 @@ def clip_grad_norm_(
5272
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
5373
# when the gradients do not reside in CPU memory.
5474
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
55-
for g in grads:
56-
g.detach().mul_(clip_coef_clamped.to(g.device))
75+
for ((device, _), [grads]) in grouped_grads.items():
76+
if (foreach is None or foreach) and device.type in ('cpu', 'cuda'):
77+
torch._foreach_mul_(grads, clip_coef_clamped.to(device)) # type: ignore[call-overload]
78+
elif foreach:
79+
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
80+
else:
81+
clip_coef_clamped_device = clip_coef_clamped.to(device)
82+
for g in grads:
83+
g.detach().mul_(clip_coef_clamped_device)
84+
5785
return total_norm
5886

5987

6088
def clip_grad_norm(
6189
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.,
62-
error_if_nonfinite: bool = False) -> torch.Tensor:
90+
error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor:
6391
r"""Clips gradient norm of an iterable of parameters.
6492
6593
.. warning::
@@ -68,10 +96,10 @@ def clip_grad_norm(
6896
"""
6997
warnings.warn("torch.nn.utils.clip_grad_norm is now deprecated in favor "
7098
"of torch.nn.utils.clip_grad_norm_.", stacklevel=2)
71-
return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite)
99+
return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach)
72100

73101

74-
def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float) -> None:
102+
def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float, foreach: Optional[bool] = None) -> None:
75103
r"""Clips gradient of an iterable of parameters at specified value.
76104
77105
Gradients are modified in-place.
@@ -82,9 +110,25 @@ def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float) -> None:
82110
clip_value (float): maximum allowed value of the gradients.
83111
The gradients are clipped in the range
84112
:math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
113+
foreach (bool): use the faster foreach-based implementation
114+
If ``None``, use the foreach implementation for CUDA and CPU tensors and silently fall back to the slow
115+
implementation for other device types.
116+
Default: ``None``
85117
"""
86118
if isinstance(parameters, torch.Tensor):
87119
parameters = [parameters]
88120
clip_value = float(clip_value)
89-
for p in filter(lambda p: p.grad is not None, parameters):
90-
p.grad.data.clamp_(min=-clip_value, max=clip_value)
121+
122+
grads = [p.grad for p in parameters if p.grad is not None]
123+
grouped_grads: Dict[Tuple[torch.device, torch.dtype], List[List[Tensor]]] \
124+
= _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment]
125+
126+
for ((device, _), [grads]) in grouped_grads.items():
127+
if (foreach is None or foreach) and device.type in {'cpu', 'cuda'}:
128+
torch._foreach_clamp_min_(grads, -clip_value)
129+
torch._foreach_clamp_max_(grads, clip_value)
130+
elif foreach:
131+
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
132+
else:
133+
for grad in grads:
134+
grad.data.clamp_(min=-clip_value, max=clip_value)

torch/optim/adam.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,8 +604,8 @@ def _fused_adam(
604604
device_state_steps,
605605
) = grouped_tensors[(device, dtype)]
606606
if grad_scale is not None and found_inf is not None:
607-
device_grad_scale = grad_scale.get(device)
608-
device_found_inf = found_inf.get(device)
607+
device_grad_scale = grad_scale.get(str(device))
608+
device_found_inf = found_inf.get(str(device))
609609
else:
610610
device_grad_scale = None
611611
device_found_inf = None

torch/utils/_foreach_utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33

44
import torch
55
from torch import Tensor
6+
from torch.autograd.grad_mode import no_grad
67

78

8-
# _group_tensors_by_device_and_dtype is a util function that splits tensors into groups by device and dtype,
9-
# which is useful before sending tensors off to a foreach implementation, which requires tensors to be on
10-
# one device and dtype.
9+
# This util function splits tensors into groups by device and dtype, which is useful before sending
10+
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
1111
# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
1212
# - tensorlists CAN be None
1313
# - all tensors in the first specified list cannot be None
@@ -17,16 +17,16 @@
1717
# Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
1818
# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
1919
# may be necessary. Check out torch/optim/sgd.py for an example.
20-
@torch.no_grad()
20+
@no_grad()
2121
def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
2222
with_indices: Optional[bool] = False) -> \
23-
Dict[Tuple[str, torch.dtype], List[List[Union[Tensor, int]]]]:
23+
Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]]:
2424
assert all([not x or len(x) == len(tensorlistlist[0]) for x in tensorlistlist]), (
2525
"all specified tensorlists must match in length")
26-
per_device_and_dtype_tensors: Dict[Tuple[str, torch.dtype], List[List[Union[Tensor, int]]]] = defaultdict(
26+
per_device_and_dtype_tensors: Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]] = defaultdict(
2727
lambda: [[] for _ in range(len(tensorlistlist) + (1 if with_indices else 0))])
2828
for i, t in enumerate(tensorlistlist[0]):
29-
key = (str(t.device), t.dtype)
29+
key = (t.device, t.dtype)
3030
for j in range(len(tensorlistlist)):
3131
# a tensorlist may be empty/None
3232
if tensorlistlist[j]:

0 commit comments

Comments
 (0)