|
46 | 46 | skipIfRocm, |
47 | 47 | skipIfTorchDynamo |
48 | 48 | ) |
49 | | -from torch.testing._internal.common_cuda import TEST_MULTIGPU |
| 49 | +from torch.testing._internal.common_cuda import TEST_MULTIGPU, TEST_CUDA |
50 | 50 | from typing import Dict, Any, Tuple |
51 | 51 | from torch.optim.optimizer import register_optimizer_step_pre_hook, register_optimizer_step_post_hook |
| 52 | +from unittest.mock import patch |
52 | 53 |
|
53 | 54 | # load_tests from common_utils is used to automatically filter tests for |
54 | 55 | # sharding on sandcastle. This line silences flake warnings |
@@ -252,21 +253,26 @@ def fn_base(optimizer, weight, bias): |
252 | 253 | ) |
253 | 254 |
|
254 | 255 | # Make sure that optimizers that support maximize can load older models |
255 | | - state_dict = optimizer.state_dict() |
256 | | - if "maximize" in state_dict["param_groups"][0]: |
257 | | - for group in state_dict["param_groups"]: |
| 256 | + old_state_dict = deepcopy(optimizer.state_dict()) |
| 257 | + state_dict_no_maximize = deepcopy(optimizer.state_dict()) |
| 258 | + if "maximize" in state_dict_no_maximize["param_groups"][0]: |
| 259 | + for group in state_dict_no_maximize["param_groups"]: |
258 | 260 | del group["maximize"] |
259 | | - optimizer.load_state_dict(state_dict) |
| 261 | + optimizer.load_state_dict(state_dict_no_maximize) |
260 | 262 | # Make sure we can still step |
261 | 263 | optimizer.step() |
| 264 | + # Undo these changes before proceeding! |
| 265 | + optimizer.load_state_dict(old_state_dict) |
262 | 266 | # Make sure that optimizers that support foreach can load older models |
263 | | - state_dict = optimizer.state_dict() |
264 | | - if "foreach" in state_dict["param_groups"][0]: |
265 | | - for group in state_dict["param_groups"]: |
| 267 | + state_dict_no_foreach = deepcopy(optimizer.state_dict()) |
| 268 | + if "foreach" in state_dict_no_foreach["param_groups"][0]: |
| 269 | + for group in state_dict_no_foreach["param_groups"]: |
266 | 270 | del group["foreach"] |
267 | | - optimizer.load_state_dict(state_dict) |
| 271 | + optimizer.load_state_dict(state_dict_no_foreach) |
268 | 272 | # Make sure we can still step |
269 | 273 | optimizer.step() |
| 274 | + # Undo these changes before proceeding! |
| 275 | + optimizer.load_state_dict(old_state_dict) |
270 | 276 |
|
271 | 277 | # Make sure that loading optimizers with step not wrapped in tensor can work |
272 | 278 | state_dict = optimizer.state_dict() |
@@ -4535,5 +4541,39 @@ def test_radam(self): |
4535 | 4541 | ) |
4536 | 4542 |
|
4537 | 4543 |
|
| 4544 | + @unittest.skipIf(not TEST_CUDA, "test requires CUDA") |
| 4545 | + def test_defaults_changed_to_foreach(self): |
| 4546 | + from torch.optim import (adam, adamw, nadam, sgd, radam, rmsprop, rprop, |
| 4547 | + asgd, adamax, adadelta, adagrad) |
| 4548 | + multi_optims = ((optim.Adam, adam, "_multi_tensor_adam"), |
| 4549 | + (optim.AdamW, adamw, "_multi_tensor_adamw"), |
| 4550 | + (optim.NAdam, nadam, "_multi_tensor_nadam"), |
| 4551 | + (optim.SGD, sgd, "_multi_tensor_sgd"), |
| 4552 | + (optim.RAdam, radam, "_multi_tensor_radam"), |
| 4553 | + (optim.RMSprop, rmsprop, "_multi_tensor_rmsprop"), |
| 4554 | + (optim.Rprop, rprop, "_multi_tensor_rprop"), |
| 4555 | + (optim.ASGD, asgd, "_multi_tensor_asgd"), |
| 4556 | + (optim.Adamax, adamax, "_multi_tensor_adamax"), |
| 4557 | + (optim.Adadelta, adadelta, "_multi_tensor_adadelta"), |
| 4558 | + (optim.Adagrad, adagrad, "_multi_tensor_adagrad"),) |
| 4559 | + |
| 4560 | + model = torch.nn.Linear(5, 5) |
| 4561 | + model.to(dtype=torch.float64, device="cuda") |
| 4562 | + input = torch.rand(2, 5, dtype=torch.float64, device="cuda") |
| 4563 | + |
| 4564 | + for opt, mod, func in multi_optims: |
| 4565 | + defaults = {} |
| 4566 | + if opt == optim.SGD: |
| 4567 | + defaults["lr"] = 1e-2 |
| 4568 | + optimizer = opt(model.parameters(), **defaults) |
| 4569 | + optimizer.zero_grad() |
| 4570 | + output = model(input) |
| 4571 | + loss = output.sum() |
| 4572 | + loss.backward() |
| 4573 | + with patch.object(mod, func) as mocked_foreach_impl: |
| 4574 | + optimizer.step() |
| 4575 | + self.assertTrue(mocked_foreach_impl.called) |
| 4576 | + |
| 4577 | + |
4538 | 4578 | if __name__ == "__main__": |
4539 | 4579 | run_tests() |
0 commit comments