11from collections import defaultdict
22from typing import cast , List , Optional , Dict , Tuple
3- import warnings
4- import itertools
53
64import torch
75from torch import Tensor
@@ -110,9 +108,9 @@ class Adam(Optimizer):
110108 fused (bool, optional): whether the fused implementation (CUDA only) is used.
111109 Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
112110 are supported. Since the fused implementation is usually significantly faster than
113- the for-loop implementation, we default to using it whenever possible (all
114- parameters are on CUDA and are of a supported type. Else, we fall back to the
115- for-loop implementation. (default: True )
111+ the for-loop implementation, we try to use it whenever possible (all parameters
112+ are on CUDA and are of a supported type) . Else, we continue with the for-loop
113+ implementation. (default: False )
116114
117115 .. _Adam\: A Method for Stochastic Optimization:
118116 https://arxiv.org/abs/1412.6980
@@ -123,7 +121,7 @@ class Adam(Optimizer):
123121 def __init__ (self , params , lr = 1e-3 , betas = (0.9 , 0.999 ), eps = 1e-8 ,
124122 weight_decay = 0 , amsgrad = False , * , foreach : Optional [bool ] = None ,
125123 maximize : bool = False , capturable : bool = False ,
126- differentiable : bool = False , fused : bool = True ):
124+ differentiable : bool = False , fused : Optional [ bool ] = None ):
127125 if not 0.0 <= lr :
128126 raise ValueError ("Invalid learning rate: {}" .format (lr ))
129127 if not 0.0 <= eps :
@@ -135,42 +133,25 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
135133 if not 0.0 <= weight_decay :
136134 raise ValueError ("Invalid weight_decay value: {}" .format (weight_decay ))
137135
138- def all_params (params , lambda_fn ):
139- if isinstance (params , Tensor ):
140- return lambda_fn (params )
141- if isinstance (params , dict ):
142- return all_params (params .values (), lambda_fn )
143- # should be an iterable, unless it sets a default, in which case it's not relevant 🤷🏻♀️
144- try :
145- return all ([all_params (p , lambda_fn ) for p in params ])
146- except TypeError :
147- return True
148-
149- params , params_copy = itertools .tee (params )
150-
151- # The fused implementation is fastest but is only available when the parameters are floats on CUDA.
152- # The fused implementation is also not differentiable. We default back to for-loop impl in both cases.
153- if fused :
154- if differentiable :
155- fused = False
156- warnings .warn ("`fused` cannot be `differentiable`, falling back to for-loop implementation" )
157- elif not all_params (params_copy , lambda p : p .is_cuda and torch .is_floating_point (p )):
158- fused = False
159- warnings .warn ("FusedAdam requires all the params to be CUDA, floating point. "
160- "Falling back to for-loop implementation" )
161-
162136 defaults = dict (lr = lr , betas = betas , eps = eps ,
163137 weight_decay = weight_decay , amsgrad = amsgrad ,
164138 maximize = maximize , foreach = foreach , capturable = capturable ,
165139 differentiable = differentiable , fused = fused )
166140 super (Adam , self ).__init__ (params , defaults )
167141
168142 if fused :
143+ if differentiable :
144+ raise RuntimeError ("`fused` cannot be `differentiable`" )
145+ self ._step_supports_amp_scaling = True
169146 # TODO(crcrpar): [low prec params & their higher prec copy]
170147 # Suppor AMP with FP16/BF16 model params which would need
171148 # higher prec copy of params to do update math in higher prec to
172149 # alleviate the loss of information.
173- self ._step_supports_amp_scaling = True
150+ if not all (
151+ p .is_cuda and torch .is_floating_point (p )
152+ for pg in self .param_groups for p in pg ['params' ]
153+ ):
154+ raise RuntimeError ("FusedAdam requires all the params to be CUDA, floating point" )
174155
175156 def __setstate__ (self , state ):
176157 super ().__setstate__ (state )
@@ -311,7 +292,7 @@ def adam(params: List[Tensor],
311292 foreach : Optional [bool ] = None ,
312293 capturable : bool = False ,
313294 differentiable : bool = False ,
314- fused : bool = False ,
295+ fused : Optional [ bool ] = None ,
315296 grad_scale : Optional [_MultiDeviceReplicator ] = None ,
316297 found_inf : Optional [_MultiDeviceReplicator ] = None ,
317298 * ,
@@ -326,6 +307,17 @@ def adam(params: List[Tensor],
326307 See :class:`~torch.optim.Adam` for details.
327308 """
328309
310+ # We try to use the fused implementation whenever we can since it is fastest.
311+ # It's only available when the tensors are floats on the same CUDA device
312+ # and when differentiable=False.
313+ # We still respect when the user inputs False for fused.
314+ if fused is None :
315+ if not differentiable and all (
316+ p .is_cuda and torch .is_floating_point (p )
317+ for p in params + grads + exp_avgs + exp_avg_sqs + max_exp_avg_sqs + state_steps
318+ ):
319+ fused = True
320+
329321 if not all (isinstance (t , torch .Tensor ) for t in state_steps ):
330322 raise RuntimeError ("API has changed, `state_steps` argument must contain a list of singleton tensors" )
331323
0 commit comments