-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🚀 The feature, motivation and pitch
As far as I can tell, there is no hook that can be run as a parameter is registered. This could just be due to the fact that it was never needed before.
The motivation for wanting this feature now would be to wrap every single parameter that is being registered in a function call that could read/modify its state. Without this hook, we would have to wrap each and every parameter initialization manually. This makes things potentially harder to read and more difficult to maintain versus a hook that can be easily be enabled/disabled. Moreover, its not even possible to do this for parameters that are defined in modules like torch.nn.Linear where we have no control over the implementation.
An example of what this proposal could look like is to have something like the following in module.py:
...
_global_parameter_registration_hooks: Dict[int, Callable] = OrderedDict()
def register_module_parameter_registration_hook(hook: Callable[..., None]) -> RemovableHandle:
...
...
class Module:
...
def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
...
for func in _global_parameter_registration_hooks:
param = func(param)
self._parameters[name] = param
...
...Also, while we're at it, we can also do this for register_buffer and add_module.
Would there be any reason this hook could not or should not be added?
Alternatives
Note, the wrapping function we want to hook to the parameter registration is not something that can be called after all the parameters are registered by retrieving them via model.parameters(). It needs to be called on a per parameter basis before the next parameter is registered
Additional context
No response
cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345 @saketh-are