Skip to content

Commit 00eb7b0

Browse files
authored
[optim] Set defaults to foreach, NOT fused (#95241) (#95415)
Rolling back the default change for Adam and rectifying the docs to reflect that AdamW never defaulted to fused. Since our fused implementations are relatively newer, let's give them a longer bake-in time before flipping the switch for every user. Pull Request resolved: #95241 Approved by: https://github.com/ngimel
1 parent 2180f34 commit 00eb7b0

File tree

12 files changed

+48
-43
lines changed

12 files changed

+48
-43
lines changed

torch/optim/adadelta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ def adadelta(
194194
# We still respect when the user inputs False for foreach.
195195
if foreach is None:
196196
_, foreach = _default_to_fused_or_foreach([params, grads, square_avgs, acc_deltas],
197-
differentiable, has_fused=False)
197+
differentiable, use_fused=False)
198198

199199
if foreach and torch.jit.is_scripting():
200200
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

torch/optim/adagrad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def adagrad(
211211

212212
if foreach is None:
213213
_, foreach = _default_to_fused_or_foreach([params, grads, state_sums, state_steps],
214-
differentiable, has_fused=False)
214+
differentiable, use_fused=False)
215215

216216
if foreach and torch.jit.is_scripting():
217217
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

torch/optim/adam.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch import Tensor
55
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling,
66
_dispatch_sqrt, _default_to_fused_or_foreach, _capturable_doc,
7-
_differentiable_doc, _foreach_doc, _maximize_doc)
7+
_differentiable_doc, _foreach_doc, _fused_doc, _maximize_doc)
88
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
99

1010
__all__ = ['Adam', 'adam']
@@ -218,28 +218,14 @@ def step(self, closure=None):
218218
{maximize}
219219
{capturable}
220220
{differentiable}
221-
fused (bool, optional): whether the fused implementation (CUDA only) is used.
222-
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
223-
are supported. Since the fused implementation is usually significantly faster than
224-
the for-loop implementation, we try to use it whenever possible (all parameters
225-
are on CUDA and are of a supported type). Else, we attempt to use the foreach
226-
implementation and lastly fall back to the for-loop implementation. (default: None)
227-
228-
.. note:: The foreach and fused implementations are typically faster than the for-loop,
229-
single-tensor implementation, so we will try to default to them IF the user has
230-
not specified either flag (i.e., when foreach = fused = None). For example, if
231-
the user specifies True for foreach but nothing for fused, we will run the foreach
232-
implementation. If the user specifies False for fused but nothing for foreach, we will
233-
run the for-loop implementation. If the user specifies True for both foreach and
234-
fused, we will prioritize fused over foreach. We attempt to use the fastest, so the
235-
hierarchy goes fused -> foreach -> for-loop.
221+
{fused}
236222
.. _Adam\: A Method for Stochastic Optimization:
237223
https://arxiv.org/abs/1412.6980
238224
.. _On the Convergence of Adam and Beyond:
239225
https://openreview.net/forum?id=ryQu7f-RZ
240226
241227
""".format(foreach=_foreach_doc, maximize=_maximize_doc, capturable=_capturable_doc,
242-
differentiable=_differentiable_doc)
228+
differentiable=_differentiable_doc, fused=_fused_doc)
243229

244230

245231
def adam(params: List[Tensor],
@@ -268,10 +254,14 @@ def adam(params: List[Tensor],
268254
See :class:`~torch.optim.Adam` for details.
269255
"""
270256

257+
# Respect when the user inputs False/True for foreach or fused. We only want to change
258+
# the default when neither have been user-specified. Note that we default to foreach
259+
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
260+
# bake-in time before making it the default, even if it is typically faster.
271261
if fused is None and foreach is None:
272-
fused, foreach = _default_to_fused_or_foreach(
262+
_, foreach = _default_to_fused_or_foreach(
273263
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps],
274-
differentiable, has_fused=True)
264+
differentiable, use_fused=False)
275265
if fused is None:
276266
fused = False
277267
if foreach is None:

torch/optim/adamax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def adamax(
207207

208208
if foreach is None:
209209
_, foreach = _default_to_fused_or_foreach([params, grads, exp_avgs, exp_infs, state_steps],
210-
differentiable, has_fused=False)
210+
differentiable, use_fused=False)
211211

212212
if foreach and torch.jit.is_scripting():
213213
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

torch/optim/adamw.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch import Tensor
33
from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _dispatch_sqrt,
44
_stack_if_compiling, _capturable_doc, _differentiable_doc, _foreach_doc,
5-
_maximize_doc, _default_to_fused_or_foreach)
5+
_fused_doc, _maximize_doc, _default_to_fused_or_foreach)
66
from typing import List, Optional
77
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
88

@@ -248,20 +248,15 @@ def step(self, closure=None):
248248
{foreach}
249249
{capturable}
250250
{differentiable}
251-
fused (bool, optional): whether the fused implementation (CUDA only) is used.
252-
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
253-
are supported. Since the fused implementation is usually significantly faster than
254-
the for-loop implementation, we try to use it whenever possible (all parameters
255-
are on CUDA and are of a supported type). Else, we continue with the for-loop
256-
implementation. (default: None)
257-
251+
{fused}
258252
.. _Decoupled Weight Decay Regularization:
259253
https://arxiv.org/abs/1711.05101
260254
.. _On the Convergence of Adam and Beyond:
261255
https://openreview.net/forum?id=ryQu7f-RZ
262256
263257
""".format(maximize=_maximize_doc,
264258
foreach=_foreach_doc,
259+
fused=_fused_doc,
265260
capturable=_capturable_doc,
266261
differentiable=_differentiable_doc)
267262

@@ -300,11 +295,14 @@ def adamw(
300295
"API has changed, `state_steps` argument must contain a list of singleton tensors"
301296
)
302297

303-
# Respect when the user inputs False/True for foreach.
298+
# Respect when the user inputs False/True for foreach or fused. We only want to change
299+
# the default when neither have been user-specified. Note that we default to foreach
300+
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
301+
# bake-in time before making it the default, even if it is typically faster.
304302
if fused is None and foreach is None:
305-
fused, foreach = _default_to_fused_or_foreach(
303+
_, foreach = _default_to_fused_or_foreach(
306304
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps],
307-
differentiable, has_fused=False)
305+
differentiable, use_fused=False)
308306
if fused is None:
309307
fused = False
310308
if foreach is None:

torch/optim/asgd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def asgd(
186186

187187
if foreach is None:
188188
_, foreach = _default_to_fused_or_foreach([params, grads, axs, mus, etas, state_steps],
189-
differentiable, has_fused=False)
189+
differentiable, use_fused=False)
190190

191191
if foreach and torch.jit.is_scripting():
192192
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

torch/optim/nadam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def nadam(params: List[Tensor],
188188

189189
if foreach is None:
190190
_, foreach = _default_to_fused_or_foreach([params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps],
191-
differentiable, has_fused=False)
191+
differentiable, use_fused=False)
192192

193193
if foreach and torch.jit.is_scripting():
194194
raise RuntimeError('torch.jit.script not supported with foreach optimizers')

torch/optim/optimizer.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,20 @@ def _dispatch_sqrt(x: float): # float annotation is needed because of torchscri
5555
return math.sqrt(x)
5656

5757
# For any optimizer with a faster implementation, we attempt to default to the
58-
# fastest whenever possible. For foreach, the requirements are to have native
59-
# tensors all on CUDA. For fused, there's currently the additional requirement
58+
# fastest + stablest whenever possible. For foreach, the requirements are to have
59+
# native tensors all on CUDA. For fused, there's currently the additional requirement
6060
# that the tensors' dtypes must be floating point. Neither alternative supports
6161
# torch.jit.script nor differentiable, so we fall back to the single tensor
6262
# implementation in those cases.
6363
def _default_to_fused_or_foreach(tensorlists: List[List[torch.Tensor]],
6464
differentiable: bool,
65-
has_fused: bool = False) -> Tuple[bool, bool]:
65+
use_fused: bool = False) -> Tuple[bool, bool]:
6666
if torch.jit.is_scripting() or differentiable:
6767
return False, False
6868
all_tensors = []
6969
for tensorlist in tensorlists:
7070
all_tensors.extend(tensorlist)
71-
fused = has_fused and all(
71+
fused = use_fused and all(
7272
p is None or (type(p) == torch.Tensor and p.is_cuda and torch.is_floating_point(p)) for p in all_tensors
7373
)
7474
foreach = not fused and all(
@@ -83,6 +83,23 @@ def _default_to_fused_or_foreach(tensorlists: List[List[torch.Tensor]],
8383
foreach over the for-loop implementation on CUDA, since it is usually
8484
significantly more performant. (default: None)"""
8585

86+
_fused_doc = r"""fused (bool, optional): whether the fused implementation (CUDA only) is used.
87+
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
88+
are supported. (default: None)
89+
90+
.. note:: The foreach and fused implementations are typically faster than the for-loop,
91+
single-tensor implementation. Thus, if the user has not specified BOTH flags
92+
(i.e., when foreach = fused = None), we will attempt defaulting to the foreach
93+
implementation when the tensors are all on CUDA. For example, if the user specifies
94+
True for fused but nothing for foreach, we will run the fused implementation. If
95+
the user specifies False for foreach but nothing for fused (or False for fused but
96+
nothing for foreach), we will run the for-loop implementation. If the user specifies
97+
True for both foreach and fused, we will prioritize fused over foreach, as it is
98+
typically faster. We attempt to use the fastest, so the hierarchy goes fused ->
99+
foreach -> for-loop. HOWEVER, since the fused implementation is relatively new,
100+
we want to give it sufficient bake-in time, so we default to foreach and NOT
101+
fused when the user has not specified either flag."""
102+
86103
_capturable_doc = r"""capturable (bool, optional): whether this instance is safe to
87104
capture in a CUDA graph. Passing True can impair ungraphed performance,
88105
so if you don't intend to graph capture this instance, leave it False

torch/optim/radam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def radam(
210210

211211
if foreach is None:
212212
_, foreach = _default_to_fused_or_foreach([params, grads, exp_avgs, exp_avg_sqs, state_steps],
213-
differentiable, has_fused=False)
213+
differentiable, use_fused=False)
214214

215215
if foreach and torch.jit.is_scripting():
216216
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

torch/optim/rmsprop.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def rmsprop(
221221

222222
if foreach is None:
223223
_, foreach = _default_to_fused_or_foreach([params, grads, square_avgs, grad_avgs, momentum_buffer_list],
224-
differentiable, has_fused=False)
224+
differentiable, use_fused=False)
225225

226226
if foreach and torch.jit.is_scripting():
227227
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

0 commit comments

Comments
 (0)