1- import warnings
21import contextlib
3- from typing import Any , Callable , Dict , Iterator , List , Tuple , Union
2+ from typing import Any , Callable , Dict , Iterator , List , Tuple , Union , Set , Optional
43
54import torch
65from torch import Tensor
@@ -37,21 +36,84 @@ def _setattr(self, name: str, value: Any) -> None:
3736 module ._orig_class = cls
3837
3938
40- def _check_tied_val_already_replaced (old_val , new_val , replaced_tensors_map ):
41- if old_val not in replaced_tensors_map :
42- replaced_tensors_map [old_val ] = new_val
43- elif replaced_tensors_map [old_val ] is not new_val :
44- warnings .warn ("functional_call was passed multiple values for tied weights. "
45- "This behavior is deprecated and will be an error in future versions" )
39+ def _create_tied_weights_map (module : 'torch.nn.Module' , params_and_buffers : Dict [str , Tensor ]) -> Dict [str , str ]:
40+ """
41+ _create_tied_weights_map(module: Module, params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]
42+
43+ Creates a weight map of {tied_name: name_given_by_user} for all weights where one of their tied weights is passed
44+
45+ ex: Foo() has self.foo and self.tied_foo, which are tied. If a user passed {'foo': ...} as the reparamaterization,
46+ this would return {'tied_foo': 'foo'}. Similarly if a user passed {'tied_foo': ...}, this returns
47+ {'tied_foo': 'foo'}.
48+
49+ ex: If there aren't any tied weights and the user passed values for every parameter and buffer, this will return a
50+ map where every name maps to an empty set: {'l1.weight': set(), 'l1.bias': set(), ...}
4651
52+ ex: The map only contains values that a user is reparamaterizing. For example, if module = nn.Linear(...) and the
53+ user only passed a new value for 'bias', this looks returns: {'bias': set()}
4754
48- def _create_swap_params (params_and_buffers , replaced_tensors_map ):
49- def _swap_parameters (module , tensor_name : str , full_path : str , tensor : Tensor ) -> None :
55+ This is useful because we will start by reparamaterizing all the keys of params_and_buffers, then all the key from
56+ this returned dictionary.
57+ """
58+
59+ # The basic algorithm looks like:
60+ # - index all weights by their original tensor value to find tied weights
61+ # - when we encounter a weight not used by the user, we save it in a set (second element in the tuple)
62+ # - when we run into a weight used by the user, we save that separate from the set as the first element in the tuple
63+ # - ending map looks like {tensor: (name_given_by_user, set(all_tied_names)}
64+ # - then loop through the values of this map (name_given_by_user and set(all_tied_names))
65+ # - for each element of all_tied_names, add {tied_name: name_given_by_user} to a new map
66+
67+ names = params_and_buffers .keys ()
68+ weight_to_name_and_tied_names : Dict [torch .Tensor , Tuple [Optional [str ], Set [str ]]] = {}
69+
70+ # create a map keyed by tensor value so that tied weights get mapped to the same key. The value is the interesting
71+ # part at the end it's (used_name, (tied_names)).
72+ # For example, in the first example where there's tied weights self.foo and self.tied_foo and the user passes a
73+ # value for self.foo, this will return {torch.Tensor(...): ('foo', set('tied_foo'))}
74+ def add_to_name_map (n : str , t : torch .Tensor ):
75+ # if the tensor hasn't been seen before, add it to the map
76+ if t not in weight_to_name_and_tied_names :
77+ weight_to_name_and_tied_names [t ] = (n , set ()) if n in names else (None , {n })
78+ return
79+
80+ # if the name is not used by the user, we add it to the tied set
81+ if n not in names :
82+ weight_to_name_and_tied_names [t ][1 ].add (n )
83+ return
84+
85+ # check that the user didn't pass two different tensors for the same tied weight
86+ first_seen_name = weight_to_name_and_tied_names [t ][0 ]
87+
88+ # if they didn't pass multiple names for tied weights or used the same tensor, we set the used name
89+ if first_seen_name is None or params_and_buffers [n ] is params_and_buffers [first_seen_name ]:
90+ weight_to_name_and_tied_names [t ] = (n , weight_to_name_and_tied_names [t ][1 ])
91+ return
92+
93+ raise ValueError (f"functional_call got values for both { n } and { first_seen_name } , which are tied. " +
94+ "Consider using tie_weights=False" )
95+
96+ tensor : Tensor
97+ for name , tensor in module .named_parameters (remove_duplicate = False ):
98+ add_to_name_map (name , tensor )
99+
100+ for name , tensor in module .named_buffers (remove_duplicate = False ):
101+ add_to_name_map (name , tensor )
102+
103+ # make {tied_name: name_given_by_user} from pairs of (name_given_by_user, set(all_tied_names))
104+ tied_weights_to_given_name = {}
105+ for name_given_by_user , tied_names in weight_to_name_and_tied_names .values ():
106+ if name_given_by_user is None : # no mapping was passed for this tensor, use original tensor
107+ continue
108+ for tied_name in tied_names :
109+ tied_weights_to_given_name [tied_name ] = name_given_by_user
110+ return tied_weights_to_given_name
111+
112+
113+ def _create_swap_params (params_and_buffers ):
114+ def _swap_parameters (module , tensor_name : str , full_path : str , tensor : Optional [Tensor ]) -> None :
50115 # Changes the module class to get a new __getattr__ dunder method
51116 # that looks for the reparametrized tensor
52- if hasattr (module , tensor_name ):
53- old_val = getattr (module , tensor_name )
54- _check_tied_val_already_replaced (old_val , tensor , replaced_tensors_map )
55117 if hasattr (module , "_attr_to_path" ):
56118 module ._attr_to_path [tensor_name ] = full_path
57119 else :
@@ -72,12 +134,17 @@ def _remove_swap(module, name: str, full_path: str) -> None:
72134def _reparametrize_module (
73135 module : 'torch.nn.Module' ,
74136 parameters_and_buffers : Dict [str , Tensor ],
137+ tie_weights : bool = False ,
75138) -> Iterator [None ]:
76- orig_tensors_to_replacements : Dict [ Tensor , Tensor ] = {}
139+ tied_weights_map = _create_tied_weights_map ( module , parameters_and_buffers ) if tie_weights else {}
77140 for name , tensor in parameters_and_buffers .items ():
78141 _apply_func_submodules (
79- _create_swap_params (parameters_and_buffers , orig_tensors_to_replacements ),
142+ _create_swap_params (parameters_and_buffers ),
80143 module , name .split ("." ), name , (tensor ,))
144+ for tied_name , user_given_name in tied_weights_map .items ():
145+ _apply_func_submodules (
146+ _create_swap_params (parameters_and_buffers ),
147+ module , tied_name .split ("." ), user_given_name , (None ,))
81148 try :
82149 yield
83150 finally :
@@ -105,6 +172,8 @@ def functional_call(
105172 parameters_and_buffers : Dict [str , Tensor ],
106173 args : Union [Any , Tuple ],
107174 kwargs : Dict [str , Any ] = None ,
175+ * ,
176+ tie_weights : bool = True ,
108177):
109178 r"""Performs a functional call on the module by replacing the module parameters
110179 and buffers with the provided ones.
@@ -128,12 +197,31 @@ def functional_call(
128197 >>> print(mod.foo) # tensor(0.)
129198 >>> print(a['foo']) # tensor(1.)
130199
200+ .. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
201+ tie_weights flag.
202+
203+ Example::
204+
205+ >>> a = {'foo': torch.zeros(())}
206+ >>> # xdoctest: +SKIP
207+ >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
208+ >>> print(mod.foo) # tensor(1.)
209+ >>> mod(torch.zeros(())) # tensor(2.)
210+ >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too
211+ >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated
212+ >>> new_a = {'foo', torch.zeros(()), 'foo_tied': torch.zeros(())}
213+ >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
214+
131215 Args:
132216 module (torch.nn.Module): the module to call
133217 parameters_and_buffers (dict of str and Tensor): the parameters that will be used in
134218 the module call.
135219 args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
136220 kwargs (dict): keyword arguments to be passed to the module call
221+ tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
222+ tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
223+ paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and
224+ buffers unless the values passed for both weights are the same. Default: True.
137225
138226 Returns:
139227 Any: the result of calling ``module``.
@@ -151,7 +239,7 @@ def functional_call(
151239 raise RuntimeError ("The stateless API can't be used with Jitted modules" )
152240 if kwargs is None :
153241 kwargs = {}
154- with _reparametrize_module (module , parameters_and_buffers ):
242+ with _reparametrize_module (module , parameters_and_buffers , tie_weights ):
155243 if isinstance (args , tuple ):
156244 out = module (* args , ** kwargs )
157245 else :
0 commit comments