Skip to content

Commit d490eb6

Browse files
committed
[optim][adamw] default to foreach when CUDA + differentiable=False
ghstack-source-id: fc22c8b Pull Request resolved: #92306
1 parent c37b451 commit d490eb6

File tree

1 file changed

+92
-81
lines changed

1 file changed

+92
-81
lines changed

torch/optim/adamw.py

Lines changed: 92 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import torch
22
from torch import Tensor
3-
from .optimizer import Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling
3+
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt,
4+
_stack_if_compiling, _default_to_foreach)
45
from typing import List, Optional
6+
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
57

68
__all__ = ["AdamW", "adamw"]
79

@@ -60,8 +62,10 @@ class AdamW(Optimizer):
6062
(default: False)
6163
maximize (bool, optional): maximize the params based on the objective, instead of
6264
minimizing (default: False)
63-
foreach (bool, optional): whether foreach implementation of optimizer
64-
is used (default: None)
65+
foreach (bool, optional): whether foreach implementation of optimizer is used.
66+
If unspecified by the user (so foreach is None), we will try to use foreach
67+
over the for-loop implementation on CUDA, since it is usually significantly
68+
more performant. (default: None)
6569
capturable (bool, optional): whether this instance is safe to capture in a CUDA graph.
6670
Passing True can impair ungraphed performance, so if you don't intend to
6771
graph capture this instance, leave it False (default: False)
@@ -223,7 +227,7 @@ def adamw(
223227
state_steps: List[Tensor],
224228
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
225229
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
226-
foreach: bool = None,
230+
foreach: Optional[bool] = None,
227231
capturable: bool = False,
228232
differentiable: bool = False,
229233
*,
@@ -245,9 +249,11 @@ def adamw(
245249
"API has changed, `state_steps` argument must contain a list of singleton tensors"
246250
)
247251

252+
# Respect when the user inputs False/True for foreach.
248253
if foreach is None:
249-
# Placeholder for more complex foreach logic to be added when value is not set
250-
foreach = False
254+
foreach = _default_to_foreach(
255+
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps],
256+
differentiable=differentiable)
251257

252258
if foreach and torch.jit.is_scripting():
253259
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
@@ -401,92 +407,97 @@ def _multi_tensor_adamw(
401407
p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)
402408
), "If capturable=True, params and state_steps must be CUDA tensors."
403409

404-
if maximize:
405-
grads = torch._foreach_neg(tuple(grads)) # type: ignore[assignment]
406-
407410
assert not differentiable, "_foreach ops don't support autograd"
408411

409-
grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in grads]
410-
exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avgs]
411-
exp_avg_sqs = [
412-
torch.view_as_real(x) if torch.is_complex(x) else x for x in exp_avg_sqs
413-
]
414-
params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in params]
412+
grouped_tensors = _group_tensors_by_device_and_dtype([
413+
params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps])
414+
for (device_params, device_grads, device_exp_avgs, device_exp_avg_sqs,
415+
device_max_exp_avg_sqs, device_state_steps) in grouped_tensors.values():
416+
if maximize:
417+
device_grads = torch._foreach_neg(tuple(device_grads)) # type: ignore[assignment]
415418

416-
# update steps
417-
torch._foreach_add_(state_steps, 1)
418419

419-
# Perform stepweight decay
420-
torch._foreach_mul_(params, 1 - lr * weight_decay)
420+
device_grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_grads]
421+
device_exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_exp_avgs]
422+
device_exp_avg_sqs = [
423+
torch.view_as_real(x) if torch.is_complex(x) else x for x in device_exp_avg_sqs
424+
]
425+
device_params = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_params]
421426

422-
# Decay the first and second moment running average coefficient
423-
torch._foreach_mul_(exp_avgs, beta1)
424-
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
427+
# update steps
428+
torch._foreach_add_(device_state_steps, 1)
425429

426-
torch._foreach_mul_(exp_avg_sqs, beta2)
427-
torch._foreach_addcmul_(exp_avg_sqs, grads, grads, 1 - beta2)
430+
# Perform stepweight decay
431+
torch._foreach_mul_(device_params, 1 - lr * weight_decay)
428432

429-
if capturable:
430-
# TODO: use foreach_pow if/when foreach_pow is added
431-
bias_correction1 = [torch.pow(beta1, step) for step in state_steps]
432-
bias_correction2 = [torch.pow(beta2, step) for step in state_steps]
433-
# foreach_sub doesn't allow a scalar as the first arg
434-
torch._foreach_sub_(bias_correction1, 1)
435-
torch._foreach_sub_(bias_correction2, 1)
436-
torch._foreach_neg_(bias_correction1)
437-
torch._foreach_neg_(bias_correction2)
438-
439-
# foreach_div doesn't allow a scalar as the first arg
440-
step_size = torch._foreach_div(bias_correction1, lr)
441-
torch._foreach_reciprocal_(step_size)
442-
torch._foreach_neg_(step_size)
443-
444-
bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2)
445-
446-
if amsgrad:
447-
# Maintains the maximum of all 2nd moment running avg. till now
448-
torch._foreach_maximum_(max_exp_avg_sqs, exp_avg_sqs)
449-
450-
# Use the max. for normalizing running avg. of gradient
451-
max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sqs)
452-
# Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
453-
# (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
454-
torch._foreach_div_(
455-
max_exp_avg_sq_sqrt,
456-
torch._foreach_mul(bias_correction2_sqrt, step_size),
457-
)
458-
eps_over_step_size = torch._foreach_div(step_size, eps)
459-
torch._foreach_reciprocal_(eps_over_step_size)
460-
denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps_over_step_size)
461-
else:
462-
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
463-
torch._foreach_div_(
464-
exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)
465-
)
466-
eps_over_step_size = torch._foreach_div(step_size, eps)
467-
torch._foreach_reciprocal_(eps_over_step_size)
468-
denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size)
433+
# Decay the first and second moment running average coefficient
434+
torch._foreach_mul_(device_exp_avgs, beta1)
435+
torch._foreach_add_(device_exp_avgs, device_grads, alpha=1 - beta1)
469436

470-
torch._foreach_addcdiv_(params, exp_avgs, denom)
471-
else:
472-
bias_correction1 = [1 - beta1 ** _get_value(step) for step in state_steps]
473-
bias_correction2 = [1 - beta2 ** _get_value(step) for step in state_steps]
437+
torch._foreach_mul_(device_exp_avg_sqs, beta2)
438+
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, 1 - beta2)
474439

475-
step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
440+
if capturable:
441+
# TODO: use foreach_pow if/when foreach_pow is added
442+
bias_correction1 = [torch.pow(beta1, step) for step in device_state_steps]
443+
bias_correction2 = [torch.pow(beta2, step) for step in device_state_steps]
444+
# foreach_sub doesn't allow a scalar as the first arg
445+
torch._foreach_sub_(bias_correction1, 1)
446+
torch._foreach_sub_(bias_correction2, 1)
447+
torch._foreach_neg_(bias_correction1)
448+
torch._foreach_neg_(bias_correction2)
449+
450+
# foreach_div doesn't allow a scalar as the first arg
451+
step_size = torch._foreach_div(bias_correction1, lr)
452+
torch._foreach_reciprocal_(step_size)
453+
torch._foreach_neg_(step_size)
454+
455+
bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2)
476456

477-
bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2]
457+
if amsgrad:
458+
# Maintains the maximum of all 2nd moment running avg. till now
459+
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
478460

479-
if amsgrad:
480-
# Maintains the maximum of all 2nd moment running avg. till now
481-
torch._foreach_maximum_(max_exp_avg_sqs, exp_avg_sqs)
461+
# Use the max. for normalizing running avg. of gradient
462+
max_exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
463+
# Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
464+
# (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
465+
torch._foreach_div_(
466+
max_exp_avg_sq_sqrt,
467+
torch._foreach_mul(bias_correction2_sqrt, step_size),
468+
)
469+
eps_over_step_size = torch._foreach_div(step_size, eps)
470+
torch._foreach_reciprocal_(eps_over_step_size)
471+
denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps_over_step_size)
472+
else:
473+
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
474+
torch._foreach_div_(
475+
exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)
476+
)
477+
eps_over_step_size = torch._foreach_div(step_size, eps)
478+
torch._foreach_reciprocal_(eps_over_step_size)
479+
denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size)
482480

483-
# Use the max. for normalizing running avg. of gradient
484-
max_exp_avg_sq_sqrt = torch._foreach_sqrt(max_exp_avg_sqs)
485-
torch._foreach_div_(max_exp_avg_sq_sqrt, bias_correction2_sqrt)
486-
denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps)
481+
torch._foreach_addcdiv_(device_params, device_exp_avgs, denom)
487482
else:
488-
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
489-
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
490-
denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
483+
bias_correction1 = [1 - beta1 ** _get_value(step) for step in device_state_steps]
484+
bias_correction2 = [1 - beta2 ** _get_value(step) for step in device_state_steps]
485+
486+
step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
487+
488+
bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2]
489+
490+
if amsgrad:
491+
# Maintains the maximum of all 2nd moment running avg. till now
492+
torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
493+
494+
# Use the max. for normalizing running avg. of gradient
495+
max_exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
496+
torch._foreach_div_(max_exp_avg_sq_sqrt, bias_correction2_sqrt)
497+
denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps)
498+
else:
499+
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
500+
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
501+
denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
491502

492-
torch._foreach_addcdiv_(params, exp_avgs, denom, step_size)
503+
torch._foreach_addcdiv_(device_params, device_exp_avgs, denom, step_size)

0 commit comments

Comments
 (0)