-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[stateless] add weight tying support #90477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2d09ab5
de5a347
f273977
9053d12
5358159
b0500be
d07d26f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,5 @@ | ||
| import warnings | ||
| import contextlib | ||
| from typing import Any, Callable, Dict, Iterator, List, Tuple, Union | ||
| from typing import Any, Callable, Dict, Iterator, List, Tuple, Union, Set, Optional | ||
zou3519 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
@@ -37,21 +36,84 @@ def _setattr(self, name: str, value: Any) -> None: | |
| module._orig_class = cls | ||
|
|
||
|
|
||
| def _check_tied_val_already_replaced(old_val, new_val, replaced_tensors_map): | ||
| if old_val not in replaced_tensors_map: | ||
| replaced_tensors_map[old_val] = new_val | ||
| elif replaced_tensors_map[old_val] is not new_val: | ||
| warnings.warn("functional_call was passed multiple values for tied weights. " | ||
| "This behavior is deprecated and will be an error in future versions") | ||
| def _create_tied_weights_map(module: 'torch.nn.Module', params_and_buffers: Dict[str, Tensor]) -> Dict[str, str]: | ||
| """ | ||
| _create_tied_weights_map(module: Module, params_and_buffers: Dict[str, Tensor]) -> Dict[str, str] | ||
|
|
||
| Creates a weight map of {tied_name: name_given_by_user} for all weights where one of their tied weights is passed | ||
|
|
||
| ex: Foo() has self.foo and self.tied_foo, which are tied. If a user passed {'foo': ...} as the reparamaterization, | ||
| this would return {'tied_foo': 'foo'}. Similarly if a user passed {'tied_foo': ...}, this returns | ||
| {'tied_foo': 'foo'}. | ||
|
|
||
| ex: If there aren't any tied weights and the user passed values for every parameter and buffer, this will return a | ||
| map where every name maps to an empty set: {'l1.weight': set(), 'l1.bias': set(), ...} | ||
|
|
||
| ex: The map only contains values that a user is reparamaterizing. For example, if module = nn.Linear(...) and the | ||
| user only passed a new value for 'bias', this looks returns: {'bias': set()} | ||
|
|
||
| def _create_swap_params(params_and_buffers, replaced_tensors_map): | ||
| def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Tensor) -> None: | ||
| This is useful because we will start by reparamaterizing all the keys of params_and_buffers, then all the key from | ||
| this returned dictionary. | ||
| """ | ||
|
|
||
| # The basic algorithm looks like: | ||
| # - index all weights by their original tensor value to find tied weights | ||
| # - when we encounter a weight not used by the user, we save it in a set (second element in the tuple) | ||
| # - when we run into a weight used by the user, we save that separate from the set as the first element in the tuple | ||
| # - ending map looks like {tensor: (name_given_by_user, set(all_tied_names)} | ||
| # - then loop through the values of this map (name_given_by_user and set(all_tied_names)) | ||
| # - for each element of all_tied_names, add {tied_name: name_given_by_user} to a new map | ||
|
|
||
| names = params_and_buffers.keys() | ||
| weight_to_name_and_tied_names: Dict[torch.Tensor, Tuple[Optional[str], Set[str]]] = {} | ||
|
|
||
| # create a map keyed by tensor value so that tied weights get mapped to the same key. The value is the interesting | ||
| # part at the end it's (used_name, (tied_names)). | ||
| # For example, in the first example where there's tied weights self.foo and self.tied_foo and the user passes a | ||
| # value for self.foo, this will return {torch.Tensor(...): ('foo', set('tied_foo'))} | ||
| def add_to_name_map(n: str, t: torch.Tensor): | ||
| # if the tensor hasn't been seen before, add it to the map | ||
| if t not in weight_to_name_and_tied_names: | ||
| weight_to_name_and_tied_names[t] = (n, set()) if n in names else (None, {n}) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. O(1) according to the internet: https://stackoverflow.com/questions/17539367/python-dictionary-keys-in-complexity |
||
| return | ||
|
|
||
| # if the name is not used by the user, we add it to the tied set | ||
| if n not in names: | ||
| weight_to_name_and_tied_names[t][1].add(n) | ||
| return | ||
|
|
||
| # check that the user didn't pass two different tensors for the same tied weight | ||
| first_seen_name = weight_to_name_and_tied_names[t][0] | ||
|
|
||
| # if they didn't pass multiple names for tied weights or used the same tensor, we set the used name | ||
| if first_seen_name is None or params_and_buffers[n] is params_and_buffers[first_seen_name]: | ||
| weight_to_name_and_tied_names[t] = (n, weight_to_name_and_tied_names[t][1]) | ||
| return | ||
|
|
||
| raise ValueError(f"functional_call got values for both {n} and {first_seen_name}, which are tied. " + | ||
| "Consider using tie_weights=False") | ||
|
|
||
| tensor: Tensor | ||
| for name, tensor in module.named_parameters(remove_duplicate=False): | ||
| add_to_name_map(name, tensor) | ||
|
|
||
| for name, tensor in module.named_buffers(remove_duplicate=False): | ||
| add_to_name_map(name, tensor) | ||
|
|
||
| # make {tied_name: name_given_by_user} from pairs of (name_given_by_user, set(all_tied_names)) | ||
| tied_weights_to_given_name = {} | ||
| for name_given_by_user, tied_names in weight_to_name_and_tied_names.values(): | ||
| if name_given_by_user is None: # no mapping was passed for this tensor, use original tensor | ||
| continue | ||
| for tied_name in tied_names: | ||
| tied_weights_to_given_name[tied_name] = name_given_by_user | ||
| return tied_weights_to_given_name | ||
|
|
||
|
|
||
| def _create_swap_params(params_and_buffers): | ||
| def _swap_parameters(module, tensor_name: str, full_path: str, tensor: Optional[Tensor]) -> None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is tensor being used for?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's unused but previously it was the original weight. Since it was harder to get for tied weights, I passed None and had to update the signature to match that. I can add a patch on top of this to remove that parameter since it appears unused
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SGTM |
||
| # Changes the module class to get a new __getattr__ dunder method | ||
| # that looks for the reparametrized tensor | ||
| if hasattr(module, tensor_name): | ||
| old_val = getattr(module, tensor_name) | ||
| _check_tied_val_already_replaced(old_val, tensor, replaced_tensors_map) | ||
| if hasattr(module, "_attr_to_path"): | ||
| module._attr_to_path[tensor_name] = full_path | ||
| else: | ||
|
|
@@ -72,12 +134,17 @@ def _remove_swap(module, name: str, full_path: str) -> None: | |
| def _reparametrize_module( | ||
| module: 'torch.nn.Module', | ||
| parameters_and_buffers: Dict[str, Tensor], | ||
| tie_weights: bool = False, | ||
| ) -> Iterator[None]: | ||
| orig_tensors_to_replacements: Dict[Tensor, Tensor] = {} | ||
| tied_weights_map = _create_tied_weights_map(module, parameters_and_buffers) if tie_weights else {} | ||
| for name, tensor in parameters_and_buffers.items(): | ||
| _apply_func_submodules( | ||
| _create_swap_params(parameters_and_buffers, orig_tensors_to_replacements), | ||
| _create_swap_params(parameters_and_buffers), | ||
| module, name.split("."), name, (tensor,)) | ||
| for tied_name, user_given_name in tied_weights_map.items(): | ||
| _apply_func_submodules( | ||
| _create_swap_params(parameters_and_buffers), | ||
| module, tied_name.split("."), user_given_name, (None,)) | ||
| try: | ||
| yield | ||
| finally: | ||
|
|
@@ -105,6 +172,8 @@ def functional_call( | |
| parameters_and_buffers: Dict[str, Tensor], | ||
| args: Union[Any, Tuple], | ||
| kwargs: Dict[str, Any] = None, | ||
| *, | ||
| tie_weights: bool = True, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change is technically BC-breaking (does it matter?). What's our deprecation plan for nn.utils.stateless.functional_call?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yep, this is BC-breaking on a beta feature. This is from earlier talks with @albanD where we figured the default behavior should be what most people expect (which, from the make_functional requests, seems to be to have the tied weights get changed together). Since we are moving it to
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both can work really. |
||
| ): | ||
| r"""Performs a functional call on the module by replacing the module parameters | ||
| and buffers with the provided ones. | ||
|
|
@@ -128,12 +197,31 @@ def functional_call( | |
| >>> print(mod.foo) # tensor(0.) | ||
| >>> print(a['foo']) # tensor(1.) | ||
|
|
||
| .. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the | ||
| tie_weights flag. | ||
|
|
||
| Example:: | ||
|
|
||
| >>> a = {'foo': torch.zeros(())} | ||
| >>> # xdoctest: +SKIP | ||
| >>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied | ||
| >>> print(mod.foo) # tensor(1.) | ||
| >>> mod(torch.zeros(())) # tensor(2.) | ||
| >>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too | ||
| >>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated | ||
| >>> new_a = {'foo', torch.zeros(()), 'foo_tied': torch.zeros(())} | ||
| >>> functional_call(mod, new_a, torch.zeros()) # tensor(0.) | ||
|
|
||
| Args: | ||
| module (torch.nn.Module): the module to call | ||
| parameters_and_buffers (dict of str and Tensor): the parameters that will be used in | ||
| the module call. | ||
| args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument. | ||
| kwargs (dict): keyword arguments to be passed to the module call | ||
| tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as | ||
| tied in the reparamaterized version. Therefore, if True and different values are passed for the tied | ||
| paramaters and buffers, it will error. If False, it will not respect the originally tied parameters and | ||
| buffers unless the values passed for both weights are the same. Default: True. | ||
|
|
||
| Returns: | ||
| Any: the result of calling ``module``. | ||
|
|
@@ -151,7 +239,7 @@ def functional_call( | |
| raise RuntimeError("The stateless API can't be used with Jitted modules") | ||
| if kwargs is None: | ||
| kwargs = {} | ||
| with _reparametrize_module(module, parameters_and_buffers): | ||
| with _reparametrize_module(module, parameters_and_buffers, tie_weights): | ||
| if isinstance(args, tuple): | ||
| out = module(*args, **kwargs) | ||
| else: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.