Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
173 changes: 92 additions & 81 deletions torch/optim/adamw.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
from torch import Tensor
from .optimizer import Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt,
_stack_if_compiling, _default_to_foreach)
from typing import List, Optional
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype

__all__ = ["AdamW", "adamw"]

Expand Down Expand Up @@ -60,8 +62,10 @@ class AdamW(Optimizer):
(default: False)
maximize (bool, optional): maximize the params based on the objective, instead of
minimizing (default: False)
foreach (bool, optional): whether foreach implementation of optimizer
is used (default: None)
foreach (bool, optional): whether foreach implementation of optimizer is used.
If unspecified by the user (so foreach is None), we will try to use foreach
over the for-loop implementation on CUDA, since it is usually significantly
more performant. (default: None)
capturable (bool, optional): whether this instance is safe to capture in a CUDA graph.
Passing True can impair ungraphed performance, so if you don't intend to
graph capture this instance, leave it False (default: False)
Expand Down Expand Up @@ -223,7 +227,7 @@ def adamw(
state_steps: List[Tensor],
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
foreach: bool = None,
foreach: Optional[bool] = None,
capturable: bool = False,
differentiable: bool = False,
*,
Expand All @@ -245,9 +249,11 @@ def adamw(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)

# Respect when the user inputs False/True for foreach.
if foreach is None:
# Placeholder for more complex foreach logic to be added when value is not set
foreach = False
foreach = _default_to_foreach(
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps],
differentiable=differentiable)

if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
Expand Down Expand Up @@ -401,92 +407,97 @@ def _multi_tensor_adamw(
p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)
), "If capturable=True, params and state_steps must be CUDA tensors."

if maximize:
grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment]

assert not differentiable, "_foreach ops don't support autograd"

grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads]
exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs]
exp_avg_sqs = [
torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avg_sqs
]
params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params]
grouped_tensors = _group_tensors_by_device_and_dtype([
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
for (device_params, device_grads, device_exp_avgs, device_exp_avg_sqs,
device_max_exp_avg_sqs, device_state_steps) in grouped_tensors.values():
if maximize:
device_grads = torch._foreach_neg(tuple(device_grads)) # type: ignore[assignment]

# update steps
torch._foreach_add_(state_steps, 1)

# Perform stepweight decay
torch._foreach_mul_(params, 1 - lr * weight_decay)
device_grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_grads]
device_exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_exp_avgs]
device_exp_avg_sqs = [
torch.view_as_real(x) if torch.is_complex(x) else x for x in device_exp_avg_sqs
]
device_params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_params]

# Decay the first and second moment running average coefficient
torch._foreach_mul_(exp_avgs, beta1)
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
# update steps
torch._foreach_add_(device_state_steps, 1)

torch._foreach_mul_(exp_avg_sqs, beta2)
torch._foreach_addcmul_(exp_avg_sqs, grads, grads, 1 - beta2)
# Perform stepweight decay
torch._foreach_mul_(device_params, 1 - lr * weight_decay)

if capturable:
# TODO: use foreach_pow if/when foreach_pow is added
bias_correction1 = [torch.pow(beta1, step) for step in state_steps]
bias_correction2 = [torch.pow(beta2, step) for step in state_steps]
# foreach_sub doesn't allow a scalar as the first arg
torch._foreach_sub_(bias_correction1, 1)
torch._foreach_sub_(bias_correction2, 1)
torch._foreach_neg_(bias_correction1)
torch._foreach_neg_(bias_correction2)

# foreach_div doesn't allow a scalar as the first arg
step_size = torch._foreach_div(bias_correction1, lr)
torch._foreach_reciprocal_(step_size)
torch._foreach_neg_(step_size)

bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2)

if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch._foreach_maximum_(max_exp_avg_sqs, exp_avg_sqs)

# Use the max. for normalizing running avg. of gradient
max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sqs)
# Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
# (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
torch._foreach_div_(
max_exp_avg_sq_sqrt,
torch._foreach_mul(bias_correction2_sqrt, step_size),
)
eps_over_step_size = torch._foreach_div(step_size, eps)
torch._foreach_reciprocal_(eps_over_step_size)
denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps_over_step_size)
else:
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
torch._foreach_div_(
exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)
)
eps_over_step_size = torch._foreach_div(step_size, eps)
torch._foreach_reciprocal_(eps_over_step_size)
denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size)
# Decay the first and second moment running average coefficient
torch._foreach_mul_(device_exp_avgs, beta1)
torch._foreach_add_(device_exp_avgs, device_grads, alpha=1 - beta1)

torch._foreach_addcdiv_(params, exp_avgs, denom)
else:
bias_correction1 = [1 - beta1 ** _get_value(step) for step in state_steps]
bias_correction2 = [1 - beta2 ** _get_value(step) for step in state_steps]
torch._foreach_mul_(device_exp_avg_sqs, beta2)
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, 1 - beta2)

step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
if capturable:
# TODO: use foreach_pow if/when foreach_pow is added
bias_correction1 = [torch.pow(beta1, step) for step in device_state_steps]
bias_correction2 = [torch.pow(beta2, step) for step in device_state_steps]
# foreach_sub doesn't allow a scalar as the first arg
torch._foreach_sub_(bias_correction1, 1)
torch._foreach_sub_(bias_correction2, 1)
torch._foreach_neg_(bias_correction1)
torch._foreach_neg_(bias_correction2)

# foreach_div doesn't allow a scalar as the first arg
step_size = torch._foreach_div(bias_correction1, lr)
torch._foreach_reciprocal_(step_size)
torch._foreach_neg_(step_size)

bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2)

bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2]
if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)

if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch._foreach_maximum_(max_exp_avg_sqs, exp_avg_sqs)
# Use the max. for normalizing running avg. of gradient
max_exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
# Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
# (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
torch._foreach_div_(
max_exp_avg_sq_sqrt,
torch._foreach_mul(bias_correction2_sqrt, step_size),
)
eps_over_step_size = torch._foreach_div(step_size, eps)
torch._foreach_reciprocal_(eps_over_step_size)
denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps_over_step_size)
else:
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch._foreach_div_(
exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)
)
eps_over_step_size = torch._foreach_div(step_size, eps)
torch._foreach_reciprocal_(eps_over_step_size)
denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size)

# Use the max. for normalizing running avg. of gradient
max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sqs)
torch._foreach_div_(max_exp_avg_sq_sqrt, bias_correction2_sqrt)
denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps)
torch._foreach_addcdiv_(device_params, device_exp_avgs, denom)
else:
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
bias_correction1 = [1 - beta1 ** _get_value(step) for step in device_state_steps]
bias_correction2 = [1 - beta2 ** _get_value(step) for step in device_state_steps]

step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])

bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2]

if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)

# Use the max. for normalizing running avg. of gradient
max_exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
torch._foreach_div_(max_exp_avg_sq_sqrt, bias_correction2_sqrt)
denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps)
else:
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
denom = torch._foreach_add(exp_avg_sq_sqrt, eps)

torch._foreach_addcdiv_(params, exp_avgs, denom, step_size)
torch._foreach_addcdiv_(device_params, device_exp_avgs, denom, step_size)