Skip to content

Commit 577c04c

Browse files
jerryzh168facebook-github-bot
authored andcommitted
add mutation support for forward_pre_hook and forward_hook (#22285)
Summary: Pull Request resolved: #22285 Previously forward hooks are expected to return None, this PR adds the support to overwrite input and output in `forward_pre_hook` and `forward_hook`, this is used to implement inserting quant/dequant function calls around forward functions. Differential Revision: D16022491 fbshipit-source-id: 02340080745f22c8ea8a2f80c2c08e3a88e37253
1 parent f7421b8 commit 577c04c

File tree

2 files changed

+36
-26
lines changed

2 files changed

+36
-26
lines changed

test/test_nn.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -713,30 +713,12 @@ def test_hook_fail(self):
713713
module = nn.Sigmoid()
714714
input = torch.randn(5, 5, requires_grad=True)
715715

716-
def fw_fail1(self, input, output):
717-
return output
718-
719-
def fw_fail2(self, input, output):
720-
return input
721-
722716
def bw_fail1(self, grad_input, grad_output):
723717
return grad_input[:-1]
724718

725719
def bw_fail2(self, grad_input, grad_output):
726720
return grad_input + (torch.randn(2, 2),)
727721

728-
with module.register_forward_hook(fw_fail1):
729-
with self.assertRaises(RuntimeError) as err:
730-
module(input)
731-
self.assertIn("fw_fail", err.exception.args[0])
732-
self.assertIn("didn't return None", err.exception.args[0])
733-
734-
with module.register_forward_hook(fw_fail2):
735-
with self.assertRaises(RuntimeError) as err:
736-
module(input)
737-
self.assertIn("fw_fail2", err.exception.args[0])
738-
self.assertIn("didn't return None", err.exception.args[0])
739-
740722
with module.register_backward_hook(bw_fail1):
741723
with self.assertRaises(RuntimeError) as err:
742724
module(input).sum().backward()
@@ -765,6 +747,28 @@ def bw_hook(module, grad_input, grad_output):
765747
expected_grad = torch.ones(5, 5).mm(module.weight.data) * 2
766748
self.assertEqual(input.grad.data, expected_grad)
767749

750+
def test_hook_mutations(self):
751+
module = nn.Linear(5, 5)
752+
input = torch.randn(5, 5, requires_grad=True)
753+
754+
def forward_pre_hook(m, input):
755+
return torch.nn.functional.relu(input[0])
756+
757+
def forward_hook(m, input, output):
758+
return -output
759+
760+
module.register_forward_pre_hook(forward_pre_hook)
761+
module.register_forward_hook(forward_hook)
762+
output = module(input)
763+
expected_res = -torch.nn.functional.linear(torch.nn.functional.relu(input), module.weight, module.bias)
764+
self.assertEqual(output, expected_res)
765+
output.backward(torch.ones(5, 5) * 2, retain_graph=True)
766+
mask = (input > 0).double()
767+
expected_grad = -torch.ones(5, 5).mm(module.weight.data) * 2 * mask
768+
self.assertEqual(input.grad, expected_grad)
769+
770+
771+
768772
def test_to(self):
769773
m = nn.Linear(3, 5)
770774
self.assertIs(m, m.to('cpu'))

torch/nn/modules/module.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -462,9 +462,11 @@ def register_forward_pre_hook(self, hook):
462462
The hook will be called every time before :func:`forward` is invoked.
463463
It should have the following signature::
464464
465-
hook(module, input) -> None
465+
hook(module, input) -> None or modified input
466466
467-
The hook should not modify the input.
467+
The hook can modify the input. User can either return a tuple or a
468+
single modified value in the hook. We will wrap the value into a tuple
469+
if a single value is returned(unless that value is already a tuple).
468470
469471
Returns:
470472
:class:`torch.utils.hooks.RemovableHandle`:
@@ -481,9 +483,11 @@ def register_forward_hook(self, hook):
481483
The hook will be called every time after :func:`forward` has computed an output.
482484
It should have the following signature::
483485
484-
hook(module, input, output) -> None
486+
hook(module, input, output) -> None or modified output
485487
486-
The hook should not modify the input or output.
488+
The hook can modify the output. It can modify the input inplace but
489+
it will not have effect on forward since this is called after
490+
:func:`forward` is called.
487491
488492
Returns:
489493
:class:`torch.utils.hooks.RemovableHandle`:
@@ -524,17 +528,19 @@ def _slow_forward(self, *input, **kwargs):
524528

525529
def __call__(self, *input, **kwargs):
526530
for hook in self._forward_pre_hooks.values():
527-
hook(self, input)
531+
result = hook(self, input)
532+
if result is not None:
533+
if not isinstance(result, tuple):
534+
result = (result,)
535+
input = result
528536
if torch._C._get_tracing_state():
529537
result = self._slow_forward(*input, **kwargs)
530538
else:
531539
result = self.forward(*input, **kwargs)
532540
for hook in self._forward_hooks.values():
533541
hook_result = hook(self, input, result)
534542
if hook_result is not None:
535-
raise RuntimeError(
536-
"forward hooks should never return any values, but '{}'"
537-
"didn't return None".format(hook))
543+
result = hook_result
538544
if len(self._backward_hooks) > 0:
539545
var = result
540546
while not isinstance(var, torch.Tensor):

0 commit comments

Comments
 (0)