11import torch
22from torch import Tensor
33from .optimizer import (Optimizer , _use_grad_for_differentiable , _get_value , _dispatch_sqrt , _stack_if_compiling ,
4- _differentiable_doc )
4+ _differentiable_doc , _foreach_doc , _default_to_foreach )
55from typing import List , Optional
6+ from torch .utils ._foreach_utils import _group_tensors_by_device_and_dtype
67
78__all__ = ['NAdam' , 'nadam' ]
89
@@ -147,14 +148,13 @@ def step(self, closure=None):
147148 numerical stability (default: 1e-8)
148149 weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
149150 momentum_decay (float, optional): momentum momentum_decay (default: 4e-3)
150- foreach (bool, optional): whether foreach implementation of optimizer
151- is used (default: None)
151+ {foreach}
152152 {differentiable}
153153
154154 .. _Incorporating Nesterov Momentum into Adam:
155155 https://openreview.net/forum?id=OM0jvwB8jIp57ZJjtNEZ
156156
157- """ .format (differentiable = _differentiable_doc )
157+ """ .format (foreach = _foreach_doc , differentiable = _differentiable_doc )
158158
159159
160160def nadam (params : List [Tensor ],
@@ -165,7 +165,7 @@ def nadam(params: List[Tensor],
165165 state_steps : List [Tensor ],
166166 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
167167 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
168- foreach : bool = None ,
168+ foreach : Optional [ bool ] = None ,
169169 differentiable : bool = False ,
170170 * ,
171171 beta1 : float ,
@@ -187,8 +187,8 @@ def nadam(params: List[Tensor],
187187 raise RuntimeError ("API has changed, `mu_products` argument must contain a list of singleton tensors" )
188188
189189 if foreach is None :
190- # Placeholder for more complex foreach logic to be added when value is not set
191- foreach = False
190+ foreach = _default_to_foreach ([ params , grads , exp_avgs , exp_avg_sqs ,
191+ mu_products , state_steps ], differentiable = differentiable )
192192
193193 if foreach and torch .jit .is_scripting ():
194194 raise RuntimeError ('torch.jit.script not supported with foreach optimizers' )
@@ -292,36 +292,41 @@ def _multi_tensor_nadam(params: List[Tensor],
292292
293293 assert not differentiable , "_foreach ops don't support autograd"
294294
295- # update steps
296- torch ._foreach_add_ (state_steps , 1 )
295+ grouped_tensors = _group_tensors_by_device_and_dtype ([params , grads , exp_avgs , exp_avg_sqs ,
296+ mu_products , state_steps ])
297+ for (grouped_params , grouped_grads , grouped_exp_avgs ,
298+ grouped_exp_avg_sqs , grouped_mu_products , grouped_state_steps ) in grouped_tensors .values ():
297299
298- bias_correction2 = [1 - beta2 ** _get_value (step ) for step in state_steps ]
299- mus = [beta1 * (1. - 0.5 * (0.96 ** (_get_value (step ) * momentum_decay ))) for step in state_steps ]
300- mu_nexts = [beta1 * (1. - 0.5 * (0.96 ** ((_get_value (step ) + 1 ) * momentum_decay )))
301- for step in state_steps ]
300+ # update steps
301+ torch ._foreach_add_ (grouped_state_steps , 1 )
302302
303- # update mu_products
304- torch ._foreach_mul_ (mu_products , mus )
303+ bias_correction2 = [1 - beta2 ** _get_value (step ) for step in grouped_state_steps ]
304+ mus = [beta1 * (1. - 0.5 * (0.96 ** (_get_value (step ) * momentum_decay ))) for step in grouped_state_steps ]
305+ mu_nexts = [beta1 * (1. - 0.5 * (0.96 ** ((_get_value (step ) + 1 ) * momentum_decay )))
306+ for step in grouped_state_steps ]
305307
306- if weight_decay != 0 :
307- grads = torch ._foreach_add ( grads , params , alpha = weight_decay )
308+ # update mu_products
309+ torch ._foreach_mul_ ( grouped_mu_products , mus )
308310
309- # Decay the first and second moment running average coefficient
310- torch ._foreach_mul_ (exp_avgs , beta1 )
311- torch ._foreach_add_ (exp_avgs , grads , alpha = 1 - beta1 )
311+ if weight_decay != 0 :
312+ grouped_grads = torch ._foreach_add (grouped_grads , grouped_params , alpha = weight_decay )
313+
314+ # Decay the first and second moment running average coefficient
315+ torch ._foreach_mul_ (grouped_exp_avgs , beta1 )
316+ torch ._foreach_add_ (grouped_exp_avgs , grouped_grads , alpha = 1 - beta1 )
312317
313- torch ._foreach_mul_ (exp_avg_sqs , beta2 )
314- torch ._foreach_addcmul_ (exp_avg_sqs , grads , grads , 1 - beta2 )
318+ torch ._foreach_mul_ (grouped_exp_avg_sqs , beta2 )
319+ torch ._foreach_addcmul_ (grouped_exp_avg_sqs , grouped_grads , grouped_grads , 1 - beta2 )
315320
316- exp_avg_sq_sqrt = torch ._foreach_sqrt (exp_avg_sqs )
317- bias_correction_sqrt = [_dispatch_sqrt (bc ) for bc in bias_correction2 ]
318- torch ._foreach_div_ (exp_avg_sq_sqrt , bias_correction_sqrt )
319- denom = torch ._foreach_add (exp_avg_sq_sqrt , eps )
321+ exp_avg_sq_sqrt = torch ._foreach_sqrt (grouped_exp_avg_sqs )
322+ bias_correction_sqrt = [_dispatch_sqrt (bc ) for bc in bias_correction2 ]
323+ torch ._foreach_div_ (exp_avg_sq_sqrt , bias_correction_sqrt )
324+ denom = torch ._foreach_add (exp_avg_sq_sqrt , eps )
320325
321- step_size_grads = _stack_if_compiling ([(lr * (1. - mu ) / (1. - _get_value (mu_product ))) * - 1
322- for mu_product , mu in zip (mu_products , mus )])
323- step_size_expavg = _stack_if_compiling ([(lr * mu_next / (1. - _get_value (mu_product ) * mu_next )) * - 1
324- for mu_product , mu_next in zip (mu_products , mu_nexts )])
326+ step_size_grads = _stack_if_compiling ([(lr * (1. - mu ) / (1. - _get_value (mu_product ))) * - 1
327+ for mu_product , mu in zip (grouped_mu_products , mus )])
328+ step_size_expavg = _stack_if_compiling ([(lr * mu_next / (1. - _get_value (mu_product ) * mu_next )) * - 1
329+ for mu_product , mu_next in zip (grouped_mu_products , mu_nexts )])
325330
326- torch ._foreach_addcdiv_ (params , grads , denom , step_size_grads )
327- torch ._foreach_addcdiv_ (params , exp_avgs , denom , step_size_expavg )
331+ torch ._foreach_addcdiv_ (grouped_params , grouped_grads , denom , step_size_grads )
332+ torch ._foreach_addcdiv_ (grouped_params , grouped_exp_avgs , denom , step_size_expavg )
0 commit comments