1- import warnings
21import contextlib
3- from typing import Any , Callable , Dict , Iterator , List , Tuple , Union
2+ from typing import Any , Callable , Dict , Iterator , List , Tuple , Union , Set , Optional
43
54import torch
65from torch import Tensor
@@ -37,21 +36,52 @@ def _setattr(self, name: str, value: Any) -> None:
3736 module ._orig_class = cls
3837
3938
40- def _check_tied_val_already_replaced (old_val , new_val , replaced_tensors_map ):
41- if old_val not in replaced_tensors_map :
42- replaced_tensors_map [old_val ] = new_val
43- elif replaced_tensors_map [old_val ] is not new_val :
44- warnings .warn ("functional_call was passed multiple values for tied weights. "
45- "This behavior is deprecated and will be an error in future versions" )
39+ def _create_tied_weights_map (module , params_and_buffers ):
40+ # creates a weight map of {tied_name: name_given_by_user} for all weights where one of their tied weights is passed
41+ #
42+ # The basic algorithm looks like:
43+ # - index all weights by their original tensor value to find tied weights
44+ # - when we encounter a weight not used by the user, we save it in a set (second element in the tuple)
45+ # - when we run into a weight used by the user, we save that separate from the set as the first element in the tuple
46+ # - ending map looks like {tensor: (name_given_by_user, set(all_tied_names)}
47+ # - then loop through the values of this map (name_given_by_user and set(all_tied_names))
48+ # - for each element of all_tied_names, add {tied_name: name_given_by_user} to a new map
49+
50+ names = params_and_buffers .keys ()
51+ weight_to_name_and_tied_names : Dict [torch .Tensor , Tuple [Optional [str ], Set [str ]]] = {}
52+
53+ def add_to_name_map (name , t ):
54+ if t in weight_to_name_and_tied_names :
55+ first_seen_name = weight_to_name_and_tied_names [t ][0 ]
56+ if name in names and first_seen_name and params_and_buffers [name ] is not params_and_buffers [first_seen_name ]:
57+ raise ValueError (f"functional_call got values for both { name } and { first_seen_name } , which are tied." )
58+ elif name in names :
59+ weight_to_name_and_tied_names [t ] = (name , weight_to_name_and_tied_names [t ][1 ])
60+ else :
61+ weight_to_name_and_tied_names [t ][1 ].add (name )
62+ else :
63+ weight_to_name_and_tied_names [t ] = (name , set ()) if name in names else (None , {name })
64+
65+ for name , t in module .named_parameters (remove_duplicate = False ):
66+ add_to_name_map (name , t )
67+
68+ for name , t in module .named_buffers (remove_duplicate = False ):
69+ add_to_name_map (name , t )
4670
71+ # make {tied_name: name_given_by_user} from pairs of (name_given_by_user, set(all_tied_names))
72+ tied_weights_to_given_name = {}
73+ for name , tied_names in weight_to_name_and_tied_names .values ():
74+ if name is None : # no mapping was passed for this tensor, use original tensor
75+ continue
76+ for tied_name in tied_names :
77+ tied_weights_to_given_name [tied_name ] = name
78+ return tied_weights_to_given_name
4779
48- def _create_swap_params (params_and_buffers , replaced_tensors_map ):
49- def _swap_parameters (module , tensor_name : str , full_path : str , tensor : Tensor ) -> None :
80+
81+ def _create_swap_params (params_and_buffers ):
82+ def _swap_parameters (module , tensor_name : str , full_path : str , tensor : Optional [Tensor ]) -> None :
5083 # Changes the module class to get a new __getattr__ dunder method
5184 # that looks for the reparametrized tensor
52- if hasattr (module , tensor_name ):
53- old_val = getattr (module , tensor_name )
54- _check_tied_val_already_replaced (old_val , tensor , replaced_tensors_map )
5585 if hasattr (module , "_attr_to_path" ):
5686 module ._attr_to_path [tensor_name ] = full_path
5787 else :
@@ -72,12 +102,17 @@ def _remove_swap(module, name: str, full_path: str) -> None:
72102def _reparametrize_module (
73103 module : 'torch.nn.Module' ,
74104 parameters_and_buffers : Dict [str , Tensor ],
105+ tie_weights : bool = False ,
75106) -> Iterator [None ]:
76- orig_tensors_to_replacements : Dict [ Tensor , Tensor ] = {}
107+ tied_weights_map = _create_tied_weights_map ( module , parameters_and_buffers ) if tie_weights else {}
77108 for name , tensor in parameters_and_buffers .items ():
78109 _apply_func_submodules (
79- _create_swap_params (parameters_and_buffers , orig_tensors_to_replacements ),
110+ _create_swap_params (parameters_and_buffers ),
80111 module , name .split ("." ), name , (tensor ,))
112+ for tied_name , user_given_name in tied_weights_map .items ():
113+ _apply_func_submodules (
114+ _create_swap_params (parameters_and_buffers ),
115+ module , tied_name .split ("." ), user_given_name , (None ,))
81116 try :
82117 yield
83118 finally :
@@ -105,6 +140,7 @@ def functional_call(
105140 parameters_and_buffers : Dict [str , Tensor ],
106141 args : Union [Any , Tuple ],
107142 kwargs : Dict [str , Any ] = None ,
143+ tie_weights : bool = False ,
108144):
109145 r"""Performs a functional call on the module by replacing the module parameters
110146 and buffers with the provided ones.
@@ -151,7 +187,7 @@ def functional_call(
151187 raise RuntimeError ("The stateless API can't be used with Jitted modules" )
152188 if kwargs is None :
153189 kwargs = {}
154- with _reparametrize_module (module , parameters_and_buffers ):
190+ with _reparametrize_module (module , parameters_and_buffers , tie_weights ):
155191 if isinstance (args , tuple ):
156192 out = module (* args , ** kwargs )
157193 else :
0 commit comments