Skip to content

Commit d376550

Browse files
janeyx99pytorchmergebot
authored andcommitted
[optim][adadelta] default to foreach when CUDA + differentiable=False (#91896)
following up to #90865 and #92048 Pull Request resolved: #91896 Approved by: https://github.com/albanD
1 parent cb67d94 commit d376550

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

torch/optim/adadelta.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,11 @@ class Adadelta(Optimizer):
4747
lr (float, optional): coefficient that scale delta before it is applied
4848
to the parameters (default: 1.0)
4949
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
50-
foreach (bool, optional): whether foreach implementation of optimizer is used (default: None)
50+
foreach (bool, optional): whether foreach implementation of optimizer is used.
51+
Since the foreach implementation is usually significantly faster than
52+
the for-loop implementation on CUDA, we try to use it whenever possible
53+
(all parameters are on CUDA). Else, we continue with the for-loop
54+
implementation. (default: None)
5155
maximize (bool, optional): maximize the params based on the objective, instead of
5256
minimizing (default: False)
5357
@@ -174,7 +178,7 @@ def adadelta(
174178
acc_deltas: List[Tensor],
175179
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
176180
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
177-
foreach: bool = None,
181+
foreach: Optional[bool] = None,
178182
differentiable: bool = False,
179183
*,
180184
lr: float,
@@ -188,9 +192,19 @@ def adadelta(
188192
See :class:`~torch.optim.Adadelta` for details.
189193
"""
190194

195+
# We try to use the foreach implementation on CUDA whenever possible since
196+
# it is faster than the for-loop implementation. However, the foreach
197+
# implementation is not differentiable, so we must check differentiable=False.
198+
# We still respect when the user inputs False for foreach.
191199
if foreach is None:
192-
# Placeholder for more complex foreach logic to be added when value is not set
193-
foreach = False
200+
all_tensors = []
201+
all_tensors.extend(params)
202+
all_tensors.extend(grads)
203+
all_tensors.extend(square_avgs)
204+
all_tensors.extend(acc_deltas)
205+
foreach = not torch.jit.is_scripting() and not differentiable and all(
206+
p.is_cuda for p in all_tensors
207+
)
194208

195209
if foreach and torch.jit.is_scripting():
196210
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

0 commit comments

Comments
 (0)