Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3ab8494305810d3c943f670bc6b028514942c7a0
eac4e547138ab22a9b41c6f96208613fd7dd19d5
172 changes: 90 additions & 82 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from functools import partial
from collections import OrderedDict
from tempfile import NamedTemporaryFile
from unittest import SkipTest

import torch

Expand Down Expand Up @@ -1654,85 +1655,6 @@ def assign_weight():
# This should work though
l2.weight = Parameter(torch.randn(10, 10))

def test_clip_grad_norm(self):
l = nn.Linear(10, 10)
max_norm = 2

def compute_norm(norm_type):
norm_type = float(norm_type)
if norm_type != inf:
total_norm = 0
for p in l.parameters():
total_norm += p.grad.data.abs().pow(norm_type).sum()
return pow(total_norm, 1. / norm_type)
else:
return max(p.grad.data.abs().max() for p in l.parameters())

def compare_scaling(grads):
p_scale = [p.grad.data.div(g).view(-1) for p, g in zip(l.parameters(), grads)]
scale = torch.cat(p_scale)
self.assertEqual(scale.std(), 0)
return scale[0]

grads = torch.arange(1., 101).view(10, 10), torch.ones(10).div(1000)
for norm_type in [0.5, 1.5, 2, 4, 'inf']:
for p, g in zip(l.parameters(), grads):
p._grad = g.clone().view_as(p.data)
norm_before = compute_norm(norm_type)
norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type)
norm_after = compute_norm(norm_type)
self.assertEqual(norm, norm_before)
self.assertEqual(norm_after, max_norm)
self.assertLessEqual(norm_after, norm_before)
compare_scaling(grads)

# Small gradients should be left unchanged
grads = torch.rand(10, 10).div(10000), torch.ones(10).div(500)
for norm_type in [0.5, 1.5, 2, 4, 'inf']:
for p, g in zip(l.parameters(), grads):
p.grad.data.copy_(g)
norm_before = compute_norm(norm_type)
norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type)
norm_after = compute_norm(norm_type)
self.assertEqual(norm, norm_before)
self.assertEqual(norm_before, norm_after)
self.assertLessEqual(norm_after, max_norm)
scale = compare_scaling(grads)
self.assertEqual(scale, 1)

# Should accept a single Tensor as input
p1, p2 = torch.randn(10, 10), torch.randn(10, 10)
g = torch.arange(1., 101).view(10, 10)
p1._grad = g.clone()
p2._grad = g.clone()
for norm_type in [0.5, 1.5, 2, 4, 'inf']:
clip_grad_norm_(p1, max_norm, norm_type=norm_type)
clip_grad_norm_([p2], max_norm, norm_type=norm_type)
self.assertEqual(p1.grad, p2.grad)

def test_clip_grad_value(self):
l = nn.Linear(10, 10)
clip_value = 2.5

grad_w, grad_b = torch.arange(-50., 50).view(10, 10).div_(5), torch.ones(10).mul_(2)
for grad_list in [[grad_w, grad_b], [grad_w, None]]:
for p, g in zip(l.parameters(), grad_list):
p._grad = g.clone().view_as(p.data) if g is not None else g

clip_grad_value_(l.parameters(), clip_value)
for p in filter(lambda p: p.grad is not None, l.parameters()):
self.assertLessEqual(p.grad.data.max(), clip_value)
self.assertGreaterEqual(p.grad.data.min(), -clip_value)

# Should accept a single Tensor as input
p1, p2 = torch.randn(10, 10), torch.randn(10, 10)
g = torch.arange(-50., 50).view(10, 10).div_(5)
p1._grad = g.clone()
p2._grad = g.clone()
clip_grad_value_(p1, clip_value)
clip_grad_value_([p2], clip_value)
self.assertEqual(p1.grad, p2.grad)

def test_parameters_to_vector(self):
conv1 = nn.Conv2d(3, 10, 5)
fc1 = nn.Linear(10, 20)
Expand Down Expand Up @@ -11473,7 +11395,8 @@ def run_test_case(norm_type, error_if_nonfinite, scalar, grad_only_one_elem, pre

@onlyCUDA
@deviceCountAtLeast(2)
def test_clip_grad_norm_multi_device(self, devices):
@parametrize_test('foreach', (False, True))
def test_clip_grad_norm_multi_device(self, devices, foreach):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a concern with your PR, but I am realizing we never run this in CI because we only have one CI config where there is more than one GPU and we don't run this test in that config. 🤔 Filed #92173

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And then you would also need to add the ciflow/periodic label to get the multigpu tests to trigger.

class TestModel(nn.Module):
def __init__(self):
super(TestModel, self).__init__()
Expand All @@ -11489,8 +11412,8 @@ def __init__(self):
p.grad = torch.ones_like(p)
for p in ref_model.parameters():
p.grad = torch.ones_like(p)
norm = clip_grad_norm_(test_model.parameters(), 0.5, norm_type=norm_type)
expected = clip_grad_norm_(ref_model.parameters(), 0.5, norm_type=norm_type)
norm = clip_grad_norm_(test_model.parameters(), 0.5, norm_type=norm_type, foreach=foreach)
expected = clip_grad_norm_(ref_model.parameters(), 0.5, norm_type=norm_type, foreach=foreach)
self.assertEqual(norm, expected)
for p, pe in zip(test_model.parameters(), ref_model.parameters()):
self.assertEqual(p.grad.to(devices[0]), pe.grad)
Expand Down Expand Up @@ -12042,6 +11965,91 @@ def perm_fn(x):
with cm:
_test(activation=activation, batch_first=batch_first, training=training)

@parametrize_test('foreach', (False, True))
def test_clip_grad_value(self, foreach, device):
if torch.device(device).type == 'xla' and foreach:
raise SkipTest('foreach not supported on XLA')

l = nn.Linear(10, 10).to(device)
clip_value = 2.5

grad_w, grad_b = torch.arange(-50., 50, device=device).view(10, 10).div_(5), torch.ones(10, device=device).mul_(2)
for grad_list in [[grad_w, grad_b], [grad_w, None]]:
for p, g in zip(l.parameters(), grad_list):
p._grad = g.clone().view_as(p.data) if g is not None else g

clip_grad_value_(l.parameters(), clip_value, foreach=foreach)
for p in filter(lambda p: p.grad is not None, l.parameters()):
self.assertLessEqual(p.grad.data.max(), clip_value)
self.assertGreaterEqual(p.grad.data.min(), -clip_value)

# Should accept a single Tensor as input
p1, p2 = torch.randn(10, 10, device=device), torch.randn(10, 10, device=device)
g = torch.arange(-50., 50, device=device).view(10, 10).div_(5)
p1._grad = g.clone()
p2._grad = g.clone()
clip_grad_value_(p1, clip_value, foreach=foreach)
clip_grad_value_([p2], clip_value, foreach=foreach)
self.assertEqual(p1.grad, p2.grad)

@parametrize_test('foreach', (False, True))
@parametrize_test('norm_type', (0.5, 1.5, 2, 4, 'inf'))
def test_clip_grad_norm(self, norm_type, foreach, device):
if torch.device(device).type == 'xla' and foreach:
raise SkipTest('foreach not supported on XLA')

l = nn.Linear(10, 10).to(device)
max_norm = 2

def compute_norm(norm_type):
norm_type = float(norm_type)
if norm_type != inf:
total_norm = 0
for p in l.parameters():
total_norm += p.grad.data.abs().pow(norm_type).sum()
return pow(total_norm, 1. / norm_type)
else:
return max(p.grad.data.abs().max() for p in l.parameters())

def compare_scaling(grads):
p_scale = [p.grad.data.div(g).view(-1) for p, g in zip(l.parameters(), grads)]
scale = torch.cat(p_scale)
self.assertEqual(scale.std(), 0)
return scale[0]

grads = torch.arange(1., 101, device=device).view(10, 10), torch.ones(10, device=device).div(1000)
for p, g in zip(l.parameters(), grads):
p._grad = g.clone().view_as(p.data)
norm_before = compute_norm(norm_type)
norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type, foreach=foreach)
norm_after = compute_norm(norm_type)
self.assertEqual(norm, norm_before)
self.assertEqual(norm_after, max_norm)
self.assertLessEqual(norm_after, norm_before)
compare_scaling(grads)

# Small gradients should be left unchanged
grads = torch.rand(10, 10, device=device).div(10000), torch.ones(10, device=device).div(500)
for p, g in zip(l.parameters(), grads):
p.grad.data.copy_(g)
norm_before = compute_norm(norm_type)
norm = clip_grad_norm_(l.parameters(), max_norm, norm_type=norm_type, foreach=foreach)
norm_after = compute_norm(norm_type)
self.assertEqual(norm, norm_before)
self.assertEqual(norm_before, norm_after)
self.assertLessEqual(norm_after, max_norm)
scale = compare_scaling(grads)
self.assertEqual(scale, 1)

# Should accept a single Tensor as input
p1, p2 = torch.randn(10, 10, device=device), torch.randn(10, 10, device=device)
g = torch.arange(1., 101, device=device).view(10, 10)
p1._grad = g.clone()
p2._grad = g.clone()
clip_grad_norm_(p1, max_norm, norm_type=norm_type, foreach=foreach)
clip_grad_norm_([p2], max_norm, norm_type=norm_type, foreach=foreach)
self.assertEqual(p1.grad, p2.grad)


class TestFunctionalPickle(TestCase):

Expand Down
68 changes: 56 additions & 12 deletions torch/nn/utils/clip_grad.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import warnings
from typing import Union, Iterable, List, Dict, Tuple, Optional

import torch
from torch import Tensor
from torch._six import inf
from typing import Union, Iterable
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]

__all__ = ['clip_grad_norm_', 'clip_grad_norm', 'clip_grad_value_']

def clip_grad_norm_(
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
error_if_nonfinite: bool = False) -> torch.Tensor:
error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor:
r"""Clips gradient norm of an iterable of parameters.

The norm is computed over all gradients together, as if they were
Expand All @@ -24,6 +27,10 @@ def clip_grad_norm_(
error_if_nonfinite (bool): if True, an error is thrown if the total
norm of the gradients from :attr:`parameters` is ``nan``,
``inf``, or ``-inf``. Default: False (will switch to True in the future)
foreach (bool): use the faster foreach-based implementation.
If ``None``, use the foreach implementation for CUDA and CPU tensors and silently fall back to the slow
implementation for other device types.
Default: ``None``

Returns:
Total norm of the parameter gradients (viewed as a single vector).
Expand All @@ -35,12 +42,25 @@ def clip_grad_norm_(
norm_type = float(norm_type)
if len(grads) == 0:
return torch.tensor(0.)
device = grads[0].device
first_device = grads[0].device
grouped_grads: Dict[Tuple[torch.device, torch.dtype], List[List[Tensor]]] \
= _group_tensors_by_device_and_dtype([[g.detach() for g in grads]]) # type: ignore[assignment]

if norm_type == inf:
norms = [g.detach().abs().max().to(device) for g in grads]
norms = [g.detach().abs().max().to(first_device) for g in grads]
total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
else:
total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
norms = []
for ((device, _), [grads]) in grouped_grads.items():
if (foreach is None or foreach) and device.type in {'cpu', 'cuda'}:
norms.extend(torch._foreach_norm(grads, norm_type))
elif foreach:
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
else:
norms.extend([torch.norm(g, norm_type) for g in grads])

total_norm = torch.norm(torch.stack([norm.to(first_device) for norm in norms]), norm_type)

if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
raise RuntimeError(
f'The total norm of order {norm_type} for gradients from '
Expand All @@ -52,14 +72,22 @@ def clip_grad_norm_(
# avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
# when the gradients do not reside in CPU memory.
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
for g in grads:
g.detach().mul_(clip_coef_clamped.to(g.device))
for ((device, _), [grads]) in grouped_grads.items():
if (foreach is None or foreach) and device.type in ('cpu', 'cuda'):
torch._foreach_mul_(grads, clip_coef_clamped.to(device)) # type: ignore[call-overload]
elif foreach:
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
else:
clip_coef_clamped_device = clip_coef_clamped.to(device)
for g in grads:
g.detach().mul_(clip_coef_clamped_device)

return total_norm


def clip_grad_norm(
parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.,
error_if_nonfinite: bool = False) -> torch.Tensor:
error_if_nonfinite: bool = False, foreach: Optional[bool] = None) -> torch.Tensor:
r"""Clips gradient norm of an iterable of parameters.

.. warning::
Expand All @@ -68,10 +96,10 @@ def clip_grad_norm(
"""
warnings.warn("torch.nn.utils.clip_grad_norm is now deprecated in favor "
"of torch.nn.utils.clip_grad_norm_.", stacklevel=2)
return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite)
return clip_grad_norm_(parameters, max_norm, norm_type, error_if_nonfinite, foreach)


def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float) -> None:
def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float, foreach: Optional[bool] = None) -> None:
r"""Clips gradient of an iterable of parameters at specified value.

Gradients are modified in-place.
Expand All @@ -82,9 +110,25 @@ def clip_grad_value_(parameters: _tensor_or_tensors, clip_value: float) -> None:
clip_value (float): maximum allowed value of the gradients.
The gradients are clipped in the range
:math:`\left[\text{-clip\_value}, \text{clip\_value}\right]`
foreach (bool): use the faster foreach-based implementation
If ``None``, use the foreach implementation for CUDA and CPU tensors and silently fall back to the slow
implementation for other device types.
Default: ``None``
"""
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
clip_value = float(clip_value)
for p in filter(lambda p: p.grad is not None, parameters):
p.grad.data.clamp_(min=-clip_value, max=clip_value)

grads = [p.grad for p in parameters if p.grad is not None]
grouped_grads: Dict[Tuple[torch.device, torch.dtype], List[List[Tensor]]] \
= _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment]

for ((device, _), [grads]) in grouped_grads.items():
if (foreach is None or foreach) and device.type in {'cpu', 'cuda'}:
torch._foreach_clamp_min_(grads, -clip_value)
torch._foreach_clamp_max_(grads, clip_value)
elif foreach:
raise RuntimeError(f'foreach=True was passed, but can\'t use the foreach API on {device.type} tensors')
else:
for grad in grads:
grad.data.clamp_(min=-clip_value, max=clip_value)
4 changes: 2 additions & 2 deletions torch/optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,8 @@ def _fused_adam(
device_state_steps,
) = grouped_tensors[(device, dtype)]
if grad_scale is not None and found_inf is not None:
device_grad_scale = grad_scale.get(device)
device_found_inf = found_inf.get(device)
device_grad_scale = grad_scale.get(str(device))
device_found_inf = found_inf.get(str(device))
else:
device_grad_scale = None
device_found_inf = None
Expand Down
14 changes: 7 additions & 7 deletions torch/utils/_foreach_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import torch
from torch import Tensor
from torch.autograd.grad_mode import no_grad


# _group_tensors_by_device_and_dtype is a util function that splits tensors into groups by device and dtype,
# which is useful before sending tensors off to a foreach implementation, which requires tensors to be on
# one device and dtype.
# This util function splits tensors into groups by device and dtype, which is useful before sending
# tensors off to a foreach implementation, which requires tensors to be on one device and dtype.
# If tensorlistlist contains more than one tensorlist, the following assumptions are made BUT NOT verified:
# - tensorlists CAN be None
# - all tensors in the first specified list cannot be None
Expand All @@ -17,16 +17,16 @@
# Whereas mutating a tensor in the resulting split-up tensorlists WILL propagate changes back to the
# original input tensorlists, changing up Nones/literals WILL NOT propagate, and manual propagation
# may be necessary. Check out torch/optim/sgd.py for an example.
@torch.no_grad()
@no_grad()
def _group_tensors_by_device_and_dtype(tensorlistlist: List[List[Tensor]],
with_indices: Optional[bool] = False) -> \
Dict[Tuple[str, torch.dtype], List[List[Union[Tensor, int]]]]:
Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]]:
assert all([not x or len(x) == len(tensorlistlist[0]) for x in tensorlistlist]), (
"all specified tensorlists must match in length")
per_device_and_dtype_tensors: Dict[Tuple[str, torch.dtype], List[List[Union[Tensor, int]]]] = defaultdict(
per_device_and_dtype_tensors: Dict[Tuple[torch.device, torch.dtype], List[List[Union[Tensor, int]]]] = defaultdict(
lambda: [[] for _ in range(len(tensorlistlist) + (1 if with_indices else 0))])
for i, t in enumerate(tensorlistlist[0]):
key = (str(t.device), t.dtype)
key = (t.device, t.dtype)
for j in range(len(tensorlistlist)):
# a tensorlist may be empty/None
if tensorlistlist[j]:
Expand Down