Skip to content

Commit 8b3c4bc

Browse files
samdowpytorchmergebot
authored andcommitted
[stateless] add weight tying support (#90477)
Pull Request resolved: #90477 Approved by: https://github.com/zou3519
1 parent e03ac0e commit 8b3c4bc

File tree

3 files changed

+202
-24
lines changed

3 files changed

+202
-24
lines changed

test/test_stateless.py

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,17 @@ def __init__(self):
2222
def forward(self, x):
2323
return self.l1(x) + self.buffer
2424

25+
class MockTiedModule(torch.nn.Module):
26+
def __init__(self):
27+
super().__init__()
28+
self.l1 = torch.nn.Linear(1, 1)
29+
self.tied_bias = self.l1.bias
30+
self.register_buffer('buffer', torch.ones(1))
31+
self.register_buffer('tied_buffer', self.buffer)
32+
33+
def forward(self, x):
34+
return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer
35+
2536

2637
class TestStatelessFunctionalAPI(TestCase):
2738
def _run_call_with_mock_module(self, module, functional_call, device='cpu', prefix=''):
@@ -156,7 +167,7 @@ def test_circular_references(self, functional_call):
156167
'l1.m.buffer': buffer}
157168
prev_weight = module.l1.weight.clone()
158169
prev_buffer = module.buffer.clone()
159-
res = functional_call(module, parameters, x)
170+
res = functional_call(module, parameters, x, tie_weights=False)
160171
self.assertEqual(x, res)
161172
# check that the weights remain unmodified and were correctly accesed
162173
cur_weight = module.l1.weight
@@ -217,6 +228,46 @@ def test_tied_weights_warns(self, functional_call):
217228
module = MockModule()
218229
module.tied_bias = module.l1.bias
219230
module.register_buffer("tied_buffer", module.buffer)
231+
232+
@parametrize("functional_call", [
233+
subtest(torch.func.functional_call, "torch_func"),
234+
subtest(stateless.functional_call, "stateless")
235+
])
236+
def test_reparamertize_tie_weights(self, functional_call):
237+
module = MockTiedModule()
238+
weight = torch.tensor([[2.0]],)
239+
bias = torch.tensor([5.0])
240+
buffer = torch.tensor([3.0])
241+
242+
parameters = {'l1.weight': weight,
243+
'l1.bias': bias,
244+
'buffer': buffer}
245+
x = torch.randn(1, 1)
246+
out = functional_call(module, parameters, x, tie_weights=True)
247+
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
248+
249+
250+
@parametrize("functional_call", [
251+
subtest(torch.func.functional_call, "torch_func"),
252+
subtest(stateless.functional_call, "stateless")
253+
])
254+
def test_reparamertize_tie_some_weights(self, functional_call):
255+
module = MockTiedModule()
256+
weight = torch.tensor([[2.0]],)
257+
buffer = torch.tensor([3.0])
258+
259+
parameters = {'l1.weight': weight,
260+
'buffer': buffer}
261+
x = torch.randn(1, 1)
262+
out = stateless.functional_call(module, parameters, x, tie_weights=True)
263+
self.assertEqual(out, x * 2. + module.l1.bias + module.tied_bias + buffer + buffer)
264+
265+
@parametrize("functional_call", [
266+
subtest(torch.func.functional_call, "torch_func"),
267+
subtest(stateless.functional_call, "stateless")
268+
])
269+
def test_tied_weights_errors(self, functional_call):
270+
module = MockTiedModule()
220271
weight = torch.tensor([[1.0]],)
221272
bias = torch.tensor([0.0])
222273
buffer = torch.tensor([0.0])
@@ -225,23 +276,41 @@ def test_tied_weights_warns(self, functional_call):
225276
'l1.bias': bias,
226277
'buffer': buffer}
227278
x = torch.randn(1, 1)
228-
self.assertNotWarn(lambda: functional_call(module, parameters, x))
279+
self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
229280

230281
# if tied values are the same tensors, shouldn't warn
231282
parameters['tied_bias'] = bias
232283
parameters['tied_buffer'] = buffer
233-
self.assertNotWarn(lambda: functional_call(module, parameters, x))
284+
self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
234285
del parameters['tied_bias']
235286
del parameters['tied_buffer']
236287

237-
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
288+
with self.assertRaisesRegex(ValueError, "functional_call got values for both (l1.bias|tied_bias)"):
238289
parameters['tied_bias'] = torch.tensor([5.0])
239-
functional_call(module, parameters, x)
290+
functional_call(module, parameters, x, tie_weights=True)
240291
del parameters['tied_bias']
241292

242-
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
293+
with self.assertRaisesRegex(ValueError, "functional_call got values for both (buffer|tied_buffer)"):
243294
parameters['tied_buffer'] = torch.tensor([5.0])
244-
functional_call(module, parameters, x)
295+
functional_call(module, parameters, x, tie_weights=True)
296+
297+
298+
def test_tied_weights_no_error_without_flag(self):
299+
module = MockTiedModule()
300+
weight = torch.tensor([[1.0]],)
301+
bias = torch.tensor([0.0])
302+
buffer = torch.tensor([0.0])
303+
304+
parameters = {'l1.weight': weight,
305+
'l1.bias': bias,
306+
'buffer': buffer}
307+
x = torch.randn(1, 1)
308+
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
309+
parameters['tied_bias'] = torch.tensor([5.0])
310+
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
311+
del parameters['tied_bias']
312+
parameters['tied_buffer'] = torch.tensor([5.0])
313+
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
245314

246315
@parametrize("functional_call", [
247316
subtest(torch.func.functional_call, "torch_func"),

torch/_functorch/functional_call.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ def functional_call(
1212
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], ...]],
1313
args: Union[Any, Tuple],
1414
kwargs: Dict[str, Any] = None,
15+
*,
16+
tie_weights: bool = True,
1517
):
1618
r"""Performs a functional call on the module by replacing the module parameters
1719
and buffers with the provided ones.
@@ -36,6 +38,21 @@ def functional_call(
3638
>>> print(mod.foo) # tensor(0.)
3739
>>> print(a['foo']) # tensor(1.)
3840
41+
.. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
42+
tie_weights flag.
43+
44+
Example::
45+
46+
>>> a = {'foo': torch.zeros(())}
47+
>>> # xdoctest: +SKIP
48+
>>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
49+
>>> print(mod.foo) # tensor(1.)
50+
>>> mod(torch.zeros(())) # tensor(2.)
51+
>>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too
52+
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated
53+
>>> new_a = {'foo', torch.zeros(()), 'foo_tied': torch.zeros(())}
54+
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
55+
3956
An example of passing mutliple dictionaries
4057
4158
.. code-block:: python
@@ -88,6 +105,10 @@ def compute_loss(params, x, t):
88105
be used together
89106
args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
90107
kwargs (dict): keyword arguments to be passed to the module call
108+
tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
109+
tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
110+
paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and
111+
buffers unless the values passed for both weights are the same. Default: True.
91112
92113
Returns:
93114
Any: the result of calling ``module``.
@@ -102,7 +123,7 @@ def compute_loss(params, x, t):
102123

103124
parameters_and_buffers = {k: v for d in parameter_and_buffer_dicts for k, v in d.items()}
104125

105-
return nn.utils.stateless.functional_call(module, parameters_and_buffers, args, kwargs)
126+
return nn.utils.stateless.functional_call(module, parameters_and_buffers, args, kwargs, tie_weights=tie_weights)
106127

107128

108129
@exposed_in("torch.func")

torch/nn/utils/stateless.py

Lines changed: 104 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
import warnings
21
import contextlib
3-
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
2+
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union, Set, Optional
43

54
import torch
65
from torch import Tensor
@@ -37,21 +36,84 @@ def _setattr(self, name: str, value: Any) -> None:
3736
module._orig_class = cls
3837

3938

40-
def _check_tied_val_already_replaced(old_val, new_val, replaced_tensors_map):
41-
if old_val not in replaced_tensors_map:
42-
replaced_tensors_map[old_val] = new_val
43-
elif replaced_tensors_map[old_val] is not new_val:
44-
warnings.warn("functional_call was passed multiple values for tied weights. "
45-
"This behavior is deprecated and will be an error in future versions")
39+
def _create_tied_weights_map(module: 'torch.nn.Module', params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]:
40+
"""
41+
_create_tied_weights_map(module: Module, params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]
42+
43+
Creates a weight map of {tied_name: name_given_by_user} for all weights where one of their tied weights is passed
44+
45+
ex: Foo() has self.foo and self.tied_foo, which are tied. If a user passed {'foo': ...} as the reparamaterization,
46+
this would return {'tied_foo': 'foo'}. Similarly if a user passed {'tied_foo': ...}, this returns
47+
{'tied_foo': 'foo'}.
48+
49+
ex: If there aren't any tied weights and the user passed values for every parameter and buffer, this will return a
50+
map where every name maps to an empty set: {'l1.weight': set(), 'l1.bias': set(), ...}
4651
52+
ex: The map only contains values that a user is reparamaterizing. For example, if module = nn.Linear(...) and the
53+
user only passed a new value for 'bias', this looks returns: {'bias': set()}
4754
48-
def _create_swap_params(params_and_buffers, replaced_tensors_map):
49-
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None:
55+
This is useful because we will start by reparamaterizing all the keys of params_and_buffers, then all the key from
56+
this returned dictionary.
57+
"""
58+
59+
# The basic algorithm looks like:
60+
# - index all weights by their original tensor value to find tied weights
61+
# - when we encounter a weight not used by the user, we save it in a set (second element in the tuple)
62+
# - when we run into a weight used by the user, we save that separate from the set as the first element in the tuple
63+
# - ending map looks like {tensor: (name_given_by_user, set(all_tied_names)}
64+
# - then loop through the values of this map (name_given_by_user and set(all_tied_names))
65+
# - for each element of all_tied_names, add {tied_name: name_given_by_user} to a new map
66+
67+
names = params_and_buffers.keys()
68+
weight_to_name_and_tied_names: Dict[torch.Tensor, Tuple[Optional[str], Set[str]]] = {}
69+
70+
# create a map keyed by tensor value so that tied weights get mapped to the same key. The value is the interesting
71+
# part at the end it's (used_name, (tied_names)).
72+
# For example, in the first example where there's tied weights self.foo and self.tied_foo and the user passes a
73+
# value for self.foo, this will return {torch.Tensor(...): ('foo', set('tied_foo'))}
74+
def add_to_name_map(n: str, t: torch.Tensor):
75+
# if the tensor hasn't been seen before, add it to the map
76+
if t not in weight_to_name_and_tied_names:
77+
weight_to_name_and_tied_names[t] = (n, set()) if n in names else (None, {n})
78+
return
79+
80+
# if the name is not used by the user, we add it to the tied set
81+
if n not in names:
82+
weight_to_name_and_tied_names[t][1].add(n)
83+
return
84+
85+
# check that the user didn't pass two different tensors for the same tied weight
86+
first_seen_name = weight_to_name_and_tied_names[t][0]
87+
88+
# if they didn't pass multiple names for tied weights or used the same tensor, we set the used name
89+
if first_seen_name is None or params_and_buffers[n] is params_and_buffers[first_seen_name]:
90+
weight_to_name_and_tied_names[t] = (n, weight_to_name_and_tied_names[t][1])
91+
return
92+
93+
raise ValueError(f"functional_call got values for both {n} and {first_seen_name}, which are tied. " +
94+
"Consider using tie_weights=False")
95+
96+
tensor: Tensor
97+
for name, tensor in module.named_parameters(remove_duplicate=False):
98+
add_to_name_map(name, tensor)
99+
100+
for name, tensor in module.named_buffers(remove_duplicate=False):
101+
add_to_name_map(name, tensor)
102+
103+
# make {tied_name: name_given_by_user} from pairs of (name_given_by_user, set(all_tied_names))
104+
tied_weights_to_given_name = {}
105+
for name_given_by_user, tied_names in weight_to_name_and_tied_names.values():
106+
if name_given_by_user is None: # no mapping was passed for this tensor, use original tensor
107+
continue
108+
for tied_name in tied_names:
109+
tied_weights_to_given_name[tied_name] = name_given_by_user
110+
return tied_weights_to_given_name
111+
112+
113+
def _create_swap_params(params_and_buffers):
114+
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Optional[Tensor]) -> None:
50115
# Changes the module class to get a new __getattr__ dunder method
51116
# that looks for the reparametrized tensor
52-
if hasattr(module, tensor_name):
53-
old_val = getattr(module, tensor_name)
54-
_check_tied_val_already_replaced(old_val, tensor, replaced_tensors_map)
55117
if hasattr(module, "_attr_to_path"):
56118
module._attr_to_path[tensor_name] = full_path
57119
else:
@@ -72,12 +134,17 @@ def _remove_swap(module, name: str, full_path: str) -> None:
72134
def _reparametrize_module(
73135
module: 'torch.nn.Module',
74136
parameters_and_buffers: Dict[str, Tensor],
137+
tie_weights: bool = False,
75138
) -> Iterator[None]:
76-
orig_tensors_to_replacements: Dict[Tensor, Tensor] = {}
139+
tied_weights_map = _create_tied_weights_map(module, parameters_and_buffers) if tie_weights else {}
77140
for name, tensor in parameters_and_buffers.items():
78141
_apply_func_submodules(
79-
_create_swap_params(parameters_and_buffers, orig_tensors_to_replacements),
142+
_create_swap_params(parameters_and_buffers),
80143
module, name.split("."), name, (tensor,))
144+
for tied_name, user_given_name in tied_weights_map.items():
145+
_apply_func_submodules(
146+
_create_swap_params(parameters_and_buffers),
147+
module, tied_name.split("."), user_given_name, (None,))
81148
try:
82149
yield
83150
finally:
@@ -105,6 +172,8 @@ def functional_call(
105172
parameters_and_buffers: Dict[str, Tensor],
106173
args: Union[Any, Tuple],
107174
kwargs: Dict[str, Any] = None,
175+
*,
176+
tie_weights: bool = True,
108177
):
109178
r"""Performs a functional call on the module by replacing the module parameters
110179
and buffers with the provided ones.
@@ -128,12 +197,31 @@ def functional_call(
128197
>>> print(mod.foo) # tensor(0.)
129198
>>> print(a['foo']) # tensor(1.)
130199
200+
.. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
201+
tie_weights flag.
202+
203+
Example::
204+
205+
>>> a = {'foo': torch.zeros(())}
206+
>>> # xdoctest: +SKIP
207+
>>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
208+
>>> print(mod.foo) # tensor(1.)
209+
>>> mod(torch.zeros(())) # tensor(2.)
210+
>>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too
211+
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated
212+
>>> new_a = {'foo', torch.zeros(()), 'foo_tied': torch.zeros(())}
213+
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)
214+
131215
Args:
132216
module (torch.nn.Module): the module to call
133217
parameters_and_buffers (dict of str and Tensor): the parameters that will be used in
134218
the module call.
135219
args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
136220
kwargs (dict): keyword arguments to be passed to the module call
221+
tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
222+
tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
223+
paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and
224+
buffers unless the values passed for both weights are the same. Default: True.
137225
138226
Returns:
139227
Any: the result of calling ``module``.
@@ -151,7 +239,7 @@ def functional_call(
151239
raise RuntimeError("The stateless API can't be used with Jitted modules")
152240
if kwargs is None:
153241
kwargs = {}
154-
with _reparametrize_module(module, parameters_and_buffers):
242+
with _reparametrize_module(module, parameters_and_buffers, tie_weights):
155243
if isinstance(args, tuple):
156244
out = module(*args, **kwargs)
157245
else:

0 commit comments

Comments
 (0)