Skip to content

Commit a190722

Browse files
author
samdow
committed
[stateless] add weight tying support
ghstack-source-id: 87031f7 Pull Request resolved: #90477
1 parent e9dc8cc commit a190722

File tree

2 files changed

+114
-27
lines changed

2 files changed

+114
-27
lines changed

test/test_stateless.py

Lines changed: 62 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,17 @@ def __init__(self):
2121
def forward(self, x):
2222
return self.l1(x) + self.buffer
2323

24+
class MockTiedModule(torch.nn.Module):
25+
def __init__(self):
26+
super().__init__()
27+
self.l1 = torch.nn.Linear(1, 1)
28+
self.tied_bias = self.l1.bias
29+
self.register_buffer('buffer', torch.ones(1))
30+
self.register_buffer('tied_buffer', self.buffer)
31+
32+
def forward(self, x):
33+
return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer
34+
2435

2536
class TestStatelessFunctionalAPI(TestCase):
2637
def _run_call_with_mock_module(self, module, device='cpu', prefix=''):
@@ -130,7 +141,7 @@ def test_circular_references(self):
130141
'l1.m.buffer': buffer}
131142
prev_weight = module.l1.weight.clone()
132143
prev_buffer = module.buffer.clone()
133-
res = stateless.functional_call(module, parameters, x)
144+
res = stateless.functional_call(module, parameters, x, tie_weights=False)
134145
self.assertEqual(x, res)
135146
# check that the weights remain unmodified and were correctly accesed
136147
cur_weight = module.l1.weight
@@ -176,10 +187,32 @@ def test_reparamertize_module_fail_reset_to_original(self):
176187
self.assertEqual(orig_sn_weight, module.l1.weight)
177188

178189

179-
def test_tied_weights_warns(self):
180-
module = MockModule()
181-
module.tied_bias = module.l1.bias
182-
module.register_buffer("tied_buffer", module.buffer)
190+
def test_reparamertize_tie_weights(self):
191+
module = MockTiedModule()
192+
weight = torch.tensor([[2.0]],)
193+
bias = torch.tensor([5.0])
194+
buffer = torch.tensor([3.0])
195+
196+
parameters = {'l1.weight': weight,
197+
'l1.bias': bias,
198+
'buffer': buffer}
199+
x = torch.randn(1, 1)
200+
out = stateless.functional_call(module, parameters, x, tie_weights=True)
201+
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)
202+
203+
def test_reparamertize_tie_some_weights(self):
204+
module = MockTiedModule()
205+
weight = torch.tensor([[2.0]],)
206+
buffer = torch.tensor([3.0])
207+
208+
parameters = {'l1.weight': weight,
209+
'buffer': buffer}
210+
x = torch.randn(1, 1)
211+
out = stateless.functional_call(module, parameters, x, tie_weights=True)
212+
self.assertEqual(out, x * 2. + module.l1.bias + module.tied_bias + buffer + buffer)
213+
214+
def test_tied_weights_errors(self):
215+
module = MockTiedModule()
183216
weight = torch.tensor([[1.0]],)
184217
bias = torch.tensor([0.0])
185218
buffer = torch.tensor([0.0])
@@ -188,23 +221,41 @@ def test_tied_weights_warns(self):
188221
'l1.bias': bias,
189222
'buffer': buffer}
190223
x = torch.randn(1, 1)
191-
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x))
224+
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=True))
192225

193226
# if tied values are the same tensors, shouldn't warn
194227
parameters['tied_bias'] = bias
195228
parameters['tied_buffer'] = buffer
196-
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x))
229+
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=True))
197230
del parameters['tied_bias']
198231
del parameters['tied_buffer']
199232

200-
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
233+
with self.assertRaisesRegex(ValueError, "functional_call got values for both (l1.bias|tied_bias)"):
201234
parameters['tied_bias'] = torch.tensor([5.0])
202-
stateless.functional_call(module, parameters, x)
235+
stateless.functional_call(module, parameters, x, tie_weights=True)
203236
del parameters['tied_bias']
204237

205-
with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
238+
with self.assertRaisesRegex(ValueError, "functional_call got values for both (buffer|tied_buffer)"):
206239
parameters['tied_buffer'] = torch.tensor([5.0])
207-
stateless.functional_call(module, parameters, x)
240+
stateless.functional_call(module, parameters, x, tie_weights=True)
241+
242+
243+
def test_tied_weights_no_error_without_kwarg(self):
244+
module = MockTiedModule()
245+
weight = torch.tensor([[1.0]],)
246+
bias = torch.tensor([0.0])
247+
buffer = torch.tensor([0.0])
248+
249+
parameters = {'l1.weight': weight,
250+
'l1.bias': bias,
251+
'buffer': buffer}
252+
x = torch.randn(1, 1)
253+
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
254+
parameters['tied_bias'] = torch.tensor([5.0])
255+
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
256+
del parameters['tied_bias']
257+
parameters['tied_buffer'] = torch.tensor([5.0])
258+
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
208259

209260

210261
def test_setattr(self):

torch/nn/utils/stateless.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
import warnings
21
import contextlib
3-
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
2+
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union, Set, Optional
43

54
import torch
65
from torch import Tensor
@@ -37,21 +36,52 @@ def _setattr(self, name: str, value: Any) -> None:
3736
module._orig_class = cls
3837

3938

40-
def _check_tied_val_already_replaced(old_val, new_val, replaced_tensors_map):
41-
if old_val not in replaced_tensors_map:
42-
replaced_tensors_map[old_val] = new_val
43-
elif replaced_tensors_map[old_val] is not new_val:
44-
warnings.warn("functional_call was passed multiple values for tied weights. "
45-
"This behavior is deprecated and will be an error in future versions")
39+
def _create_tied_weights_map(module, params_and_buffers):
40+
# creates a weight map of {tied_name: name_given_by_user} for all weights where one of their tied weights is passed
41+
#
42+
# The basic algorithm looks like:
43+
# - index all weights by their original tensor value to find tied weights
44+
# - when we encounter a weight not used by the user, we save it in a set (second element in the tuple)
45+
# - when we run into a weight used by the user, we save that separate from the set as the first element in the tuple
46+
# - ending map looks like {tensor: (name_given_by_user, set(all_tied_names)}
47+
# - then loop through the values of this map (name_given_by_user and set(all_tied_names))
48+
# - for each element of all_tied_names, add {tied_name: name_given_by_user} to a new map
49+
50+
names = params_and_buffers.keys()
51+
weight_to_name_and_tied_names: Dict[torch.Tensor, Tuple[Optional[str], Set[str]]] = {}
52+
53+
def add_to_name_map(name, t):
54+
if t in weight_to_name_and_tied_names:
55+
first_seen_name = weight_to_name_and_tied_names[t][0]
56+
if name in names and first_seen_name and params_and_buffers[name] is not params_and_buffers[first_seen_name]:
57+
raise ValueError(f"functional_call got values for both {name} and {first_seen_name}, which are tied.")
58+
elif name in names:
59+
weight_to_name_and_tied_names[t] = (name, weight_to_name_and_tied_names[t][1])
60+
else:
61+
weight_to_name_and_tied_names[t][1].add(name)
62+
else:
63+
weight_to_name_and_tied_names[t] = (name, set()) if name in names else (None, {name})
64+
65+
for name, t in module.named_parameters(remove_duplicate=False):
66+
add_to_name_map(name, t)
67+
68+
for name, t in module.named_buffers(remove_duplicate=False):
69+
add_to_name_map(name, t)
4670

71+
# make {tied_name: name_given_by_user} from pairs of (name_given_by_user, set(all_tied_names))
72+
tied_weights_to_given_name = {}
73+
for name, tied_names in weight_to_name_and_tied_names.values():
74+
if name is None: # no mapping was passed for this tensor, use original tensor
75+
continue
76+
for tied_name in tied_names:
77+
tied_weights_to_given_name[tied_name] = name
78+
return tied_weights_to_given_name
4779

48-
def _create_swap_params(params_and_buffers, replaced_tensors_map):
49-
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None:
80+
81+
def _create_swap_params(params_and_buffers):
82+
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Optional[Tensor]) -> None:
5083
# Changes the module class to get a new __getattr__ dunder method
5184
# that looks for the reparametrized tensor
52-
if hasattr(module, tensor_name):
53-
old_val = getattr(module, tensor_name)
54-
_check_tied_val_already_replaced(old_val, tensor, replaced_tensors_map)
5585
if hasattr(module, "_attr_to_path"):
5686
module._attr_to_path[tensor_name] = full_path
5787
else:
@@ -72,12 +102,17 @@ def _remove_swap(module, name: str, full_path: str) -> None:
72102
def _reparametrize_module(
73103
module: 'torch.nn.Module',
74104
parameters_and_buffers: Dict[str, Tensor],
105+
tie_weights: bool = False,
75106
) -> Iterator[None]:
76-
orig_tensors_to_replacements: Dict[Tensor, Tensor] = {}
107+
tied_weights_map = _create_tied_weights_map(module, parameters_and_buffers) if tie_weights else {}
77108
for name, tensor in parameters_and_buffers.items():
78109
_apply_func_submodules(
79-
_create_swap_params(parameters_and_buffers, orig_tensors_to_replacements),
110+
_create_swap_params(parameters_and_buffers),
80111
module, name.split("."), name, (tensor,))
112+
for tied_name, user_given_name in tied_weights_map.items():
113+
_apply_func_submodules(
114+
_create_swap_params(parameters_and_buffers),
115+
module, tied_name.split("."), user_given_name, (None,))
81116
try:
82117
yield
83118
finally:
@@ -105,6 +140,7 @@ def functional_call(
105140
parameters_and_buffers: Dict[str, Tensor],
106141
args: Union[Any, Tuple],
107142
kwargs: Dict[str, Any] = None,
143+
tie_weights: bool = False,
108144
):
109145
r"""Performs a functional call on the module by replacing the module parameters
110146
and buffers with the provided ones.
@@ -151,7 +187,7 @@ def functional_call(
151187
raise RuntimeError("The stateless API can't be used with Jitted modules")
152188
if kwargs is None:
153189
kwargs = {}
154-
with _reparametrize_module(module, parameters_and_buffers):
190+
with _reparametrize_module(module, parameters_and_buffers, tie_weights):
155191
if isinstance(args, tuple):
156192
out = module(*args, **kwargs)
157193
else:

0 commit comments

Comments
 (0)