11import torch
22from torch import Tensor
3- from .optimizer import Optimizer , _use_grad_for_differentiable , _get_value , _dispatch_sqrt , _stack_if_compiling
3+ from .optimizer import (Optimizer , _use_grad_for_differentiable , _get_value , _dispatch_sqrt ,
4+ _stack_if_compiling , _default_to_foreach )
45from typing import List , Optional
6+ from torch .utils ._foreach_utils import _group_tensors_by_device_and_dtype
57
68__all__ = ["AdamW" , "adamw" ]
79
@@ -60,8 +62,10 @@ class AdamW(Optimizer):
6062 (default: False)
6163 maximize (bool, optional): maximize the params based on the objective, instead of
6264 minimizing (default: False)
63- foreach (bool, optional): whether foreach implementation of optimizer
64- is used (default: None)
65+ foreach (bool, optional): whether foreach implementation of optimizer is used.
66+ If unspecified by the user (so foreach is None), we will try to use foreach
67+ over the for-loop implementation on CUDA, since it is usually significantly
68+ more performant. (default: None)
6569 capturable (bool, optional): whether this instance is safe to capture in a CUDA graph.
6670 Passing True can impair ungraphed performance, so if you don't intend to
6771 graph capture this instance, leave it False (default: False)
@@ -223,7 +227,7 @@ def adamw(
223227 state_steps : List [Tensor ],
224228 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
225229 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
226- foreach : bool = None ,
230+ foreach : Optional [ bool ] = None ,
227231 capturable : bool = False ,
228232 differentiable : bool = False ,
229233 * ,
@@ -245,9 +249,11 @@ def adamw(
245249 "API has changed, `state_steps` argument must contain a list of singleton tensors"
246250 )
247251
252+ # Respect when the user inputs False/True for foreach.
248253 if foreach is None :
249- # Placeholder for more complex foreach logic to be added when value is not set
250- foreach = False
254+ foreach = _default_to_foreach (
255+ [params , grads , exp_avgs , exp_avg_sqs , max_exp_avg_sqs , state_steps ],
256+ differentiable = differentiable )
251257
252258 if foreach and torch .jit .is_scripting ():
253259 raise RuntimeError ("torch.jit.script not supported with foreach optimizers" )
@@ -401,92 +407,97 @@ def _multi_tensor_adamw(
401407 p .is_cuda and step .is_cuda for p , step in zip (params , state_steps )
402408 ), "If capturable=True, params and state_steps must be CUDA tensors."
403409
404- if maximize :
405- grads = torch ._foreach_neg (tuple (grads )) # type: ignore[assignment]
406-
407410 assert not differentiable , "_foreach ops don't support autograd"
408411
409- grads = [ torch . view_as_real ( x ) if torch . is_complex ( x ) else x for x in grads ]
410- exp_avgs = [ torch . view_as_real ( x ) if torch . is_complex ( x ) else x for x in exp_avgs ]
411- exp_avg_sqs = [
412- torch . view_as_real ( x ) if torch . is_complex ( x ) else x for x in exp_avg_sqs
413- ]
414- params = [ torch . view_as_real ( x ) if torch .is_complex ( x ) else x for x in params ]
412+ grouped_tensors = _group_tensors_by_device_and_dtype ([
413+ params , grads , exp_avgs , exp_avg_sqs , max_exp_avg_sqs , state_steps ])
414+ for ( device_params , device_grads , device_exp_avgs , device_exp_avg_sqs ,
415+ device_max_exp_avg_sqs , device_state_steps ) in grouped_tensors . values ():
416+ if maximize :
417+ device_grads = torch ._foreach_neg ( tuple ( device_grads )) # type: ignore[assignment ]
415418
416- # update steps
417- torch ._foreach_add_ (state_steps , 1 )
418419
419- # Perform stepweight decay
420- torch ._foreach_mul_ (params , 1 - lr * weight_decay )
420+ device_grads = [torch .view_as_real (x ) if torch .is_complex (x ) else x for x in device_grads ]
421+ device_exp_avgs = [torch .view_as_real (x ) if torch .is_complex (x ) else x for x in device_exp_avgs ]
422+ device_exp_avg_sqs = [
423+ torch .view_as_real (x ) if torch .is_complex (x ) else x for x in device_exp_avg_sqs
424+ ]
425+ device_params = [torch .view_as_real (x ) if torch .is_complex (x ) else x for x in device_params ]
421426
422- # Decay the first and second moment running average coefficient
423- torch ._foreach_mul_ (exp_avgs , beta1 )
424- torch ._foreach_add_ (exp_avgs , grads , alpha = 1 - beta1 )
427+ # update steps
428+ torch ._foreach_add_ (device_state_steps , 1 )
425429
426- torch . _foreach_mul_ ( exp_avg_sqs , beta2 )
427- torch ._foreach_addcmul_ ( exp_avg_sqs , grads , grads , 1 - beta2 )
430+ # Perform stepweight decay
431+ torch ._foreach_mul_ ( device_params , 1 - lr * weight_decay )
428432
429- if capturable :
430- # TODO: use foreach_pow if/when foreach_pow is added
431- bias_correction1 = [torch .pow (beta1 , step ) for step in state_steps ]
432- bias_correction2 = [torch .pow (beta2 , step ) for step in state_steps ]
433- # foreach_sub doesn't allow a scalar as the first arg
434- torch ._foreach_sub_ (bias_correction1 , 1 )
435- torch ._foreach_sub_ (bias_correction2 , 1 )
436- torch ._foreach_neg_ (bias_correction1 )
437- torch ._foreach_neg_ (bias_correction2 )
438-
439- # foreach_div doesn't allow a scalar as the first arg
440- step_size = torch ._foreach_div (bias_correction1 , lr )
441- torch ._foreach_reciprocal_ (step_size )
442- torch ._foreach_neg_ (step_size )
443-
444- bias_correction2_sqrt = torch ._foreach_sqrt (bias_correction2 )
445-
446- if amsgrad :
447- # Maintains the maximum of all 2nd moment running avg. till now
448- torch ._foreach_maximum_ (max_exp_avg_sqs , exp_avg_sqs )
449-
450- # Use the max. for normalizing running avg. of gradient
451- max_exp_avg_sq_sqrt = torch ._foreach_sqrt (max_exp_avg_sqs )
452- # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
453- # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
454- torch ._foreach_div_ (
455- max_exp_avg_sq_sqrt ,
456- torch ._foreach_mul (bias_correction2_sqrt , step_size ),
457- )
458- eps_over_step_size = torch ._foreach_div (step_size , eps )
459- torch ._foreach_reciprocal_ (eps_over_step_size )
460- denom = torch ._foreach_add (max_exp_avg_sq_sqrt , eps_over_step_size )
461- else :
462- exp_avg_sq_sqrt = torch ._foreach_sqrt (exp_avg_sqs )
463- torch ._foreach_div_ (
464- exp_avg_sq_sqrt , torch ._foreach_mul (bias_correction2_sqrt , step_size )
465- )
466- eps_over_step_size = torch ._foreach_div (step_size , eps )
467- torch ._foreach_reciprocal_ (eps_over_step_size )
468- denom = torch ._foreach_add (exp_avg_sq_sqrt , eps_over_step_size )
433+ # Decay the first and second moment running average coefficient
434+ torch ._foreach_mul_ (device_exp_avgs , beta1 )
435+ torch ._foreach_add_ (device_exp_avgs , device_grads , alpha = 1 - beta1 )
469436
470- torch ._foreach_addcdiv_ (params , exp_avgs , denom )
471- else :
472- bias_correction1 = [1 - beta1 ** _get_value (step ) for step in state_steps ]
473- bias_correction2 = [1 - beta2 ** _get_value (step ) for step in state_steps ]
437+ torch ._foreach_mul_ (device_exp_avg_sqs , beta2 )
438+ torch ._foreach_addcmul_ (device_exp_avg_sqs , device_grads , device_grads , 1 - beta2 )
474439
475- step_size = _stack_if_compiling ([(lr / bc ) * - 1 for bc in bias_correction1 ])
440+ if capturable :
441+ # TODO: use foreach_pow if/when foreach_pow is added
442+ bias_correction1 = [torch .pow (beta1 , step ) for step in device_state_steps ]
443+ bias_correction2 = [torch .pow (beta2 , step ) for step in device_state_steps ]
444+ # foreach_sub doesn't allow a scalar as the first arg
445+ torch ._foreach_sub_ (bias_correction1 , 1 )
446+ torch ._foreach_sub_ (bias_correction2 , 1 )
447+ torch ._foreach_neg_ (bias_correction1 )
448+ torch ._foreach_neg_ (bias_correction2 )
449+
450+ # foreach_div doesn't allow a scalar as the first arg
451+ step_size = torch ._foreach_div (bias_correction1 , lr )
452+ torch ._foreach_reciprocal_ (step_size )
453+ torch ._foreach_neg_ (step_size )
454+
455+ bias_correction2_sqrt = torch ._foreach_sqrt (bias_correction2 )
476456
477- bias_correction2_sqrt = [_dispatch_sqrt (bc ) for bc in bias_correction2 ]
457+ if amsgrad :
458+ # Maintains the maximum of all 2nd moment running avg. till now
459+ torch ._foreach_maximum_ (device_max_exp_avg_sqs , device_exp_avg_sqs )
478460
479- if amsgrad :
480- # Maintains the maximum of all 2nd moment running avg. till now
481- torch ._foreach_maximum_ (max_exp_avg_sqs , exp_avg_sqs )
461+ # Use the max. for normalizing running avg. of gradient
462+ max_exp_avg_sq_sqrt = torch ._foreach_sqrt (device_max_exp_avg_sqs )
463+ # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
464+ # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
465+ torch ._foreach_div_ (
466+ max_exp_avg_sq_sqrt ,
467+ torch ._foreach_mul (bias_correction2_sqrt , step_size ),
468+ )
469+ eps_over_step_size = torch ._foreach_div (step_size , eps )
470+ torch ._foreach_reciprocal_ (eps_over_step_size )
471+ denom = torch ._foreach_add (max_exp_avg_sq_sqrt , eps_over_step_size )
472+ else :
473+ exp_avg_sq_sqrt = torch ._foreach_sqrt (device_exp_avg_sqs )
474+ torch ._foreach_div_ (
475+ exp_avg_sq_sqrt , torch ._foreach_mul (bias_correction2_sqrt , step_size )
476+ )
477+ eps_over_step_size = torch ._foreach_div (step_size , eps )
478+ torch ._foreach_reciprocal_ (eps_over_step_size )
479+ denom = torch ._foreach_add (exp_avg_sq_sqrt , eps_over_step_size )
482480
483- # Use the max. for normalizing running avg. of gradient
484- max_exp_avg_sq_sqrt = torch ._foreach_sqrt (max_exp_avg_sqs )
485- torch ._foreach_div_ (max_exp_avg_sq_sqrt , bias_correction2_sqrt )
486- denom = torch ._foreach_add (max_exp_avg_sq_sqrt , eps )
481+ torch ._foreach_addcdiv_ (device_params , device_exp_avgs , denom )
487482 else :
488- exp_avg_sq_sqrt = torch ._foreach_sqrt (exp_avg_sqs )
489- torch ._foreach_div_ (exp_avg_sq_sqrt , bias_correction2_sqrt )
490- denom = torch ._foreach_add (exp_avg_sq_sqrt , eps )
483+ bias_correction1 = [1 - beta1 ** _get_value (step ) for step in device_state_steps ]
484+ bias_correction2 = [1 - beta2 ** _get_value (step ) for step in device_state_steps ]
485+
486+ step_size = _stack_if_compiling ([(lr / bc ) * - 1 for bc in bias_correction1 ])
487+
488+ bias_correction2_sqrt = [_dispatch_sqrt (bc ) for bc in bias_correction2 ]
489+
490+ if amsgrad :
491+ # Maintains the maximum of all 2nd moment running avg. till now
492+ torch ._foreach_maximum_ (device_max_exp_avg_sqs , device_exp_avg_sqs )
493+
494+ # Use the max. for normalizing running avg. of gradient
495+ max_exp_avg_sq_sqrt = torch ._foreach_sqrt (device_max_exp_avg_sqs )
496+ torch ._foreach_div_ (max_exp_avg_sq_sqrt , bias_correction2_sqrt )
497+ denom = torch ._foreach_add (max_exp_avg_sq_sqrt , eps )
498+ else :
499+ exp_avg_sq_sqrt = torch ._foreach_sqrt (device_exp_avg_sqs )
500+ torch ._foreach_div_ (exp_avg_sq_sqrt , bias_correction2_sqrt )
501+ denom = torch ._foreach_add (exp_avg_sq_sqrt , eps )
491502
492- torch ._foreach_addcdiv_ (params , exp_avgs , denom , step_size )
503+ torch ._foreach_addcdiv_ (device_params , device_exp_avgs , denom , step_size )
0 commit comments