@@ -47,7 +47,11 @@ class Adadelta(Optimizer):
4747 lr (float, optional): coefficient that scale delta before it is applied
4848 to the parameters (default: 1.0)
4949 weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
50- foreach (bool, optional): whether foreach implementation of optimizer is used (default: None)
50+ foreach (bool, optional): whether foreach implementation of optimizer is used.
51+ Since the foreach implementation is usually significantly faster than
52+ the for-loop implementation on CUDA, we try to use it whenever possible
53+ (all parameters are on CUDA). Else, we continue with the for-loop
54+ implementation. (default: None)
5155 maximize (bool, optional): maximize the params based on the objective, instead of
5256 minimizing (default: False)
5357
@@ -174,7 +178,7 @@ def adadelta(
174178 acc_deltas : List [Tensor ],
175179 # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
176180 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
177- foreach : bool = None ,
181+ foreach : Optional [ bool ] = None ,
178182 differentiable : bool = False ,
179183 * ,
180184 lr : float ,
@@ -188,9 +192,19 @@ def adadelta(
188192 See :class:`~torch.optim.Adadelta` for details.
189193 """
190194
195+ # We try to use the foreach implementation on CUDA whenever possible since
196+ # it is faster than the for-loop implementation. However, the foreach
197+ # implementation is not differentiable, so we must check differentiable=False.
198+ # We still respect when the user inputs False for foreach.
191199 if foreach is None :
192- # Placeholder for more complex foreach logic to be added when value is not set
193- foreach = False
200+ all_tensors = []
201+ all_tensors .extend (params )
202+ all_tensors .extend (grads )
203+ all_tensors .extend (square_avgs )
204+ all_tensors .extend (acc_deltas )
205+ foreach = not torch .jit .is_scripting () and not differentiable and all (
206+ p .is_cuda for p in all_tensors
207+ )
194208
195209 if foreach and torch .jit .is_scripting ():
196210 raise RuntimeError ("torch.jit.script not supported with foreach optimizers" )
0 commit comments