Skip to content

Commit f5d1857

Browse files
mrshenlipytorchmergebot
authored andcommitted
Allow Module forward-pre and forward hooks to take kwargs (#89389)
closes #35643 This PR is mostly borrowed from #82042. Thanks @Padarn for implementing the first version and debugging into the errors. Based on the discussion in #82042 this PR adds a with_kwargs argument to register_forward_pre_hook and register_forward_hook methods. When the arg is set to true, the provided hook must accept kwargs args. Under the hook, this PR adds a `_forward_pre_hooks_with_kwargs` and a `_forward_hook_with_kwargs` set to keep track of which hooks accept kwargs. Differential Revision: [D41431111](https://our.internmc.facebook.com/intern/diff/D41431111) Pull Request resolved: #89389 Approved by: https://github.com/soulitzer
1 parent 4935b59 commit f5d1857

File tree

5 files changed

+335
-47
lines changed

5 files changed

+335
-47
lines changed

test/nn/test_module_hooks.py

Lines changed: 185 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.nn as nn
1010

1111
from functools import partial
12-
from typing import List, Tuple
12+
from typing import Any, Dict, List, Tuple
1313

1414

1515
class Net(nn.Module):
@@ -87,8 +87,64 @@ def full_backward_pre_hook(
8787
self.assertEqual(len(grad_input), 1)
8888

8989

90-
class TestModuleHooks(TestCase):
90+
class KwargModel(nn.Module):
91+
def __init__(self) -> None:
92+
super().__init__()
93+
self.net1 = Net()
94+
self.net2 = Net()
95+
96+
def forward(
97+
self, x: torch.Tensor, bias: torch.Tensor = None
98+
) -> torch.Tensor:
99+
if bias is not None:
100+
x = x + bias
101+
return x
102+
103+
def internal_forward_hook(
104+
self,
105+
module: nn.Module,
106+
args: Tuple[torch.Tensor],
107+
kwargs: Dict[str, Any],
108+
out: torch.Tensor,
109+
):
110+
return out + kwargs["bias"]
111+
112+
113+
def kwarg_forward_pre_hook(
114+
self: TestCase,
115+
fired_hooks: List[int],
116+
expected_module: nn.Module,
117+
hook_id: int,
118+
module: nn.Module,
119+
args: Tuple[torch.Tensor],
120+
kwargs: Dict[str, Any],
121+
) -> Tuple[Any, Any]:
122+
fired_hooks.append(hook_id)
123+
self.assertEqual(id(module), id(expected_module))
124+
self.assertEqual(len(args), 1)
125+
kwargs["bias"] = 2 * kwargs["bias"]
126+
return args, kwargs
127+
128+
129+
def kwarg_forward_hook(
130+
self: TestCase,
131+
fired_hooks: List[int],
132+
expected_module: nn.Module,
133+
hook_id: int,
134+
module: nn.Module,
135+
args: Tuple[torch.Tensor],
136+
kwargs: Dict[str, Any],
137+
out: torch.Tensor,
138+
) -> Any:
139+
fired_hooks.append(hook_id)
140+
self.assertEqual(id(module), id(expected_module))
141+
self.assertEqual(len(args), 1)
142+
143+
out = out + kwargs["bias"]
144+
return out
91145

146+
147+
class TestModuleHooks(TestCase):
92148
@skipIfTorchDynamo("Dynamo does not yet capture hooks")
93149
def test_forward_hooks(self):
94150
fired_hooks: List[int] = []
@@ -116,11 +172,15 @@ def test_forward_pre_hooks(self):
116172
model = ToyModel()
117173
x = torch.randn(10, 10)
118174
hook = partial(forward_pre_hook, self, fired_hooks, model.net2.seq1)
119-
model.net2.seq1.register_forward_pre_hook(partial(hook, 0), prepend=True)
175+
model.net2.seq1.register_forward_pre_hook(
176+
partial(hook, 0), prepend=True
177+
)
120178
model.net2.seq1.register_forward_pre_hook(partial(hook, 1))
121179
model.net2.seq1.register_forward_pre_hook(partial(hook, 2))
122180
model.net2.seq1.register_forward_pre_hook(partial(hook, 3))
123-
model.net2.seq1.register_forward_pre_hook(partial(hook, 4), prepend=True)
181+
model.net2.seq1.register_forward_pre_hook(
182+
partial(hook, 4), prepend=True
183+
)
124184
expected = [4, 0, 1, 2, 3]
125185

126186
self.assertEqual(fired_hooks, [])
@@ -158,8 +218,12 @@ def test_full_backward_pre_hooks(self):
158218
model = ToyModel()
159219
x = torch.randn(10, 10)
160220
hook = partial(full_backward_pre_hook, self, fired_hooks, model.net1)
161-
model.net1.register_full_backward_pre_hook(partial(hook, 0), prepend=True)
162-
model.net1.register_full_backward_pre_hook(partial(hook, 1), prepend=True)
221+
model.net1.register_full_backward_pre_hook(
222+
partial(hook, 0), prepend=True
223+
)
224+
model.net1.register_full_backward_pre_hook(
225+
partial(hook, 1), prepend=True
226+
)
163227
model.net1.register_full_backward_pre_hook(partial(hook, 2))
164228
model.net1.register_full_backward_pre_hook(partial(hook, 3))
165229
model.net1.register_full_backward_pre_hook(partial(hook, 4))
@@ -178,10 +242,18 @@ def test_mixed_hooks(self):
178242
fired_hooks: List[int] = []
179243
model = ToyModel()
180244
x = torch.randn(10, 10)
181-
model.register_forward_pre_hook(partial(forward_pre_hook, self, fired_hooks, model, 0))
182-
model.register_forward_hook(partial(forward_hook, self, fired_hooks, model, 1))
183-
model.register_full_backward_pre_hook(partial(full_backward_pre_hook, self, fired_hooks, model, 2))
184-
model.register_full_backward_hook(partial(full_backward_hook, self, fired_hooks, model, 3))
245+
model.register_forward_pre_hook(
246+
partial(forward_pre_hook, self, fired_hooks, model, 0)
247+
)
248+
model.register_forward_hook(
249+
partial(forward_hook, self, fired_hooks, model, 1)
250+
)
251+
model.register_full_backward_pre_hook(
252+
partial(full_backward_pre_hook, self, fired_hooks, model, 2)
253+
)
254+
model.register_full_backward_hook(
255+
partial(full_backward_hook, self, fired_hooks, model, 3)
256+
)
185257

186258
self.assertEqual(fired_hooks, [])
187259
out = model(x)
@@ -191,6 +263,109 @@ def test_mixed_hooks(self):
191263
model(x).sum().backward()
192264
self.assertEqual(fired_hooks, [0, 1, 2, 3, 0, 1, 2, 3])
193265

266+
@skipIfTorchDynamo("Dynamo does not yet capture hooks")
267+
def test_kwarg_hooks(self):
268+
# 1. test forward pre hook
269+
fired_hooks: List[int] = []
270+
x: torch.Tensor = torch.ones(10, 10)
271+
bias: torch.Tensor = torch.ones(10, 10)
272+
model = KwargModel()
273+
model.register_forward_pre_hook(
274+
partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0),
275+
with_kwargs=True,
276+
)
277+
278+
# forward-pre: bias' = bias * 2
279+
# So, out = x + bias * 2
280+
self.assertEqual(fired_hooks, [])
281+
out = model(x, bias=bias)
282+
self.assertEqual(fired_hooks, [0])
283+
self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5)
284+
285+
# 2. test forward pre and forward hooks
286+
fired_hooks: List[int] = []
287+
x: torch.Tensor = torch.ones(10, 10)
288+
bias: torch.Tensor = torch.ones(10, 10)
289+
model = KwargModel()
290+
model.register_forward_hook(
291+
partial(kwarg_forward_hook, self, fired_hooks, model, 1),
292+
with_kwargs=True,
293+
)
294+
model.register_forward_pre_hook(
295+
partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0),
296+
with_kwargs=True,
297+
)
298+
299+
# forward-pre: bias' = bias * 2
300+
# forward: out = x + bias'
301+
# forward-post: out = out + bias'
302+
# So, out = x + bias * 4
303+
self.assertEqual(fired_hooks, [])
304+
out = model(x, bias=bias)
305+
self.assertEqual(fired_hooks, [0, 1])
306+
self.assertEqual(out, x + 4 * bias, rtol=0, atol=1e-5)
307+
308+
# 3. test nn.Module member method as forward-post hook
309+
x: torch.Tensor = torch.ones(10, 10)
310+
bias: torch.Tensor = torch.ones(10, 10)
311+
model = KwargModel()
312+
model.register_forward_hook(
313+
model.internal_forward_hook, with_kwargs=True
314+
)
315+
316+
# forward: out = x + bias
317+
# forward-post: out = out + bias
318+
# So, out = x + bias * 2
319+
out = model(x, bias=bias)
320+
self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5)
321+
322+
323+
@skipIfTorchDynamo("Dynamo does not yet capture hooks")
324+
def test_remove_kwarg_hooks(self):
325+
# test forward pre and forward hooks
326+
fired_hooks: List[int] = []
327+
x: torch.Tensor = torch.ones(10, 10)
328+
bias: torch.Tensor = torch.ones(10, 10)
329+
model = KwargModel()
330+
forward_hook_handle = model.register_forward_hook(
331+
partial(kwarg_forward_hook, self, fired_hooks, model, 1),
332+
with_kwargs=True,
333+
)
334+
forward_pre_hook_handle = model.register_forward_pre_hook(
335+
partial(kwarg_forward_pre_hook, self, fired_hooks, model, 0),
336+
with_kwargs=True,
337+
)
338+
339+
# forward-pre: bias' = bias * 2
340+
# forward: out = x + bias'
341+
# forward-post: out = out + bias'
342+
# So, out = x + bias * 4
343+
self.assertEqual(fired_hooks, [])
344+
out = model(x, bias=bias)
345+
self.assertEqual(fired_hooks, [0, 1])
346+
self.assertEqual(out, x + 4 * bias, rtol=0, atol=1e-5)
347+
348+
# forward-pre: bias' = bias * 2
349+
# forward: out = x + bias'
350+
# So, out = x + bias * 2
351+
forward_hook_handle.remove()
352+
out = model(x, bias=bias)
353+
self.assertEqual(fired_hooks, [0, 1, 0])
354+
self.assertEqual(out, x + 2 * bias, rtol=0, atol=1e-5)
355+
self.assertFalse(
356+
forward_hook_handle.id in model._forward_hooks_with_kwargs
357+
)
358+
359+
# forward: out = x + bias
360+
# So, out = x + bias
361+
forward_pre_hook_handle.remove()
362+
out = model(x, bias=bias)
363+
self.assertEqual(fired_hooks, [0, 1, 0])
364+
self.assertEqual(out, x + bias, rtol=0, atol=1e-5)
365+
self.assertFalse(
366+
forward_pre_hook_handle.id in model._forward_pre_hooks_with_kwargs
367+
)
368+
194369

195370
if __name__ == "__main__":
196371
run_tests()

torch/distributed/nn/api/remote_module.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,9 @@
6464
"_backward_pre_hooks",
6565
"_is_full_backward_hook",
6666
"_forward_hooks",
67+
"_forward_hooks_with_kwargs",
6768
"_forward_pre_hooks",
69+
"_forward_pre_hooks_with_kwargs",
6870
"_state_dict_hooks",
6971
"_load_state_dict_pre_hooks",
7072
"_load_state_dict_post_hooks",
@@ -365,13 +367,15 @@ def register_forward_pre_hook( # type: ignore[return]
365367
self,
366368
hook: Callable[..., None],
367369
prepend: bool = False,
370+
with_kwargs: bool = False,
368371
) -> RemovableHandle:
369372
_raise_not_supported(self.register_forward_pre_hook.__name__)
370373

371374
def register_forward_hook( # type: ignore[return]
372375
self,
373376
hook: Callable[..., None],
374377
prepend: bool = False,
378+
with_kwargs: bool = False,
375379
) -> RemovableHandle:
376380
_raise_not_supported(self.register_forward_hook.__name__)
377381

torch/jit/_recursive.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
"_backward_hooks",
3131
"_backward_pre_hooks",
3232
"_forward_hooks",
33+
"_forward_hooks_with_kwargs",
3334
"_forward_pre_hooks",
35+
"_forward_pre_hooks_with_kwargs",
3436
"_state_dict_hooks",
3537
"_load_state_dict_pre_hooks",
3638
"_load_state_dict_post_hooks",

0 commit comments

Comments
 (0)