44
55__all__ : List [str ] = []
66
7- # WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter
8- # without changing it's life-time.
9- # NOTE: Alternative is to add the meta-data as an attribute to the tensor,
10- # but that will serialize the meta-data if Tensor is serialized.
11- param_to_optim_hook_handle_map = torch .utils .weak .WeakTensorKeyDictionary ()
12- param_to_acc_grad_map = torch .utils .weak .WeakTensorKeyDictionary ()
137
148@no_type_check
159def _apply_optimizer_in_backward (
@@ -50,12 +44,19 @@ def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None:
5044 # this parameter is ready (has been accumulated into .grad field)
5145
5246 # Don't create a new acc_grad if we already have one
53- # i.e. for shared parameters or attaching multiple optimizers to a param.
54- if param not in param_to_acc_grad_map :
55- param_to_acc_grad_map [param ] = param .view_as (param ).grad_fn .next_functions [0 ][0 ]
47+ # i.e.f or shared parameters or attaching multiple optimizers to a param.
48+ if not hasattr (param , "acc_grad" ):
49+ acc_grad = param .view_as (param ).grad_fn .next_functions [0 ][0 ]
50+ else :
51+ acc_grad = param ._acc_grad
5652
5753 optimizer = optimizer_class ([param ], ** optimizer_kwargs )
5854
55+ # Keep the grad accumulator around for the lifetime of the Tensor,
56+ # store it on the param to avoid uncollectable ref-cycle
57+ if not hasattr (param , "acc_grad" ):
58+ param ._acc_grad = acc_grad # type: ignore[attr-defined]
59+
5960 if not hasattr (param , "_in_backward_optimizers" ):
6061 param ._in_backward_optimizers = [] # type: ignore[attr-defined]
6162 # TODO: investigate whether we really need these attributes.
@@ -72,10 +73,10 @@ def optimizer_hook(*_unused) -> None:
7273
7374 param .grad = None
7475
75- handle = param_to_acc_grad_map [ param ] .register_hook (optimizer_hook ) # type: ignore[attr-defined]
76- if param not in param_to_optim_hook_handle_map :
77- param_to_optim_hook_handle_map [ param ] = []
78- param_to_optim_hook_handle_map [ param ]. append (handle )
76+ handle = param . _acc_grad .register_hook (optimizer_hook ) # type: ignore[attr-defined]
77+ if not hasattr ( param , '_optimizer_hook_handles' ) :
78+ param . _optimizer_hook_handles = [] # type: ignore[attr-defined ]
79+ param . _optimizer_hook_handles . append (handle ) # type: ignore[attr-defined]
7980
8081 for param in params :
8182 _apply_optimizer_in_backward_to_param (param )
0 commit comments