@@ -988,6 +988,87 @@ def test_state_dict_pre_post_hook(self, device, dtype, optim_info):
988988 self .assertTrue (state_dict ["ran_state_dict_pre_hook" ])
989989
990990
991+ @staticmethod
992+ def _load_state_dict_pre_hook1 (optimizer : Optimizer , state_dict : Dict [str , Any ]) -> None :
993+ state_dict ["param_groups" ][0 ]["lr" ] = 0.002
994+
995+
996+ @staticmethod
997+ def _load_state_dict_pre_hook2 (optimizer : Optimizer , state_dict : Dict [str , Any ]) -> Dict [str , Any ]:
998+ # The typical use case for returning a state dict is to drastically modify the state dict.
999+ # I will simulate by simply making a deep copy and ensuring that my_state_dict still gets used
1000+ my_state_dict = deepcopy (state_dict )
1001+ my_state_dict ["param_groups" ][0 ]["lr" ] = 0.003
1002+ return my_state_dict
1003+
1004+
1005+ @staticmethod
1006+ def _load_state_dict_post_hook (optimizer : Optimizer ) -> None :
1007+ optimizer .state ["ran_load_state_dict_pre_hook2" ] = optimizer .param_groups [0 ]["lr" ] == 0.003
1008+ optimizer .state ["ran_load_state_dict_post_hook" ] = True
1009+
1010+
1011+ @optims (optim_db , dtypes = [torch .float32 ])
1012+ def test_load_state_dict_pre_hook_and_prepend (self , device , dtype , optim_info ):
1013+ optim_cls = optim_info .optim_cls
1014+ all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs (device , dtype , optim_info )
1015+ for optim_input in all_optim_inputs :
1016+ if (optim_info .only_supports_capturable_on_foreach and optim_input .kwargs .get ("capturable" , False )
1017+ and not optim_input .kwargs .get ("foreach" , False )):
1018+ continue
1019+
1020+ param = torch .rand (2 , 3 , device = device , dtype = dtype , requires_grad = True )
1021+ optim = optim_cls ([param ], ** optim_input .kwargs )
1022+ state_dict = optim .state_dict ()
1023+
1024+ # usually one would have a new optim instance here, but it's all the same here
1025+ optim .register_load_state_dict_pre_hook (self .__class__ ._load_state_dict_pre_hook1 )
1026+ optim .load_state_dict (state_dict )
1027+ self .assertEqual (optim .param_groups [0 ]["lr" ], 0.002 )
1028+
1029+ optim .register_load_state_dict_pre_hook (self .__class__ ._load_state_dict_pre_hook2 , prepend = True )
1030+ optim .load_state_dict (state_dict )
1031+ # If prepend were False would be 0.003 but since prepend is True, the other hook overrides
1032+ self .assertEqual (optim .param_groups [0 ]["lr" ], 0.002 )
1033+
1034+
1035+ @optims (optim_db , dtypes = [torch .float32 ])
1036+ def test_load_state_dict_post_hook (self , device , dtype , optim_info ):
1037+ optim_cls = optim_info .optim_cls
1038+ all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs (device , dtype , optim_info )
1039+ for optim_input in all_optim_inputs :
1040+ if (optim_info .only_supports_capturable_on_foreach and optim_input .kwargs .get ("capturable" , False )
1041+ and not optim_input .kwargs .get ("foreach" , False )):
1042+ continue
1043+
1044+ param = torch .rand (2 , 3 , device = device , dtype = dtype , requires_grad = True )
1045+ optim = optim_cls ([param ], ** optim_input .kwargs )
1046+
1047+ optim .register_load_state_dict_post_hook (self .__class__ ._load_state_dict_post_hook )
1048+ optim .load_state_dict (optim .state_dict ())
1049+ self .assertFalse (optim .state ["ran_load_state_dict_pre_hook2" ])
1050+ self .assertTrue (optim .state ["ran_load_state_dict_post_hook" ])
1051+
1052+
1053+ @optims (optim_db , dtypes = [torch .float32 ])
1054+ def test_load_state_dict_pre_post_hook (self , device , dtype , optim_info ):
1055+ optim_cls = optim_info .optim_cls
1056+ all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs (device , dtype , optim_info )
1057+ for optim_input in all_optim_inputs :
1058+ if (optim_info .only_supports_capturable_on_foreach and optim_input .kwargs .get ("capturable" , False )
1059+ and not optim_input .kwargs .get ("foreach" , False )):
1060+ continue
1061+
1062+ param = torch .rand (2 , 3 , device = device , dtype = dtype , requires_grad = True )
1063+ optim = optim_cls ([param ], ** optim_input .kwargs )
1064+
1065+ optim .register_load_state_dict_pre_hook (self .__class__ ._load_state_dict_pre_hook2 )
1066+ optim .register_load_state_dict_post_hook (self .__class__ ._load_state_dict_post_hook )
1067+ optim .load_state_dict (optim .state_dict ())
1068+ self .assertTrue (optim .state ["ran_load_state_dict_pre_hook2" ])
1069+ self .assertTrue (optim .state ["ran_load_state_dict_post_hook" ])
1070+
1071+
9911072 @optims (optim_db , dtypes = [torch .float32 ])
9921073 def test_step_post_hook (self , device , dtype , optim_info ):
9931074 def post_hook (opt : Optimizer , args : Tuple [Any ], kwargs : Dict [Any , Any ]):
0 commit comments