|
4 | 4 | from torch import Tensor |
5 | 5 | from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling, |
6 | 6 | _dispatch_sqrt, _default_to_fused_or_foreach, _capturable_doc, |
7 | | - _differentiable_doc, _foreach_doc, _maximize_doc) |
| 7 | + _differentiable_doc, _foreach_doc, _fused_doc, _maximize_doc) |
8 | 8 | from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype |
9 | 9 |
|
10 | 10 | __all__ = ['Adam', 'adam'] |
@@ -218,28 +218,14 @@ def step(self, closure=None): |
218 | 218 | {maximize} |
219 | 219 | {capturable} |
220 | 220 | {differentiable} |
221 | | - fused (bool, optional): whether the fused implementation (CUDA only) is used. |
222 | | - Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` |
223 | | - are supported. Since the fused implementation is usually significantly faster than |
224 | | - the for-loop implementation, we try to use it whenever possible (all parameters |
225 | | - are on CUDA and are of a supported type). Else, we attempt to use the foreach |
226 | | - implementation and lastly fall back to the for-loop implementation. (default: None) |
227 | | -
|
228 | | - .. note:: The foreach and fused implementations are typically faster than the for-loop, |
229 | | - single-tensor implementation, so we will try to default to them IF the user has |
230 | | - not specified either flag (i.e., when foreach = fused = None). For example, if |
231 | | - the user specifies True for foreach but nothing for fused, we will run the foreach |
232 | | - implementation. If the user specifies False for fused but nothing for foreach, we will |
233 | | - run the for-loop implementation. If the user specifies True for both foreach and |
234 | | - fused, we will prioritize fused over foreach. We attempt to use the fastest, so the |
235 | | - hierarchy goes fused -> foreach -> for-loop. |
| 221 | + {fused} |
236 | 222 | .. _Adam\: A Method for Stochastic Optimization: |
237 | 223 | https://arxiv.org/abs/1412.6980 |
238 | 224 | .. _On the Convergence of Adam and Beyond: |
239 | 225 | https://openreview.net/forum?id=ryQu7f-RZ |
240 | 226 |
|
241 | 227 | """.format(foreach=_foreach_doc, maximize=_maximize_doc, capturable=_capturable_doc, |
242 | | - differentiable=_differentiable_doc) |
| 228 | + differentiable=_differentiable_doc, fused=_fused_doc) |
243 | 229 |
|
244 | 230 |
|
245 | 231 | def adam(params: List[Tensor], |
@@ -268,10 +254,14 @@ def adam(params: List[Tensor], |
268 | 254 | See :class:`~torch.optim.Adam` for details. |
269 | 255 | """ |
270 | 256 |
|
| 257 | + # Respect when the user inputs False/True for foreach or fused. We only want to change |
| 258 | + # the default when neither have been user-specified. Note that we default to foreach |
| 259 | + # and pass False to use_fused. This is not a mistake--we want to give the fused impl |
| 260 | + # bake-in time before making it the default, even if it is typically faster. |
271 | 261 | if fused is None and foreach is None: |
272 | | - fused, foreach = _default_to_fused_or_foreach( |
| 262 | + _, foreach = _default_to_fused_or_foreach( |
273 | 263 | [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps], |
274 | | - differentiable, has_fused=True) |
| 264 | + differentiable, use_fused=False) |
275 | 265 | if fused is None: |
276 | 266 | fused = False |
277 | 267 | if foreach is None: |
|
0 commit comments