@@ -2223,6 +2223,107 @@ def forward(self, x):
22232223 FileCheck ().check ("conv" ).check_not ("aten::batch_norm" ).run (traced_model .graph )
22242224 FileCheck ().check ("conv" ).check_not ("aten::add" ).run (traced_model .graph )
22252225
2226+ def test_linear_bn_folding (self ):
2227+ module_pairs = [(nn .Linear , nn .BatchNorm1d ), (nn .Linear , nn .BatchNorm2d ), (nn .Linear , nn .BatchNorm3d )]
2228+ use_tracing = [True , False ]
2229+ bn_running_stats = [True , False ]
2230+
2231+ for modules , tracing , track_stats in product (module_pairs , use_tracing , bn_running_stats ):
2232+ class LinearBN (torch .nn .Module ):
2233+ def __init__ (self , in_features , out_features ):
2234+ super (LinearBN , self ).__init__ ()
2235+ self .linear = modules [0 ](in_features , out_features )
2236+ self .bn = modules [1 ](out_features , eps = 0.001 , track_running_stats = track_stats )
2237+
2238+ def forward (self , x ):
2239+ x = self .linear (x )
2240+ return self .bn (x )
2241+
2242+ mod_eager = LinearBN (32 , 32 ).eval ()
2243+
2244+ inps = [3 , 32 ]
2245+ if modules [1 ] == nn .BatchNorm2d :
2246+ inps .append (inps [- 1 ])
2247+ inps .append (inps [- 1 ])
2248+ if modules [1 ] == nn .BatchNorm3d :
2249+ inps .append (inps [- 1 ])
2250+ inps .append (inps [- 1 ])
2251+ inps .append (inps [- 1 ])
2252+
2253+ inp = torch .rand (inps )
2254+
2255+ if tracing :
2256+ scripted_mod = torch .jit .trace (mod_eager , (inp ))
2257+ else :
2258+ scripted_mod = torch .jit .script (mod_eager )
2259+
2260+ self .run_pass ("inline" , scripted_mod .graph )
2261+ self .run_pass ("peephole" , scripted_mod .graph )
2262+ self .run_pass ("constant_propagation" , scripted_mod .graph )
2263+
2264+ FileCheck ().check ("linear" ).check ("batch" ).run (scripted_mod .graph )
2265+ # successfully no-ops with non-const inputs
2266+ self .run_pass ("fold_frozen_linear_bn" , scripted_mod .graph )
2267+ FileCheck ().check ("linear" ).check ("aten::batch_norm" ).run (scripted_mod .graph )
2268+
2269+ scripted_mod = torch .jit .freeze (scripted_mod )
2270+ self .run_pass ("fold_frozen_linear_bn" , scripted_mod .graph )
2271+ if track_stats :
2272+ FileCheck ().check ("linear" ).check_not ("aten::batch_norm" ).run (scripted_mod .graph )
2273+ else :
2274+ FileCheck ().check ("linear" ).check ("aten::batch_norm" ).run (scripted_mod .graph )
2275+
2276+ self .assertEqual (mod_eager (inp ), scripted_mod (inp ))
2277+ self .assertEqual (mod_eager (inp ), scripted_mod (inp ))
2278+
2279+ @skipCUDAMemoryLeakCheckIf (True )
2280+ @unittest .skipIf (not TEST_CUDA , "Optimization currently only run for GPU" )
2281+ def test_linear_bn_folding_autocast_scenario_cuda (self ):
2282+ module_pairs = [(nn .Linear , nn .BatchNorm1d ), (nn .Linear , nn .BatchNorm2d ), (nn .Linear , nn .BatchNorm3d )]
2283+ use_tracing = [True , False ]
2284+ bn_running_stats = [True , False ]
2285+
2286+ for modules , tracing , track_stats in product (module_pairs , use_tracing , bn_running_stats ):
2287+ class LinearBN (torch .nn .Module ):
2288+ def __init__ (self , in_features , out_features ):
2289+ super (LinearBN , self ).__init__ ()
2290+ self .linear = modules [0 ](in_features , out_features , bias = False , dtype = torch .half )
2291+ self .bn = modules [1 ](out_features , eps = 0.001 , dtype = torch .float )
2292+
2293+ def forward (self , x ):
2294+ x = self .linear (x )
2295+ return self .bn (x )
2296+
2297+ mod_eager = LinearBN (32 , 32 ).cuda ().eval ()
2298+
2299+ inps = [3 , 32 ]
2300+ if modules [1 ] == nn .BatchNorm2d :
2301+ inps .append (inps [- 1 ])
2302+ inps .append (inps [- 1 ])
2303+ if modules [1 ] == nn .BatchNorm3d :
2304+ inps .append (inps [- 1 ])
2305+ inps .append (inps [- 1 ])
2306+ inps .append (inps [- 1 ])
2307+
2308+ x = torch .rand (inps , dtype = torch .half ).cuda ()
2309+
2310+ if tracing :
2311+ scripted_mod = torch .jit .trace (mod_eager , (x ))
2312+ else :
2313+ scripted_mod = torch .jit .script (mod_eager )
2314+ scripted_mod = torch .jit .freeze (scripted_mod )
2315+ FileCheck ().check ("linear" ).check_not ("aten::batch_norm" ).run (scripted_mod .graph )
2316+ lin_node = scripted_mod .graph .findNode ("aten::linear" , True )
2317+ self .assertTrue (lin_node is not None )
2318+ weight_input = lin_node .namedInput ("weight" )
2319+ bias_input = lin_node .namedInput ("bias" )
2320+ self .assertTrue (bias_input is not None )
2321+ self .assertTrue (weight_input .type ().dtype () == torch .half )
2322+ self .assertTrue (bias_input .type ().dtype () == torch .half )
2323+
2324+ self .assertEqual (mod_eager (x ), scripted_mod (x ), atol = 1e-2 , rtol = 1e-2 )
2325+ self .assertEqual (mod_eager (x ), scripted_mod (x ), atol = 1e-2 , rtol = 1e-2 )
2326+
22262327 @unittest .skipIf (not TEST_CUDA , "Optimization currently only run for GPU" )
22272328 def test_linear_concat (self ):
22282329 out_dimms = [[5 , 10 ], [1 , 5 ]]
0 commit comments