99import torch .nn as nn
1010
1111from functools import partial
12- from typing import List , Tuple
12+ from typing import Any , Dict , List , Tuple
1313
1414
1515class Net (nn .Module ):
@@ -87,8 +87,64 @@ def full_backward_pre_hook(
8787 self .assertEqual (len (grad_input ), 1 )
8888
8989
90- class TestModuleHooks (TestCase ):
90+ class KwargModel (nn .Module ):
91+ def __init__ (self ) -> None :
92+ super ().__init__ ()
93+ self .net1 = Net ()
94+ self .net2 = Net ()
95+
96+ def forward (
97+ self , x : torch .Tensor , bias : torch .Tensor = None
98+ ) -> torch .Tensor :
99+ if bias is not None :
100+ x = x + bias
101+ return x
102+
103+ def internal_forward_hook (
104+ self ,
105+ module : nn .Module ,
106+ args : Tuple [torch .Tensor ],
107+ kwargs : Dict [str , Any ],
108+ out : torch .Tensor ,
109+ ):
110+ return out + kwargs ["bias" ]
111+
112+
113+ def kwarg_forward_pre_hook (
114+ self : TestCase ,
115+ fired_hooks : List [int ],
116+ expected_module : nn .Module ,
117+ hook_id : int ,
118+ module : nn .Module ,
119+ args : Tuple [torch .Tensor ],
120+ kwargs : Dict [str , Any ],
121+ ) -> Tuple [Any , Any ]:
122+ fired_hooks .append (hook_id )
123+ self .assertEqual (id (module ), id (expected_module ))
124+ self .assertEqual (len (args ), 1 )
125+ kwargs ["bias" ] = 2 * kwargs ["bias" ]
126+ return args , kwargs
127+
128+
129+ def kwarg_forward_hook (
130+ self : TestCase ,
131+ fired_hooks : List [int ],
132+ expected_module : nn .Module ,
133+ hook_id : int ,
134+ module : nn .Module ,
135+ args : Tuple [torch .Tensor ],
136+ kwargs : Dict [str , Any ],
137+ out : torch .Tensor ,
138+ ) -> Any :
139+ fired_hooks .append (hook_id )
140+ self .assertEqual (id (module ), id (expected_module ))
141+ self .assertEqual (len (args ), 1 )
142+
143+ out = out + kwargs ["bias" ]
144+ return out
91145
146+
147+ class TestModuleHooks (TestCase ):
92148 @skipIfTorchDynamo ("Dynamo does not yet capture hooks" )
93149 def test_forward_hooks (self ):
94150 fired_hooks : List [int ] = []
@@ -116,11 +172,15 @@ def test_forward_pre_hooks(self):
116172 model = ToyModel ()
117173 x = torch .randn (10 , 10 )
118174 hook = partial (forward_pre_hook , self , fired_hooks , model .net2 .seq1 )
119- model .net2 .seq1 .register_forward_pre_hook (partial (hook , 0 ), prepend = True )
175+ model .net2 .seq1 .register_forward_pre_hook (
176+ partial (hook , 0 ), prepend = True
177+ )
120178 model .net2 .seq1 .register_forward_pre_hook (partial (hook , 1 ))
121179 model .net2 .seq1 .register_forward_pre_hook (partial (hook , 2 ))
122180 model .net2 .seq1 .register_forward_pre_hook (partial (hook , 3 ))
123- model .net2 .seq1 .register_forward_pre_hook (partial (hook , 4 ), prepend = True )
181+ model .net2 .seq1 .register_forward_pre_hook (
182+ partial (hook , 4 ), prepend = True
183+ )
124184 expected = [4 , 0 , 1 , 2 , 3 ]
125185
126186 self .assertEqual (fired_hooks , [])
@@ -158,8 +218,12 @@ def test_full_backward_pre_hooks(self):
158218 model = ToyModel ()
159219 x = torch .randn (10 , 10 )
160220 hook = partial (full_backward_pre_hook , self , fired_hooks , model .net1 )
161- model .net1 .register_full_backward_pre_hook (partial (hook , 0 ), prepend = True )
162- model .net1 .register_full_backward_pre_hook (partial (hook , 1 ), prepend = True )
221+ model .net1 .register_full_backward_pre_hook (
222+ partial (hook , 0 ), prepend = True
223+ )
224+ model .net1 .register_full_backward_pre_hook (
225+ partial (hook , 1 ), prepend = True
226+ )
163227 model .net1 .register_full_backward_pre_hook (partial (hook , 2 ))
164228 model .net1 .register_full_backward_pre_hook (partial (hook , 3 ))
165229 model .net1 .register_full_backward_pre_hook (partial (hook , 4 ))
@@ -178,10 +242,18 @@ def test_mixed_hooks(self):
178242 fired_hooks : List [int ] = []
179243 model = ToyModel ()
180244 x = torch .randn (10 , 10 )
181- model .register_forward_pre_hook (partial (forward_pre_hook , self , fired_hooks , model , 0 ))
182- model .register_forward_hook (partial (forward_hook , self , fired_hooks , model , 1 ))
183- model .register_full_backward_pre_hook (partial (full_backward_pre_hook , self , fired_hooks , model , 2 ))
184- model .register_full_backward_hook (partial (full_backward_hook , self , fired_hooks , model , 3 ))
245+ model .register_forward_pre_hook (
246+ partial (forward_pre_hook , self , fired_hooks , model , 0 )
247+ )
248+ model .register_forward_hook (
249+ partial (forward_hook , self , fired_hooks , model , 1 )
250+ )
251+ model .register_full_backward_pre_hook (
252+ partial (full_backward_pre_hook , self , fired_hooks , model , 2 )
253+ )
254+ model .register_full_backward_hook (
255+ partial (full_backward_hook , self , fired_hooks , model , 3 )
256+ )
185257
186258 self .assertEqual (fired_hooks , [])
187259 out = model (x )
@@ -191,6 +263,109 @@ def test_mixed_hooks(self):
191263 model (x ).sum ().backward ()
192264 self .assertEqual (fired_hooks , [0 , 1 , 2 , 3 , 0 , 1 , 2 , 3 ])
193265
266+ @skipIfTorchDynamo ("Dynamo does not yet capture hooks" )
267+ def test_kwarg_hooks (self ):
268+ # 1. test forward pre hook
269+ fired_hooks : List [int ] = []
270+ x : torch .Tensor = torch .ones (10 , 10 )
271+ bias : torch .Tensor = torch .ones (10 , 10 )
272+ model = KwargModel ()
273+ model .register_forward_pre_hook (
274+ partial (kwarg_forward_pre_hook , self , fired_hooks , model , 0 ),
275+ with_kwargs = True ,
276+ )
277+
278+ # forward-pre: bias' = bias * 2
279+ # So, out = x + bias * 2
280+ self .assertEqual (fired_hooks , [])
281+ out = model (x , bias = bias )
282+ self .assertEqual (fired_hooks , [0 ])
283+ self .assertEqual (out , x + 2 * bias , rtol = 0 , atol = 1e-5 )
284+
285+ # 2. test forward pre and forward hooks
286+ fired_hooks : List [int ] = []
287+ x : torch .Tensor = torch .ones (10 , 10 )
288+ bias : torch .Tensor = torch .ones (10 , 10 )
289+ model = KwargModel ()
290+ model .register_forward_hook (
291+ partial (kwarg_forward_hook , self , fired_hooks , model , 1 ),
292+ with_kwargs = True ,
293+ )
294+ model .register_forward_pre_hook (
295+ partial (kwarg_forward_pre_hook , self , fired_hooks , model , 0 ),
296+ with_kwargs = True ,
297+ )
298+
299+ # forward-pre: bias' = bias * 2
300+ # forward: out = x + bias'
301+ # forward-post: out = out + bias'
302+ # So, out = x + bias * 4
303+ self .assertEqual (fired_hooks , [])
304+ out = model (x , bias = bias )
305+ self .assertEqual (fired_hooks , [0 , 1 ])
306+ self .assertEqual (out , x + 4 * bias , rtol = 0 , atol = 1e-5 )
307+
308+ # 3. test nn.Module member method as forward-post hook
309+ x : torch .Tensor = torch .ones (10 , 10 )
310+ bias : torch .Tensor = torch .ones (10 , 10 )
311+ model = KwargModel ()
312+ model .register_forward_hook (
313+ model .internal_forward_hook , with_kwargs = True
314+ )
315+
316+ # forward: out = x + bias
317+ # forward-post: out = out + bias
318+ # So, out = x + bias * 2
319+ out = model (x , bias = bias )
320+ self .assertEqual (out , x + 2 * bias , rtol = 0 , atol = 1e-5 )
321+
322+
323+ @skipIfTorchDynamo ("Dynamo does not yet capture hooks" )
324+ def test_remove_kwarg_hooks (self ):
325+ # test forward pre and forward hooks
326+ fired_hooks : List [int ] = []
327+ x : torch .Tensor = torch .ones (10 , 10 )
328+ bias : torch .Tensor = torch .ones (10 , 10 )
329+ model = KwargModel ()
330+ forward_hook_handle = model .register_forward_hook (
331+ partial (kwarg_forward_hook , self , fired_hooks , model , 1 ),
332+ with_kwargs = True ,
333+ )
334+ forward_pre_hook_handle = model .register_forward_pre_hook (
335+ partial (kwarg_forward_pre_hook , self , fired_hooks , model , 0 ),
336+ with_kwargs = True ,
337+ )
338+
339+ # forward-pre: bias' = bias * 2
340+ # forward: out = x + bias'
341+ # forward-post: out = out + bias'
342+ # So, out = x + bias * 4
343+ self .assertEqual (fired_hooks , [])
344+ out = model (x , bias = bias )
345+ self .assertEqual (fired_hooks , [0 , 1 ])
346+ self .assertEqual (out , x + 4 * bias , rtol = 0 , atol = 1e-5 )
347+
348+ # forward-pre: bias' = bias * 2
349+ # forward: out = x + bias'
350+ # So, out = x + bias * 2
351+ forward_hook_handle .remove ()
352+ out = model (x , bias = bias )
353+ self .assertEqual (fired_hooks , [0 , 1 , 0 ])
354+ self .assertEqual (out , x + 2 * bias , rtol = 0 , atol = 1e-5 )
355+ self .assertFalse (
356+ forward_hook_handle .id in model ._forward_hooks_with_kwargs
357+ )
358+
359+ # forward: out = x + bias
360+ # So, out = x + bias
361+ forward_pre_hook_handle .remove ()
362+ out = model (x , bias = bias )
363+ self .assertEqual (fired_hooks , [0 , 1 , 0 ])
364+ self .assertEqual (out , x + bias , rtol = 0 , atol = 1e-5 )
365+ self .assertFalse (
366+ forward_pre_hook_handle .id in model ._forward_pre_hooks_with_kwargs
367+ )
368+
194369
195370if __name__ == "__main__" :
196371 run_tests ()
0 commit comments