@@ -2223,107 +2223,6 @@ def forward(self, x):
22232223 FileCheck ().check ("conv" ).check_not ("aten::batch_norm" ).run (traced_model .graph )
22242224 FileCheck ().check ("conv" ).check_not ("aten::add" ).run (traced_model .graph )
22252225
2226- def test_linear_bn_folding (self ):
2227- module_pairs = [(nn .Linear , nn .BatchNorm1d ), (nn .Linear , nn .BatchNorm2d ), (nn .Linear , nn .BatchNorm3d )]
2228- use_tracing = [True , False ]
2229- bn_running_stats = [True , False ]
2230-
2231- for modules , tracing , track_stats in product (module_pairs , use_tracing , bn_running_stats ):
2232- class LinearBN (torch .nn .Module ):
2233- def __init__ (self , in_features , out_features ):
2234- super (LinearBN , self ).__init__ ()
2235- self .linear = modules [0 ](in_features , out_features )
2236- self .bn = modules [1 ](out_features , eps = 0.001 , track_running_stats = track_stats )
2237-
2238- def forward (self , x ):
2239- x = self .linear (x )
2240- return self .bn (x )
2241-
2242- mod_eager = LinearBN (32 , 32 ).eval ()
2243-
2244- inps = [3 , 32 ]
2245- if modules [1 ] == nn .BatchNorm2d :
2246- inps .append (inps [- 1 ])
2247- inps .append (inps [- 1 ])
2248- if modules [1 ] == nn .BatchNorm3d :
2249- inps .append (inps [- 1 ])
2250- inps .append (inps [- 1 ])
2251- inps .append (inps [- 1 ])
2252-
2253- inp = torch .rand (inps )
2254-
2255- if tracing :
2256- scripted_mod = torch .jit .trace (mod_eager , (inp ))
2257- else :
2258- scripted_mod = torch .jit .script (mod_eager )
2259-
2260- self .run_pass ("inline" , scripted_mod .graph )
2261- self .run_pass ("peephole" , scripted_mod .graph )
2262- self .run_pass ("constant_propagation" , scripted_mod .graph )
2263-
2264- FileCheck ().check ("linear" ).check ("batch" ).run (scripted_mod .graph )
2265- # successfully no-ops with non-const inputs
2266- self .run_pass ("fold_frozen_linear_bn" , scripted_mod .graph )
2267- FileCheck ().check ("linear" ).check ("aten::batch_norm" ).run (scripted_mod .graph )
2268-
2269- scripted_mod = torch .jit .freeze (scripted_mod )
2270- self .run_pass ("fold_frozen_linear_bn" , scripted_mod .graph )
2271- if track_stats :
2272- FileCheck ().check ("linear" ).check_not ("aten::batch_norm" ).run (scripted_mod .graph )
2273- else :
2274- FileCheck ().check ("linear" ).check ("aten::batch_norm" ).run (scripted_mod .graph )
2275-
2276- self .assertEqual (mod_eager (inp ), scripted_mod (inp ))
2277- self .assertEqual (mod_eager (inp ), scripted_mod (inp ))
2278-
2279- @skipCUDAMemoryLeakCheckIf (True )
2280- @unittest .skipIf (not TEST_CUDA , "Optimization currently only run for GPU" )
2281- def test_linear_bn_folding_autocast_scenario_cuda (self ):
2282- module_pairs = [(nn .Linear , nn .BatchNorm1d ), (nn .Linear , nn .BatchNorm2d ), (nn .Linear , nn .BatchNorm3d )]
2283- use_tracing = [True , False ]
2284- bn_running_stats = [True , False ]
2285-
2286- for modules , tracing , track_stats in product (module_pairs , use_tracing , bn_running_stats ):
2287- class LinearBN (torch .nn .Module ):
2288- def __init__ (self , in_features , out_features ):
2289- super (LinearBN , self ).__init__ ()
2290- self .linear = modules [0 ](in_features , out_features , bias = False , dtype = torch .half )
2291- self .bn = modules [1 ](out_features , eps = 0.001 , dtype = torch .float )
2292-
2293- def forward (self , x ):
2294- x = self .linear (x )
2295- return self .bn (x )
2296-
2297- mod_eager = LinearBN (32 , 32 ).cuda ().eval ()
2298-
2299- inps = [3 , 32 ]
2300- if modules [1 ] == nn .BatchNorm2d :
2301- inps .append (inps [- 1 ])
2302- inps .append (inps [- 1 ])
2303- if modules [1 ] == nn .BatchNorm3d :
2304- inps .append (inps [- 1 ])
2305- inps .append (inps [- 1 ])
2306- inps .append (inps [- 1 ])
2307-
2308- x = torch .rand (inps , dtype = torch .half ).cuda ()
2309-
2310- if tracing :
2311- scripted_mod = torch .jit .trace (mod_eager , (x ))
2312- else :
2313- scripted_mod = torch .jit .script (mod_eager )
2314- scripted_mod = torch .jit .freeze (scripted_mod )
2315- FileCheck ().check ("linear" ).check_not ("aten::batch_norm" ).run (scripted_mod .graph )
2316- lin_node = scripted_mod .graph .findNode ("aten::linear" , True )
2317- self .assertTrue (lin_node is not None )
2318- weight_input = lin_node .namedInput ("weight" )
2319- bias_input = lin_node .namedInput ("bias" )
2320- self .assertTrue (bias_input is not None )
2321- self .assertTrue (weight_input .type ().dtype () == torch .half )
2322- self .assertTrue (bias_input .type ().dtype () == torch .half )
2323-
2324- self .assertEqual (mod_eager (x ), scripted_mod (x ), atol = 1e-2 , rtol = 1e-2 )
2325- self .assertEqual (mod_eager (x ), scripted_mod (x ), atol = 1e-2 , rtol = 1e-2 )
2326-
23272226 @unittest .skipIf (not TEST_CUDA , "Optimization currently only run for GPU" )
23282227 def test_linear_concat (self ):
23292228 out_dimms = [[5 , 10 ], [1 , 5 ]]
0 commit comments