Skip to content

Commit c292b78

Browse files
committed
Move logic to adam() and not the constructor
1 parent 288a778 commit c292b78

File tree

1 file changed

+24
-32
lines changed

1 file changed

+24
-32
lines changed

torch/optim/adam.py

Lines changed: 24 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from collections import defaultdict
22
from typing import cast, List, Optional, Dict, Tuple
3-
import warnings
4-
import itertools
53

64
import torch
75
from torch import Tensor
@@ -110,9 +108,9 @@ class Adam(Optimizer):
110108
fused (bool, optional): whether the fused implementation (CUDA only) is used.
111109
Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
112110
are supported. Since the fused implementation is usually significantly faster than
113-
the for-loop implementation, we default to using it whenever possible (all
114-
parameters are on CUDA and are of a supported type. Else, we fall back to the
115-
for-loop implementation. (default: True)
111+
the for-loop implementation, we try to use it whenever possible (all parameters
112+
are on CUDA and are of a supported type). Else, we continue with the for-loop
113+
implementation. (default: False)
116114
117115
.. _Adam\: A Method for Stochastic Optimization:
118116
https://arxiv.org/abs/1412.6980
@@ -123,7 +121,7 @@ class Adam(Optimizer):
123121
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
124122
weight_decay=0, amsgrad=False, *, foreach: Optional[bool] = None,
125123
maximize: bool = False, capturable: bool = False,
126-
differentiable: bool = False, fused: bool = True):
124+
differentiable: bool = False, fused: Optional[bool] = None):
127125
if not 0.0 <= lr:
128126
raise ValueError("Invalid learning rate: {}".format(lr))
129127
if not 0.0 <= eps:
@@ -135,42 +133,25 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
135133
if not 0.0 <= weight_decay:
136134
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
137135

138-
def all_params(params, lambda_fn):
139-
if isinstance(params, Tensor):
140-
return lambda_fn(params)
141-
if isinstance(params, dict):
142-
return all_params(params.values(), lambda_fn)
143-
# should be an iterable, unless it sets a default, in which case it's not relevant 🤷🏻‍♀️
144-
try:
145-
return all([all_params(p, lambda_fn) for p in params])
146-
except TypeError:
147-
return True
148-
149-
params, params_copy = itertools.tee(params)
150-
151-
# The fused implementation is fastest but is only available when the parameters are floats on CUDA.
152-
# The fused implementation is also not differentiable. We default back to for-loop impl in both cases.
153-
if fused:
154-
if differentiable:
155-
fused = False
156-
warnings.warn("`fused` cannot be `differentiable`, falling back to for-loop implementation")
157-
elif not all_params(params_copy, lambda p: p.is_cuda and torch.is_floating_point(p)):
158-
fused = False
159-
warnings.warn("FusedAdam requires all the params to be CUDA, floating point. "
160-
"Falling back to for-loop implementation")
161-
162136
defaults = dict(lr=lr, betas=betas, eps=eps,
163137
weight_decay=weight_decay, amsgrad=amsgrad,
164138
maximize=maximize, foreach=foreach, capturable=capturable,
165139
differentiable=differentiable, fused=fused)
166140
super(Adam, self).__init__(params, defaults)
167141

168142
if fused:
143+
if differentiable:
144+
raise RuntimeError("`fused` cannot be `differentiable`")
145+
self._step_supports_amp_scaling = True
169146
# TODO(crcrpar): [low prec params & their higher prec copy]
170147
# Suppor AMP with FP16/BF16 model params which would need
171148
# higher prec copy of params to do update math in higher prec to
172149
# alleviate the loss of information.
173-
self._step_supports_amp_scaling = True
150+
if not all(
151+
p.is_cuda and torch.is_floating_point(p)
152+
for pg in self.param_groups for p in pg['params']
153+
):
154+
raise RuntimeError("FusedAdam requires all the params to be CUDA, floating point")
174155

175156
def __setstate__(self, state):
176157
super().__setstate__(state)
@@ -311,7 +292,7 @@ def adam(params: List[Tensor],
311292
foreach: Optional[bool] = None,
312293
capturable: bool = False,
313294
differentiable: bool = False,
314-
fused: bool = False,
295+
fused: Optional[bool] = None,
315296
grad_scale: Optional[_MultiDeviceReplicator] = None,
316297
found_inf: Optional[_MultiDeviceReplicator] = None,
317298
*,
@@ -326,6 +307,17 @@ def adam(params: List[Tensor],
326307
See :class:`~torch.optim.Adam` for details.
327308
"""
328309

310+
# We try to use the fused implementation whenever we can since it is fastest.
311+
# It's only available when the tensors are floats on the same CUDA device
312+
# and when differentiable=False.
313+
# We still respect when the user inputs False for fused.
314+
if fused is None:
315+
if not differentiable and all(
316+
p.is_cuda and torch.is_floating_point(p)
317+
for p in params + grads + exp_avgs + exp_avg_sqs + max_exp_avg_sqs + state_steps
318+
):
319+
fused = True
320+
329321
if not all(isinstance(t, torch.Tensor) for t in state_steps):
330322
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
331323

0 commit comments

Comments
 (0)