Skip to content

Commit 0865964

Browse files
authored
[optim] _actually_ default to foreach (#95862)
* [optim] include nn.Parameter as foreach supported (#95811) This PR is a result of a realization that models are NOT subscribed to the foreach defaulting as have been claimed on our documentation for months now. BIG OOPS. Pull Request resolved: #95811 Approved by: https://github.com/albanD * [optim] Widen the cases for defaulting to foreach (#95820) Big OOP correction continued. Also added a test this time to verify the defaulting was as expected. The key here is realizing that the grouping for foreach already assumes that the non-param tensorlists follow suit in dtype and device, so it is too narrow to check that _all_ tensors were on CUDA. The main leeway this allowed was state_steps, which are sometimes cpu tensors. Since foreach _can_ handle cpu tensors, this should not introduce breakage. Pull Request resolved: #95820 Approved by: https://github.com/albanD
1 parent f18ac1b commit 0865964

File tree

14 files changed

+69
-42
lines changed

14 files changed

+69
-42
lines changed

test/distributed/fsdp/test_fsdp_use_orig_params.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ def _get_fsdp_models_and_optims(
504504
fsdp_kwargs=fsdp_kwargs,
505505
deterministic=True,
506506
)
507-
optim = torch.optim.Adam(fsdp_model.parameters(), lr=LR)
507+
optim = torch.optim.Adam(fsdp_model.parameters(), foreach=False, lr=LR)
508508
fsdp_kwargs["use_orig_params"] = True
509509
fsdp_model_orig_params = TransformerWithSharedParams.init(
510510
self.process_group,
@@ -513,7 +513,9 @@ def _get_fsdp_models_and_optims(
513513
fsdp_kwargs=fsdp_kwargs,
514514
deterministic=True,
515515
)
516-
optim_orig_params = torch.optim.Adam(fsdp_model_orig_params.parameters(), lr=LR)
516+
optim_orig_params = torch.optim.Adam(
517+
fsdp_model_orig_params.parameters(), foreach=False, lr=LR
518+
)
517519
return fsdp_model, optim, fsdp_model_orig_params, optim_orig_params
518520

519521
def _check_fsdp_parameter_parity(self, fsdp1: FSDP, fsdp2: FSDP) -> None:

test/test_optim.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,10 @@
4646
skipIfRocm,
4747
skipIfTorchDynamo
4848
)
49-
from torch.testing._internal.common_cuda import TEST_MULTIGPU
49+
from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA
5050
from typing import Dict, Any, Tuple
5151
from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook
52+
from unittest.mock import patch
5253

5354
# load_tests from common_utils is used to automatically filter tests for
5455
# sharding on sandcastle. This line silences flake warnings
@@ -252,21 +253,26 @@ def fn_base(optimizer, weight, bias):
252253
)
253254

254255
# Make sure that optimizers that support maximize can load older models
255-
state_dict = optimizer.state_dict()
256-
if "maximize" in state_dict["param_groups"][0]:
257-
for group in state_dict["param_groups"]:
256+
old_state_dict = deepcopy(optimizer.state_dict())
257+
state_dict_no_maximize = deepcopy(optimizer.state_dict())
258+
if "maximize" in state_dict_no_maximize["param_groups"][0]:
259+
for group in state_dict_no_maximize["param_groups"]:
258260
del group["maximize"]
259-
optimizer.load_state_dict(state_dict)
261+
optimizer.load_state_dict(state_dict_no_maximize)
260262
# Make sure we can still step
261263
optimizer.step()
264+
# Undo these changes before proceeding!
265+
optimizer.load_state_dict(old_state_dict)
262266
# Make sure that optimizers that support foreach can load older models
263-
state_dict = optimizer.state_dict()
264-
if "foreach" in state_dict["param_groups"][0]:
265-
for group in state_dict["param_groups"]:
267+
state_dict_no_foreach = deepcopy(optimizer.state_dict())
268+
if "foreach" in state_dict_no_foreach["param_groups"][0]:
269+
for group in state_dict_no_foreach["param_groups"]:
266270
del group["foreach"]
267-
optimizer.load_state_dict(state_dict)
271+
optimizer.load_state_dict(state_dict_no_foreach)
268272
# Make sure we can still step
269273
optimizer.step()
274+
# Undo these changes before proceeding!
275+
optimizer.load_state_dict(old_state_dict)
270276

271277
# Make sure that loading optimizers with step not wrapped in tensor can work
272278
state_dict = optimizer.state_dict()
@@ -4535,5 +4541,39 @@ def test_radam(self):
45354541
)
45364542

45374543

4544+
@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
4545+
def test_defaults_changed_to_foreach(self):
4546+
from torch.optim import (adam, adamw, nadam, sgd, radam, rmsprop, rprop,
4547+
asgd, adamax, adadelta, adagrad)
4548+
multi_optims = ((optim.Adam, adam, "_multi_tensor_adam"),
4549+
(optim.AdamW, adamw, "_multi_tensor_adamw"),
4550+
(optim.NAdam, nadam, "_multi_tensor_nadam"),
4551+
(optim.SGD, sgd, "_multi_tensor_sgd"),
4552+
(optim.RAdam, radam, "_multi_tensor_radam"),
4553+
(optim.RMSprop, rmsprop, "_multi_tensor_rmsprop"),
4554+
(optim.Rprop, rprop, "_multi_tensor_rprop"),
4555+
(optim.ASGD, asgd, "_multi_tensor_asgd"),
4556+
(optim.Adamax, adamax, "_multi_tensor_adamax"),
4557+
(optim.Adadelta, adadelta, "_multi_tensor_adadelta"),
4558+
(optim.Adagrad, adagrad, "_multi_tensor_adagrad"),)
4559+
4560+
model = torch.nn.Linear(5, 5)
4561+
model.to(dtype=torch.float64, device="cuda")
4562+
input = torch.rand(2, 5, dtype=torch.float64, device="cuda")
4563+
4564+
for opt, mod, func in multi_optims:
4565+
defaults = {}
4566+
if opt == optim.SGD:
4567+
defaults["lr"] = 1e-2
4568+
optimizer = opt(model.parameters(), **defaults)
4569+
optimizer.zero_grad()
4570+
output = model(input)
4571+
loss = output.sum()
4572+
loss.backward()
4573+
with patch.object(mod, func) as mocked_foreach_impl:
4574+
optimizer.step()
4575+
self.assertTrue(mocked_foreach_impl.called)
4576+
4577+
45384578
if __name__ == "__main__":
45394579
run_tests()

torch/optim/adadelta.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,7 @@ def adadelta(
193193

194194
# We still respect when the user inputs False for foreach.
195195
if foreach is None:
196-
_, foreach = _default_to_fused_or_foreach([params, grads, square_avgs, acc_deltas],
197-
differentiable, use_fused=False)
196+
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
198197

199198
if foreach and torch.jit.is_scripting():
200199
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

torch/optim/adagrad.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,7 @@ def adagrad(
210210
)
211211

212212
if foreach is None:
213-
_, foreach = _default_to_fused_or_foreach([params, grads, state_sums, state_steps],
214-
differentiable, use_fused=False)
213+
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
215214

216215
if foreach and torch.jit.is_scripting():
217216
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

torch/optim/adam.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,9 +259,7 @@ def adam(params: List[Tensor],
259259
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
260260
# bake-in time before making it the default, even if it is typically faster.
261261
if fused is None and foreach is None:
262-
_, foreach = _default_to_fused_or_foreach(
263-
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps],
264-
differentiable, use_fused=False)
262+
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
265263
if fused is None:
266264
fused = False
267265
if foreach is None:

torch/optim/adamax.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,7 @@ def adamax(
206206
)
207207

208208
if foreach is None:
209-
_, foreach = _default_to_fused_or_foreach([params, grads, exp_avgs, exp_infs, state_steps],
210-
differentiable, use_fused=False)
209+
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
211210

212211
if foreach and torch.jit.is_scripting():
213212
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

torch/optim/adamw.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,9 +300,7 @@ def adamw(
300300
# and pass False to use_fused. This is not a mistake--we want to give the fused impl
301301
# bake-in time before making it the default, even if it is typically faster.
302302
if fused is None and foreach is None:
303-
_, foreach = _default_to_fused_or_foreach(
304-
[params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps],
305-
differentiable, use_fused=False)
303+
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
306304
if fused is None:
307305
fused = False
308306
if foreach is None:

torch/optim/asgd.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,7 @@ def asgd(
185185
"""
186186

187187
if foreach is None:
188-
_, foreach = _default_to_fused_or_foreach([params, grads, axs, mus, etas, state_steps],
189-
differentiable, use_fused=False)
188+
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
190189

191190
if foreach and torch.jit.is_scripting():
192191
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

torch/optim/nadam.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,7 @@ def nadam(params: List[Tensor],
187187
raise RuntimeError("API has changed, `mu_products` argument must contain a list of singleton tensors")
188188

189189
if foreach is None:
190-
_, foreach = _default_to_fused_or_foreach([params, grads, exp_avgs, exp_avg_sqs, mu_products, state_steps],
191-
differentiable, use_fused=False)
190+
_, foreach = _default_to_fused_or_foreach(params, differentiable, use_fused=False)
192191

193192
if foreach and torch.jit.is_scripting():
194193
raise RuntimeError('torch.jit.script not supported with foreach optimizers')

torch/optim/optimizer.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
__all__ = ['Optimizer', 'register_optimizer_step_pre_hook', 'register_optimizer_step_post_hook']
1616
_global_optimizer_pre_hooks: Dict[int, Callable] = OrderedDict()
1717
_global_optimizer_post_hooks: Dict[int, Callable] = OrderedDict()
18+
_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter]
1819

1920
class _RequiredParameter:
2021
"""Singleton class representing a required parameter for an Optimizer."""
@@ -56,23 +57,20 @@ def _dispatch_sqrt(x: float): # float annotation is needed because of torchscri
5657

5758
# For any optimizer with a faster implementation, we attempt to default to the
5859
# fastest + stablest whenever possible. For foreach, the requirements are to have
59-
# native tensors all on CUDA. For fused, there's currently the additional requirement
60+
# native params all on CUDA. For fused, there's currently the additional requirement
6061
# that the tensors' dtypes must be floating point. Neither alternative supports
6162
# torch.jit.script nor differentiable, so we fall back to the single tensor
6263
# implementation in those cases.
63-
def _default_to_fused_or_foreach(tensorlists: List[List[torch.Tensor]],
64+
def _default_to_fused_or_foreach(params: List[torch.Tensor],
6465
differentiable: bool,
6566
use_fused: bool = False) -> Tuple[bool, bool]:
6667
if torch.jit.is_scripting() or differentiable:
6768
return False, False
68-
all_tensors = []
69-
for tensorlist in tensorlists:
70-
all_tensors.extend(tensorlist)
7169
fused = use_fused and all(
72-
p is None or (type(p) == torch.Tensor and p.is_cuda and torch.is_floating_point(p)) for p in all_tensors
70+
p is None or (type(p) in _foreach_supported_types and p.is_cuda and torch.is_floating_point(p)) for p in params
7371
)
7472
foreach = not fused and all(
75-
p is None or (type(p) == torch.Tensor and p.is_cuda) for p in all_tensors
73+
p is None or (type(p) in _foreach_supported_types and p.is_cuda) for p in params
7674
)
7775
return fused, foreach
7876

0 commit comments

Comments
 (0)