Skip to content

Commit c8f0206

Browse files
committed
[optim][nadam] group tensors in foreach, make it default
[ghstack-poisoned]
1 parent a592359 commit c8f0206

File tree

2 files changed

+40
-35
lines changed

2 files changed

+40
-35
lines changed

torch/optim/nadam.py

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import torch
22
from torch import Tensor
33
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt, _stack_if_compiling,
4-
_differentiable_doc)
4+
_differentiable_doc, _foreach_doc, _default_to_foreach)
55
from typing import List, Optional
6+
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
67

78
__all__ = ['NAdam', 'nadam']
89

@@ -147,14 +148,13 @@ def step(self, closure=None):
147148
numerical stability (default: 1e-8)
148149
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
149150
momentum_decay (float, optional): momentum momentum_decay (default: 4e-3)
150-
foreach (bool, optional): whether foreach implementation of optimizer
151-
is used (default: None)
151+
{foreach}
152152
{differentiable}
153153
154154
.. _Incorporating Nesterov Momentum into Adam:
155155
https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ
156156
157-
""".format(differentiable=_differentiable_doc)
157+
""".format(foreach=_foreach_doc, differentiable=_differentiable_doc)
158158

159159

160160
def nadam(params: List[Tensor],
@@ -165,7 +165,7 @@ def nadam(params: List[Tensor],
165165
state_steps: List[Tensor],
166166
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
167167
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
168-
foreach: bool = None,
168+
foreach: Optional[bool] = None,
169169
differentiable: bool = False,
170170
*,
171171
beta1: float,
@@ -187,8 +187,8 @@ def nadam(params: List[Tensor],
187187
raise RuntimeError("API has changed, `mu_products` argument must contain a list of singleton tensors")
188188

189189
if foreach is None:
190-
# Placeholder for more complex foreach logic to be added when value is not set
191-
foreach = False
190+
foreach = _default_to_foreach([params, grads, exp_avgs, exp_avg_sqs,
191+
mu_products, state_steps], differentiable=differentiable)
192192

193193
if foreach and torch.jit.is_scripting():
194194
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
@@ -292,36 +292,41 @@ def _multi_tensor_nadam(params: List[Tensor],
292292

293293
assert not differentiable, "_foreach ops don't support autograd"
294294

295-
# update steps
296-
torch._foreach_add_(state_steps, 1)
295+
grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs,
296+
mu_products, state_steps])
297+
for (grouped_params, grouped_grads, grouped_exp_avgs,
298+
grouped_exp_avg_sqs, grouped_mu_products, grouped_state_steps) in grouped_tensors.values():
297299

298-
bias_correction2 = [1 - beta2 ** _get_value(step) for step in state_steps]
299-
mus = [beta1 * (1. - 0.5 * (0.96 ** (_get_value(step) * momentum_decay))) for step in state_steps]
300-
mu_nexts = [beta1 * (1. - 0.5 * (0.96 ** ((_get_value(step) + 1) * momentum_decay)))
301-
for step in state_steps]
300+
# update steps
301+
torch._foreach_add_(grouped_state_steps, 1)
302302

303-
# update mu_products
304-
torch._foreach_mul_(mu_products, mus)
303+
bias_correction2 = [1 - beta2 ** _get_value(step) for step in grouped_state_steps]
304+
mus = [beta1 * (1. - 0.5 * (0.96 ** (_get_value(step) * momentum_decay))) for step in grouped_state_steps]
305+
mu_nexts = [beta1 * (1. - 0.5 * (0.96 ** ((_get_value(step) + 1) * momentum_decay)))
306+
for step in grouped_state_steps]
305307

306-
if weight_decay != 0:
307-
grads = torch._foreach_add(grads, params, alpha=weight_decay)
308+
# update mu_products
309+
torch._foreach_mul_(grouped_mu_products, mus)
308310

309-
# Decay the first and second moment running average coefficient
310-
torch._foreach_mul_(exp_avgs, beta1)
311-
torch._foreach_add_(exp_avgs, grads, alpha=1 - beta1)
311+
if weight_decay != 0:
312+
grouped_grads = torch._foreach_add(grouped_grads, grouped_params, alpha=weight_decay)
313+
314+
# Decay the first and second moment running average coefficient
315+
torch._foreach_mul_(grouped_exp_avgs, beta1)
316+
torch._foreach_add_(grouped_exp_avgs, grouped_grads, alpha=1 - beta1)
312317

313-
torch._foreach_mul_(exp_avg_sqs, beta2)
314-
torch._foreach_addcmul_(exp_avg_sqs, grads, grads, 1 - beta2)
318+
torch._foreach_mul_(grouped_exp_avg_sqs, beta2)
319+
torch._foreach_addcmul_(grouped_exp_avg_sqs, grouped_grads, grouped_grads, 1 - beta2)
315320

316-
exp_avg_sq_sqrt = torch._foreach_sqrt(exp_avg_sqs)
317-
bias_correction_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2]
318-
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt)
319-
denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
321+
exp_avg_sq_sqrt = torch._foreach_sqrt(grouped_exp_avg_sqs)
322+
bias_correction_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2]
323+
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction_sqrt)
324+
denom = torch._foreach_add(exp_avg_sq_sqrt, eps)
320325

321-
step_size_grads = _stack_if_compiling([(lr * (1. - mu) / (1. - _get_value(mu_product))) * -1
322-
for mu_product, mu in zip(mu_products, mus)])
323-
step_size_expavg = _stack_if_compiling([(lr * mu_next / (1. - _get_value(mu_product) * mu_next)) * -1
324-
for mu_product, mu_next in zip(mu_products, mu_nexts)])
326+
step_size_grads = _stack_if_compiling([(lr * (1. - mu) / (1. - _get_value(mu_product))) * -1
327+
for mu_product, mu in zip(grouped_mu_products, mus)])
328+
step_size_expavg = _stack_if_compiling([(lr * mu_next / (1. - _get_value(mu_product) * mu_next)) * -1
329+
for mu_product, mu_next in zip(grouped_mu_products, mu_nexts)])
325330

326-
torch._foreach_addcdiv_(params, grads, denom, step_size_grads)
327-
torch._foreach_addcdiv_(params, exp_avgs, denom, step_size_expavg)
331+
torch._foreach_addcdiv_(grouped_params, grouped_grads, denom, step_size_grads)
332+
torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, denom, step_size_expavg)

torch/optim/optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,12 @@ def _dispatch_sqrt(x: float): # float annotation is needed because of torchscri
5959
# it is faster than the for-loop implementation. However, the foreach
6060
# implementation is not differentiable, so we must check differentiable=False.
6161
def _default_to_foreach(tensorlists: List[List[torch.Tensor]], differentiable: bool = False) -> bool:
62+
if torch.jit.is_scripting() or differentiable:
63+
return False
6264
all_tensors = []
6365
for tensorlist in tensorlists:
6466
all_tensors.extend(tensorlist)
65-
return not torch.jit.is_scripting() and not differentiable and all(
66-
p.is_cuda for p in all_tensors
67-
)
67+
return all(p.is_cuda for p in all_tensors)
6868

6969

7070
# Common doc strings among optimizers

0 commit comments

Comments
 (0)