Skip to content

Commit 059994d

Browse files
janeyx99pytorchmergebot
authored andcommitted
Migrate load_state_dict hook tests to OptimizerInfo (#119310)
Pull Request resolved: #119310 Approved by: https://github.com/albanD ghstack dependencies: #119283, #119288, #119299, #119308
1 parent 0320e62 commit 059994d

File tree

2 files changed

+82
-57
lines changed

2 files changed

+82
-57
lines changed

test/optim/test_optim.py

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,11 @@
33
import unittest
44
import functools
55
import itertools
6-
from copy import deepcopy
76

87
import torch
98
from torch.nn import Parameter
109
from torch.optim import (
11-
Adadelta, Adagrad, Adam, Adamax, AdamW, ASGD, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam, Optimizer
10+
Adadelta, Adagrad, Adam, Adamax, AdamW, ASGD, NAdam, RAdam, RMSprop, Rprop, SGD, SparseAdam
1211
)
1312
from torch.optim.lr_scheduler import (
1413
StepLR,
@@ -28,7 +27,6 @@
2827

2928

3029
from torch.testing._internal.common_cuda import TEST_CUDA
31-
from typing import Dict, Any
3230
from unittest.mock import patch
3331

3432
# load_tests from common_utils is used to automatically filter tests for
@@ -747,60 +745,6 @@ def test_fused_optimizer_does_not_step_if_foundinf(self):
747745
maximize=False,
748746
)
749747

750-
@staticmethod
751-
def _load_state_dict_pre_hook1(optimizer: Optimizer, state_dict: Dict[str, Any]) -> None:
752-
state_dict["param_groups"][0]["lr"] = 0.002
753-
754-
@staticmethod
755-
def _load_state_dict_pre_hook2(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
756-
# The typical use case for returning a state dict is to drastically modify the state dict.
757-
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
758-
my_state_dict = deepcopy(state_dict)
759-
my_state_dict["param_groups"][0]["lr"] = 0.003
760-
return my_state_dict
761-
762-
@staticmethod
763-
def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
764-
optimizer.state["ran_load_state_dict_pre_hook2"] = optimizer.param_groups[0]["lr"] == 0.003
765-
optimizer.state["ran_load_state_dict_post_hook"] = True
766-
767-
def test_load_state_dict_pre_hook_and_prepend(self):
768-
param = torch.rand(2, 3, requires_grad=True)
769-
param.grad = torch.rand(2, 3, requires_grad=True)
770-
opt = SGD([param], lr=0.001)
771-
state_dict = opt.state_dict()
772-
773-
# usually one would have a new opt instance here, but it's all the same here
774-
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook1)
775-
opt.load_state_dict(state_dict)
776-
self.assertEqual(opt.param_groups[0]["lr"], 0.002)
777-
778-
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2, prepend=True)
779-
opt.load_state_dict(state_dict)
780-
# If prepend were False would be 0.003 but since prepend is True, the other hook overrides
781-
self.assertEqual(opt.param_groups[0]["lr"], 0.002)
782-
783-
def test_load_state_dict_post_hook(self):
784-
param = torch.rand(2, 3, requires_grad=True)
785-
param.grad = torch.rand(2, 3, requires_grad=True)
786-
opt = SGD([param], lr=0.001)
787-
788-
opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
789-
opt.load_state_dict(opt.state_dict())
790-
self.assertFalse(opt.state["ran_load_state_dict_pre_hook2"])
791-
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])
792-
793-
def test_load_state_dict_pre_post_hook(self):
794-
param = torch.rand(2, 3, requires_grad=True)
795-
param.grad = torch.rand(2, 3, requires_grad=True)
796-
opt = SGD([param], lr=0.001)
797-
798-
opt.register_load_state_dict_pre_hook(self._load_state_dict_pre_hook2)
799-
opt.register_load_state_dict_post_hook(self._load_state_dict_post_hook)
800-
opt.load_state_dict(opt.state_dict())
801-
self.assertTrue(opt.state["ran_load_state_dict_pre_hook2"])
802-
self.assertTrue(opt.state["ran_load_state_dict_post_hook"])
803-
804748

805749
def _diff_fn(p, grad, opt_differentiable_state, opt_class, kwargs, *ignored):
806750
# Ignored is the list of values in `opt_differentiable_state`, we do this

test/test_optim.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,87 @@ def test_state_dict_pre_post_hook(self, device, dtype, optim_info):
988988
self.assertTrue(state_dict["ran_state_dict_pre_hook"])
989989

990990

991+
@staticmethod
992+
def _load_state_dict_pre_hook1(optimizer: Optimizer, state_dict: Dict[str, Any]) -> None:
993+
state_dict["param_groups"][0]["lr"] = 0.002
994+
995+
996+
@staticmethod
997+
def _load_state_dict_pre_hook2(optimizer: Optimizer, state_dict: Dict[str, Any]) -> Dict[str, Any]:
998+
# The typical use case for returning a state dict is to drastically modify the state dict.
999+
# I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
1000+
my_state_dict = deepcopy(state_dict)
1001+
my_state_dict["param_groups"][0]["lr"] = 0.003
1002+
return my_state_dict
1003+
1004+
1005+
@staticmethod
1006+
def _load_state_dict_post_hook(optimizer: Optimizer) -> None:
1007+
optimizer.state["ran_load_state_dict_pre_hook2"] = optimizer.param_groups[0]["lr"] == 0.003
1008+
optimizer.state["ran_load_state_dict_post_hook"] = True
1009+
1010+
1011+
@optims(optim_db, dtypes=[torch.float32])
1012+
def test_load_state_dict_pre_hook_and_prepend(self, device, dtype, optim_info):
1013+
optim_cls = optim_info.optim_cls
1014+
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
1015+
for optim_input in all_optim_inputs:
1016+
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
1017+
and not optim_input.kwargs.get("foreach", False)):
1018+
continue
1019+
1020+
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1021+
optim = optim_cls([param], **optim_input.kwargs)
1022+
state_dict = optim.state_dict()
1023+
1024+
# usually one would have a new optim instance here, but it's all the same here
1025+
optim.register_load_state_dict_pre_hook(self.__class__._load_state_dict_pre_hook1)
1026+
optim.load_state_dict(state_dict)
1027+
self.assertEqual(optim.param_groups[0]["lr"], 0.002)
1028+
1029+
optim.register_load_state_dict_pre_hook(self.__class__._load_state_dict_pre_hook2, prepend=True)
1030+
optim.load_state_dict(state_dict)
1031+
# If prepend were False would be 0.003 but since prepend is True, the other hook overrides
1032+
self.assertEqual(optim.param_groups[0]["lr"], 0.002)
1033+
1034+
1035+
@optims(optim_db, dtypes=[torch.float32])
1036+
def test_load_state_dict_post_hook(self, device, dtype, optim_info):
1037+
optim_cls = optim_info.optim_cls
1038+
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
1039+
for optim_input in all_optim_inputs:
1040+
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
1041+
and not optim_input.kwargs.get("foreach", False)):
1042+
continue
1043+
1044+
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1045+
optim = optim_cls([param], **optim_input.kwargs)
1046+
1047+
optim.register_load_state_dict_post_hook(self.__class__._load_state_dict_post_hook)
1048+
optim.load_state_dict(optim.state_dict())
1049+
self.assertFalse(optim.state["ran_load_state_dict_pre_hook2"])
1050+
self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
1051+
1052+
1053+
@optims(optim_db, dtypes=[torch.float32])
1054+
def test_load_state_dict_pre_post_hook(self, device, dtype, optim_info):
1055+
optim_cls = optim_info.optim_cls
1056+
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info)
1057+
for optim_input in all_optim_inputs:
1058+
if (optim_info.only_supports_capturable_on_foreach and optim_input.kwargs.get("capturable", False)
1059+
and not optim_input.kwargs.get("foreach", False)):
1060+
continue
1061+
1062+
param = torch.rand(2, 3, device=device, dtype=dtype, requires_grad=True)
1063+
optim = optim_cls([param], **optim_input.kwargs)
1064+
1065+
optim.register_load_state_dict_pre_hook(self.__class__._load_state_dict_pre_hook2)
1066+
optim.register_load_state_dict_post_hook(self.__class__._load_state_dict_post_hook)
1067+
optim.load_state_dict(optim.state_dict())
1068+
self.assertTrue(optim.state["ran_load_state_dict_pre_hook2"])
1069+
self.assertTrue(optim.state["ran_load_state_dict_post_hook"])
1070+
1071+
9911072
@optims(optim_db, dtypes=[torch.float32])
9921073
def test_step_post_hook(self, device, dtype, optim_info):
9931074
def post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):

0 commit comments

Comments
 (0)