Skip to content

Commit d940db0

Browse files
committed
Update on "Allow Module forward-pre and forward hooks to take kwargs"
closes #35643 This PR is mostly borrowed from #82042. Thanks Padarn for implementing the first version and debugging into the errors. Based on the discussion in #82042 this PR adds a with_kwargs argument to register_forward_pre_hook and register_forward_hook methods. When the arg is set to true, the provided hook must accept kwargs args. Under the hook, this PR adds a `_forward_pre_hooks_with_kwargs` and a `_forward_hook_with_kwargs` set to keep track of which hooks accept kwargs. Differential Revision: [D41431111](https://our.internmc.facebook.com/intern/diff/D41431111) cc albanD mruberry jbschlosser walterddr kshitij12345 saketh-are [ghstack-poisoned]
2 parents 8df224b + 87ef676 commit d940db0

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

torch/nn/modules/module.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,11 +1495,13 @@ def _call_impl(self, *args, **kwargs):
14951495
if hook_id in self._forward_pre_hooks_with_kwargs:
14961496
result = hook(self, args, kwargs) # type: ignore[misc]
14971497
if result is not None:
1498-
assert isinstance(result, tuple) and len(result) == 2, (
1499-
"forward pre-hook must return None or a tuple of "
1500-
f"(new_args, new_kwargs), but got {result}."
1501-
)
1502-
args, kwargs = result
1498+
if isinstance(result, tuple) and len(result) == 2:
1499+
args, kwargs = result
1500+
else:
1501+
raise RuntimeError(
1502+
"forward pre-hook must return None or a tuple "
1503+
f"of (new_args, new_kwargs), but got {result}."
1504+
)
15031505
else:
15041506
result = hook(self, args)
15051507
if result is not None:

torch/utils/hooks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def remove(self) -> None:
3737

3838
if self.extra_dict_ref is not None:
3939
extra_dict = self.extra_dict_ref()
40-
if extra_dict is not None:
40+
if extra_dict is not None and self.id in extra_dict:
4141
del extra_dict[self.id]
4242

4343
def __getstate__(self):

0 commit comments

Comments
 (0)