Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 76 additions & 7 deletions test/test_stateless.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ def __init__(self):
def forward(self, x):
return self.l1(x) + self.buffer

class MockTiedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(1, 1)
self.tied_bias = self.l1.bias
self.register_buffer('buffer', torch.ones(1))
self.register_buffer('tied_buffer', self.buffer)

def forward(self, x):
return self.l1(x) + self.tied_bias + self.buffer + self.tied_buffer


class TestStatelessFunctionalAPI(TestCase):
def _run_call_with_mock_module(self, module, functional_call, device='cpu', prefix=''):
Expand Down Expand Up @@ -156,7 +167,7 @@ def test_circular_references(self, functional_call):
'l1.m.buffer': buffer}
prev_weight = module.l1.weight.clone()
prev_buffer = module.buffer.clone()
res = functional_call(module, parameters, x)
res = functional_call(module, parameters, x, tie_weights=False)
self.assertEqual(x, res)
# check that the weights remain unmodified and were correctly accesed
cur_weight = module.l1.weight
Expand Down Expand Up @@ -217,6 +228,46 @@ def test_tied_weights_warns(self, functional_call):
module = MockModule()
module.tied_bias = module.l1.bias
module.register_buffer("tied_buffer", module.buffer)

@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_reparamertize_tie_weights(self, functional_call):
module = MockTiedModule()
weight = torch.tensor([[2.0]],)
bias = torch.tensor([5.0])
buffer = torch.tensor([3.0])

parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer}
x = torch.randn(1, 1)
out = functional_call(module, parameters, x, tie_weights=True)
self.assertEqual(out, x * weight + bias + bias + buffer + buffer)


@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_reparamertize_tie_some_weights(self, functional_call):
module = MockTiedModule()
weight = torch.tensor([[2.0]],)
buffer = torch.tensor([3.0])

parameters = {'l1.weight': weight,
'buffer': buffer}
x = torch.randn(1, 1)
out = stateless.functional_call(module, parameters, x, tie_weights=True)
self.assertEqual(out, x * 2. + module.l1.bias + module.tied_bias + buffer + buffer)

@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
subtest(stateless.functional_call, "stateless")
])
def test_tied_weights_errors(self, functional_call):
module = MockTiedModule()
weight = torch.tensor([[1.0]],)
bias = torch.tensor([0.0])
buffer = torch.tensor([0.0])
Expand All @@ -225,23 +276,41 @@ def test_tied_weights_warns(self, functional_call):
'l1.bias': bias,
'buffer': buffer}
x = torch.randn(1, 1)
self.assertNotWarn(lambda: functional_call(module, parameters, x))
self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))

# if tied values are the same tensors, shouldn't warn
parameters['tied_bias'] = bias
parameters['tied_buffer'] = buffer
self.assertNotWarn(lambda: functional_call(module, parameters, x))
self.assertNotWarn(lambda: functional_call(module, parameters, x, tie_weights=True))
del parameters['tied_bias']
del parameters['tied_buffer']

with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
with self.assertRaisesRegex(ValueError, "functional_call got values for both (l1.bias|tied_bias)"):
parameters['tied_bias'] = torch.tensor([5.0])
functional_call(module, parameters, x)
functional_call(module, parameters, x, tie_weights=True)
del parameters['tied_bias']

with self.assertWarnsOnceRegex(UserWarning, "functional_call was passed multiple values"):
with self.assertRaisesRegex(ValueError, "functional_call got values for both (buffer|tied_buffer)"):
parameters['tied_buffer'] = torch.tensor([5.0])
functional_call(module, parameters, x)
functional_call(module, parameters, x, tie_weights=True)


def test_tied_weights_no_error_without_flag(self):
module = MockTiedModule()
weight = torch.tensor([[1.0]],)
bias = torch.tensor([0.0])
buffer = torch.tensor([0.0])

parameters = {'l1.weight': weight,
'l1.bias': bias,
'buffer': buffer}
x = torch.randn(1, 1)
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
parameters['tied_bias'] = torch.tensor([5.0])
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))
del parameters['tied_bias']
parameters['tied_buffer'] = torch.tensor([5.0])
self.assertNotWarn(lambda: stateless.functional_call(module, parameters, x, tie_weights=False))

@parametrize("functional_call", [
subtest(torch.func.functional_call, "torch_func"),
Expand Down
23 changes: 22 additions & 1 deletion torch/_functorch/functional_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def functional_call(
parameter_and_buffer_dicts: Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], ...]],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
*,
tie_weights: bool = True,
):
r"""Performs a functional call on the module by replacing the module parameters
and buffers with the provided ones.
Expand All @@ -36,6 +38,21 @@ def functional_call(
>>> print(mod.foo) # tensor(0.)
>>> print(a['foo']) # tensor(1.)

.. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
tie_weights flag.

Example::

>>> a = {'foo': torch.zeros(())}
>>> # xdoctest: +SKIP
>>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
>>> print(mod.foo) # tensor(1.)
>>> mod(torch.zeros(())) # tensor(2.)
>>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated
>>> new_a = {'foo', torch.zeros(()), 'foo_tied': torch.zeros(())}
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)

An example of passing mutliple dictionaries

.. code-block:: python
Expand Down Expand Up @@ -88,6 +105,10 @@ def compute_loss(params, x, t):
be used together
args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
kwargs (dict): keyword arguments to be passed to the module call
tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and
buffers unless the values passed for both weights are the same. Default: True.

Returns:
Any: the result of calling ``module``.
Expand All @@ -102,7 +123,7 @@ def compute_loss(params, x, t):

parameters_and_buffers = {k: v for d in parameter_and_buffer_dicts for k, v in d.items()}

return nn.utils.stateless.functional_call(module, parameters_and_buffers, args, kwargs)
return nn.utils.stateless.functional_call(module, parameters_and_buffers, args, kwargs, tie_weights=tie_weights)


@exposed_in("torch.func")
Expand Down
120 changes: 104 additions & 16 deletions torch/nn/utils/stateless.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import warnings
import contextlib
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union, Set, Optional

import torch
from torch import Tensor
Expand Down Expand Up @@ -37,21 +36,84 @@ def _setattr(self, name: str, value: Any) -> None:
module._orig_class = cls


def _check_tied_val_already_replaced(old_val, new_val, replaced_tensors_map):
if old_val not in replaced_tensors_map:
replaced_tensors_map[old_val] = new_val
elif replaced_tensors_map[old_val] is not new_val:
warnings.warn("functional_call was passed multiple values for tied weights. "
"This behavior is deprecated and will be an error in future versions")
def _create_tied_weights_map(module: 'torch.nn.Module', params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]:
"""
_create_tied_weights_map(module: Module, params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]

Creates a weight map of {tied_name: name_given_by_user} for all weights where one of their tied weights is passed

ex: Foo() has self.foo and self.tied_foo, which are tied. If a user passed {'foo': ...} as the reparamaterization,
this would return {'tied_foo': 'foo'}. Similarly if a user passed {'tied_foo': ...}, this returns
{'tied_foo': 'foo'}.

ex: If there aren't any tied weights and the user passed values for every parameter and buffer, this will return a
map where every name maps to an empty set: {'l1.weight': set(), 'l1.bias': set(), ...}

ex: The map only contains values that a user is reparamaterizing. For example, if module = nn.Linear(...) and the
user only passed a new value for 'bias', this looks returns: {'bias': set()}

def _create_swap_params(params_and_buffers, replaced_tensors_map):
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None:
This is useful because we will start by reparamaterizing all the keys of params_and_buffers, then all the key from
this returned dictionary.
"""

# The basic algorithm looks like:
# - index all weights by their original tensor value to find tied weights
# - when we encounter a weight not used by the user, we save it in a set (second element in the tuple)
# - when we run into a weight used by the user, we save that separate from the set as the first element in the tuple
# - ending map looks like {tensor: (name_given_by_user, set(all_tied_names)}
# - then loop through the values of this map (name_given_by_user and set(all_tied_names))
# - for each element of all_tied_names, add {tied_name: name_given_by_user} to a new map

names = params_and_buffers.keys()
weight_to_name_and_tied_names: Dict[torch.Tensor, Tuple[Optional[str], Set[str]]] = {}

# create a map keyed by tensor value so that tied weights get mapped to the same key. The value is the interesting
# part at the end it's (used_name, (tied_names)).
# For example, in the first example where there's tied weights self.foo and self.tied_foo and the user passes a
# value for self.foo, this will return {torch.Tensor(...): ('foo', set('tied_foo'))}
def add_to_name_map(n: str, t: torch.Tensor):
# if the tensor hasn't been seen before, add it to the map
if t not in weight_to_name_and_tied_names:
weight_to_name_and_tied_names[t] = (n, set()) if n in names else (None, {n})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for n in names, I am not sure what exact object the dict_keys object is. What is the complexity of this lookup?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return

# if the name is not used by the user, we add it to the tied set
if n not in names:
weight_to_name_and_tied_names[t][1].add(n)
return

# check that the user didn't pass two different tensors for the same tied weight
first_seen_name = weight_to_name_and_tied_names[t][0]

# if they didn't pass multiple names for tied weights or used the same tensor, we set the used name
if first_seen_name is None or params_and_buffers[n] is params_and_buffers[first_seen_name]:
weight_to_name_and_tied_names[t] = (n, weight_to_name_and_tied_names[t][1])
return

raise ValueError(f"functional_call got values for both {n} and {first_seen_name}, which are tied. " +
"Consider using tie_weights=False")

tensor: Tensor
for name, tensor in module.named_parameters(remove_duplicate=False):
add_to_name_map(name, tensor)

for name, tensor in module.named_buffers(remove_duplicate=False):
add_to_name_map(name, tensor)

# make {tied_name: name_given_by_user} from pairs of (name_given_by_user, set(all_tied_names))
tied_weights_to_given_name = {}
for name_given_by_user, tied_names in weight_to_name_and_tied_names.values():
if name_given_by_user is None: # no mapping was passed for this tensor, use original tensor
continue
for tied_name in tied_names:
tied_weights_to_given_name[tied_name] = name_given_by_user
return tied_weights_to_given_name


def _create_swap_params(params_and_buffers):
def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Optional[Tensor]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is tensor being used for?

Copy link
Contributor Author

@samdow samdow Jan 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unused but previously it was the original weight. Since it was harder to get for tied weights, I passed None and had to update the signature to match that. I can add a patch on top of this to remove that parameter since it appears unused

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM

# Changes the module class to get a new __getattr__ dunder method
# that looks for the reparametrized tensor
if hasattr(module, tensor_name):
old_val = getattr(module, tensor_name)
_check_tied_val_already_replaced(old_val, tensor, replaced_tensors_map)
if hasattr(module, "_attr_to_path"):
module._attr_to_path[tensor_name] = full_path
else:
Expand All @@ -72,12 +134,17 @@ def _remove_swap(module, name: str, full_path: str) -> None:
def _reparametrize_module(
module: 'torch.nn.Module',
parameters_and_buffers: Dict[str, Tensor],
tie_weights: bool = False,
) -> Iterator[None]:
orig_tensors_to_replacements: Dict[Tensor, Tensor] = {}
tied_weights_map = _create_tied_weights_map(module, parameters_and_buffers) if tie_weights else {}
for name, tensor in parameters_and_buffers.items():
_apply_func_submodules(
_create_swap_params(parameters_and_buffers, orig_tensors_to_replacements),
_create_swap_params(parameters_and_buffers),
module, name.split("."), name, (tensor,))
for tied_name, user_given_name in tied_weights_map.items():
_apply_func_submodules(
_create_swap_params(parameters_and_buffers),
module, tied_name.split("."), user_given_name, (None,))
try:
yield
finally:
Expand Down Expand Up @@ -105,6 +172,8 @@ def functional_call(
parameters_and_buffers: Dict[str, Tensor],
args: Union[Any, Tuple],
kwargs: Dict[str, Any] = None,
*,
tie_weights: bool = True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is technically BC-breaking (does it matter?). What's our deprecation plan for nn.utils.stateless.functional_call?
Easiest thing to do seems like:

  • nn.utils.stateless.functional_call should retain the same behavior as before (tie_weights=False?)
  • we deprecate nn.utils.stateless.functional_call in the next version of PyTorch (so, now on master)
  • we introduce a new torch.func.functional_call (probably needs a better name) to replace it in the next version of PyTorch that has our preferred default (tie_weights=True)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is technically BC-breaking (does it matter?). What's our deprecation plan for nn.utils.stateless.functional_call?

Yep, this is BC-breaking on a beta feature. This is from earlier talks with @albanD where we figured the default behavior should be what most people expect (which, from the make_functional requests, seems to be to have the tied weights get changed together).

Since we are moving it to torch.func anyway, I'm fine to change the default back to match old behavior for now and break it when we do the move unless @albanD has other thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both can work really.
But it might be simple to really have the torch.func version be an alias to the old version.
And yes from earlier discussions, I think this BC-breaking is ok.

):
r"""Performs a functional call on the module by replacing the module parameters
and buffers with the provided ones.
Expand All @@ -128,12 +197,31 @@ def functional_call(
>>> print(mod.foo) # tensor(0.)
>>> print(a['foo']) # tensor(1.)

.. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the
tie_weights flag.

Example::

>>> a = {'foo': torch.zeros(())}
>>> # xdoctest: +SKIP
>>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied
>>> print(mod.foo) # tensor(1.)
>>> mod(torch.zeros(())) # tensor(2.)
>>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated
>>> new_a = {'foo', torch.zeros(()), 'foo_tied': torch.zeros(())}
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.)

Args:
module (torch.nn.Module): the module to call
parameters_and_buffers (dict of str and Tensor): the parameters that will be used in
the module call.
args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument.
kwargs (dict): keyword arguments to be passed to the module call
tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as
tied in the reparamaterized version. Therefore, if True and different values are passed for the tied
paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and
buffers unless the values passed for both weights are the same. Default: True.

Returns:
Any: the result of calling ``module``.
Expand All @@ -151,7 +239,7 @@ def functional_call(
raise RuntimeError("The stateless API can't be used with Jitted modules")
if kwargs is None:
kwargs = {}
with _reparametrize_module(module, parameters_and_buffers):
with _reparametrize_module(module, parameters_and_buffers, tie_weights):
if isinstance(args, tuple):
out = module(*args, **kwargs)
else:
Expand Down