Skip to content

Commit 6d2b0cb

Browse files
min-jean-chopytorchmergebot
authored andcommitted
[Re-landing 86706] [JIT] Frozen Graph Linear-BatchNormNd Folding (#91020)
Re-landing #86706 This PR adds linear-batchnormNd folding for JIT frozen graphs. **Performance benchmark** A preliminary benchmark with a simple model of linear+bn1d tested on first socket, physical cores of skylake machine. **FP32, JIT** without linear-bn folding ![Screenshot (1368)](https://user-images.githubusercontent.com/93151422/195168944-cfc5b920-bc82-4be1-a221-d194c8fa6c18.png) with linear-bn folding ![Screenshot (1367)](https://user-images.githubusercontent.com/93151422/195168926-267b0515-45a1-4f08-922d-c150845199ae.png) Pull Request resolved: #91020 Approved by: https://github.com/davidberard98
1 parent e8bf7c2 commit 6d2b0cb

File tree

9 files changed

+306
-0
lines changed

9 files changed

+306
-0
lines changed

build_variables.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,12 @@ core_sources_full_mobile_no_backend_interface_xplat = [
297297
"torch/csrc/jit/passes/remove_mutation.cpp",
298298
"torch/csrc/jit/passes/prepack_folding.cpp",
299299
"torch/csrc/jit/passes/fold_conv_bn.cpp",
300+
"torch/csrc/jit/passes/fold_linear_bn.cpp",
300301
"torch/csrc/jit/passes/dbr_quantization/remove_redundant_aliases.cpp",
301302
"torch/csrc/jit/passes/frozen_concat_linear.cpp",
302303
"torch/csrc/jit/passes/frozen_conv_add_relu_fusion.cpp",
303304
"torch/csrc/jit/passes/frozen_conv_folding.cpp",
305+
"torch/csrc/jit/passes/frozen_linear_folding.cpp",
304306
"torch/csrc/jit/passes/frozen_linear_transpose.cpp",
305307
"torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp",
306308
"torch/csrc/jit/passes/frozen_graph_optimizations.cpp",

test/jit/test_freezing.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2223,6 +2223,107 @@ def forward(self, x):
22232223
FileCheck().check("conv").check_not("aten::batch_norm").run(traced_model.graph)
22242224
FileCheck().check("conv").check_not("aten::add").run(traced_model.graph)
22252225

2226+
def test_linear_bn_folding(self):
2227+
module_pairs = [(nn.Linear, nn.BatchNorm1d), (nn.Linear, nn.BatchNorm2d), (nn.Linear, nn.BatchNorm3d)]
2228+
use_tracing = [True, False]
2229+
bn_running_stats = [True, False]
2230+
2231+
for modules, tracing, track_stats in product(module_pairs, use_tracing, bn_running_stats):
2232+
class LinearBN(torch.nn.Module):
2233+
def __init__(self, in_features, out_features):
2234+
super(LinearBN, self).__init__()
2235+
self.linear = modules[0](in_features, out_features)
2236+
self.bn = modules[1](out_features, eps=0.001, track_running_stats=track_stats)
2237+
2238+
def forward(self, x):
2239+
x = self.linear(x)
2240+
return self.bn(x)
2241+
2242+
mod_eager = LinearBN(32, 32).eval()
2243+
2244+
inps = [3, 32]
2245+
if modules[1] == nn.BatchNorm2d:
2246+
inps.append(inps[-1])
2247+
inps.append(inps[-1])
2248+
if modules[1] == nn.BatchNorm3d:
2249+
inps.append(inps[-1])
2250+
inps.append(inps[-1])
2251+
inps.append(inps[-1])
2252+
2253+
inp = torch.rand(inps)
2254+
2255+
if tracing:
2256+
scripted_mod = torch.jit.trace(mod_eager, (inp))
2257+
else:
2258+
scripted_mod = torch.jit.script(mod_eager)
2259+
2260+
self.run_pass("inline", scripted_mod.graph)
2261+
self.run_pass("peephole", scripted_mod.graph)
2262+
self.run_pass("constant_propagation", scripted_mod.graph)
2263+
2264+
FileCheck().check("linear").check("batch").run(scripted_mod.graph)
2265+
# successfully no-ops with non-const inputs
2266+
self.run_pass("fold_frozen_linear_bn", scripted_mod.graph)
2267+
FileCheck().check("linear").check("aten::batch_norm").run(scripted_mod.graph)
2268+
2269+
scripted_mod = torch.jit.freeze(scripted_mod)
2270+
self.run_pass("fold_frozen_linear_bn", scripted_mod.graph)
2271+
if track_stats:
2272+
FileCheck().check("linear").check_not("aten::batch_norm").run(scripted_mod.graph)
2273+
else:
2274+
FileCheck().check("linear").check("aten::batch_norm").run(scripted_mod.graph)
2275+
2276+
self.assertEqual(mod_eager(inp), scripted_mod(inp))
2277+
self.assertEqual(mod_eager(inp), scripted_mod(inp))
2278+
2279+
@skipCUDAMemoryLeakCheckIf(True)
2280+
@unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
2281+
def test_linear_bn_folding_autocast_scenario_cuda(self):
2282+
module_pairs = [(nn.Linear, nn.BatchNorm1d), (nn.Linear, nn.BatchNorm2d), (nn.Linear, nn.BatchNorm3d)]
2283+
use_tracing = [True, False]
2284+
bn_running_stats = [True, False]
2285+
2286+
for modules, tracing, track_stats in product(module_pairs, use_tracing, bn_running_stats):
2287+
class LinearBN(torch.nn.Module):
2288+
def __init__(self, in_features, out_features):
2289+
super(LinearBN, self).__init__()
2290+
self.linear = modules[0](in_features, out_features, bias=False, dtype=torch.half)
2291+
self.bn = modules[1](out_features, eps=0.001, dtype=torch.float)
2292+
2293+
def forward(self, x):
2294+
x = self.linear(x)
2295+
return self.bn(x)
2296+
2297+
mod_eager = LinearBN(32, 32).cuda().eval()
2298+
2299+
inps = [3, 32]
2300+
if modules[1] == nn.BatchNorm2d:
2301+
inps.append(inps[-1])
2302+
inps.append(inps[-1])
2303+
if modules[1] == nn.BatchNorm3d:
2304+
inps.append(inps[-1])
2305+
inps.append(inps[-1])
2306+
inps.append(inps[-1])
2307+
2308+
x = torch.rand(inps, dtype=torch.half).cuda()
2309+
2310+
if tracing:
2311+
scripted_mod = torch.jit.trace(mod_eager, (x))
2312+
else:
2313+
scripted_mod = torch.jit.script(mod_eager)
2314+
scripted_mod = torch.jit.freeze(scripted_mod)
2315+
FileCheck().check("linear").check_not("aten::batch_norm").run(scripted_mod.graph)
2316+
lin_node = scripted_mod.graph.findNode("aten::linear", True)
2317+
self.assertTrue(lin_node is not None)
2318+
weight_input = lin_node.namedInput("weight")
2319+
bias_input = lin_node.namedInput("bias")
2320+
self.assertTrue(bias_input is not None)
2321+
self.assertTrue(weight_input.type().dtype() == torch.half)
2322+
self.assertTrue(bias_input.type().dtype() == torch.half)
2323+
2324+
self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)
2325+
self.assertEqual(mod_eager(x), scripted_mod(x), atol=1e-2, rtol=1e-2)
2326+
22262327
@unittest.skipIf(not TEST_CUDA, "Optimization currently only run for GPU")
22272328
def test_linear_concat(self):
22282329
out_dimms = [[5, 10], [1, 5]]
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#include <torch/csrc/jit/passes/fold_linear_bn.h>
2+
3+
#include <ATen/TensorOperators.h>
4+
5+
#ifndef AT_PER_OPERATOR_HEADERS
6+
#include <ATen/Functions.h>
7+
#else
8+
#include <ATen/ops/rsqrt.h>
9+
#endif
10+
11+
namespace torch {
12+
namespace jit {
13+
14+
std::tuple<at::Tensor, at::Tensor> computeUpdatedLinearWeightAndBias(
15+
const LinearBNParameters& p) {
16+
at::Tensor bn_scale = p.bn_w * at::rsqrt(p.bn_rv + p.bn_eps);
17+
at::Tensor fused_w = p.linear_w * bn_scale.unsqueeze(-1);
18+
at::Tensor fused_b = (p.linear_b - p.bn_rm) * bn_scale + p.bn_b;
19+
20+
auto linear_w_dtype = p.linear_w.dtype();
21+
auto linear_b_dtype = p.linear_b.dtype();
22+
23+
return std::make_tuple(
24+
fused_w.to(linear_w_dtype), fused_b.to(linear_b_dtype));
25+
}
26+
27+
} // namespace jit
28+
} // namespace torch
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <torch/csrc/jit/api/module.h>
4+
5+
namespace torch {
6+
namespace jit {
7+
8+
struct TORCH_API LinearBNParameters {
9+
at::Tensor linear_w;
10+
at::Tensor linear_b;
11+
at::Tensor bn_rm;
12+
at::Tensor bn_rv;
13+
double bn_eps = 0.0;
14+
at::Tensor bn_w;
15+
at::Tensor bn_b;
16+
};
17+
18+
/**
19+
* Given the current weight and bias tensors of a Linear module and parameters
20+
* of the BatchNorm module we're folding with, compute the updated values
21+
* for the weight and bias.
22+
*
23+
* The function is basically copied from torch/nn/utils/fusion.py
24+
*/
25+
TORCH_API std::tuple<at::Tensor, at::Tensor> computeUpdatedLinearWeightAndBias(
26+
const LinearBNParameters& p);
27+
28+
} // namespace jit
29+
} // namespace torch

torch/csrc/jit/passes/frozen_graph_optimizations.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <torch/csrc/jit/passes/frozen_concat_linear.h>
55
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
66
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
7+
#include <torch/csrc/jit/passes/frozen_linear_folding.h>
78
#include <torch/csrc/jit/passes/remove_dropout.h>
89
#include <torch/csrc/jit/runtime/graph_executor.h>
910
#include <torch/csrc/utils/memory.h>
@@ -24,6 +25,7 @@ void OptimizeFrozenGraph(
2425
changed |= FoldFrozenConvBatchnorm(graph);
2526
changed |= FoldFrozenConvAddOrSub(graph);
2627
changed |= FoldFrozenConvMulOrDiv(graph);
28+
changed |= FoldFrozenLinearBatchnorm(graph);
2729
} while (changed);
2830
}
2931
}

torch/csrc/jit/passes/frozen_graph_optimizations.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
* - FoldFrozenConvBatchnorm
99
* - FoldFrozenConvAddOrSub
1010
* - FoldFrozenConvMulOrDiv
11+
* - FoldFrozenLinearBatchnorm
1112
*/
1213

1314
namespace torch {
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#include <torch/csrc/jit/ir/constants.h>
2+
#include <torch/csrc/jit/ir/ir.h>
3+
#include <torch/csrc/jit/passes/dead_code_elimination.h>
4+
#include <torch/csrc/jit/passes/fold_linear_bn.h>
5+
#include <torch/csrc/jit/passes/frozen_linear_folding.h>
6+
#include <torch/csrc/jit/passes/utils/optimization_utils.h>
7+
8+
#ifndef AT_PER_OPERATOR_HEADERS
9+
#include <ATen/Functions.h>
10+
#else
11+
#include <ATen/ops/ones_like.h>
12+
#include <ATen/ops/zeros_like.h>
13+
#endif
14+
15+
namespace torch {
16+
namespace jit {
17+
18+
namespace {
19+
20+
using Tensor = at::Tensor;
21+
22+
bool supportedLinearNode(Node* n) {
23+
if (n->kind() == aten::linear) {
24+
return true;
25+
} else {
26+
return false;
27+
}
28+
}
29+
30+
bool FoldFrozenLinearBatchnorm(Block* b) {
31+
bool graph_modified = false;
32+
for (Node* n : b->nodes()) {
33+
for (Block* block : n->blocks()) {
34+
graph_modified |= FoldFrozenLinearBatchnorm(block);
35+
}
36+
37+
if (n->kind() == aten::batch_norm &&
38+
supportedLinearNode(n->inputs().at(0)->node())) {
39+
auto linear = n->inputs().at(0)->node();
40+
auto bn = n;
41+
42+
if (nonConstantParameters(linear) || nonConstantParameters(bn)) {
43+
continue;
44+
}
45+
46+
auto bn_rm_ivalue = bn->namedInput("running_mean");
47+
auto bn_rv_ivalue = bn->namedInput("running_var");
48+
49+
// check running_mean and running_var has value, if they are
50+
// None(track_running_stats=False), skiping the folding path.
51+
if (bn_rm_ivalue->type() == NoneType::get() &&
52+
bn_rv_ivalue->type() == NoneType::get()) {
53+
continue;
54+
}
55+
56+
auto bn_rm = constant_as<Tensor>(bn->namedInput("running_mean")).value();
57+
auto bn_rv = constant_as<Tensor>(bn->namedInput("running_var")).value();
58+
auto bn_eps = constant_as<double>(bn->namedInput("eps")).value();
59+
auto linear_w = constant_as<Tensor>(linear->namedInput("weight")).value();
60+
61+
// implementation taken from torch/nn/utils/fusion.py
62+
Tensor linear_b;
63+
if (linear->namedInput("bias")->type() == NoneType::get()) {
64+
at::ScalarType bias_dtype = bn_rm.scalar_type();
65+
at::ScalarType weight_dtype = linear_w.scalar_type();
66+
at::DeviceType weight_device = linear_w.device().type();
67+
if (weight_device == at::kCUDA &&
68+
(weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) &&
69+
bias_dtype == at::kFloat) {
70+
bias_dtype = weight_dtype;
71+
}
72+
linear_b = at::zeros_like(bn_rm, at::TensorOptions().dtype(bias_dtype));
73+
} else {
74+
linear_b = constant_as<Tensor>(linear->namedInput("bias")).value();
75+
}
76+
Tensor bn_w;
77+
if (bn->namedInput("weight")->type() == NoneType::get()) {
78+
bn_w = at::ones_like(bn_rm);
79+
} else {
80+
bn_w = constant_as<Tensor>(bn->namedInput("weight")).value();
81+
}
82+
Tensor bn_b;
83+
if (n->namedInput("bias")->type() == NoneType::get()) {
84+
bn_b = at::zeros_like(bn_rm);
85+
} else {
86+
bn_b = constant_as<Tensor>(bn->namedInput("bias")).value();
87+
}
88+
89+
LinearBNParameters params;
90+
params.linear_w = linear_w;
91+
params.linear_b = linear_b;
92+
params.bn_rm = bn_rm;
93+
params.bn_rv = bn_rv;
94+
params.bn_eps = bn_eps;
95+
params.bn_w = bn_w;
96+
params.bn_b = bn_b;
97+
std::tuple<Tensor, Tensor> out =
98+
computeUpdatedLinearWeightAndBias(params);
99+
WithInsertPoint guard(linear);
100+
auto fused_linear_w = b->owningGraph()->insertConstant(std::get<0>(out));
101+
auto fused_linear_b = b->owningGraph()->insertConstant(std::get<1>(out));
102+
auto linear_w_value = linear->namedInput("weight");
103+
auto linear_b_value = linear->namedInput("bias");
104+
105+
fused_linear_w->setDebugName(linear_w_value->debugName() + "_fused_bn");
106+
fused_linear_b->setDebugName(linear_b_value->debugName() + "_fused_bn");
107+
108+
linear->replaceInputWith(linear_w_value, fused_linear_w);
109+
linear->replaceInputWith(linear_b_value, fused_linear_b);
110+
111+
bn->output()->replaceAllUsesWith(linear->output());
112+
graph_modified = true;
113+
}
114+
}
115+
return graph_modified;
116+
}
117+
118+
} // namespace
119+
120+
bool FoldFrozenLinearBatchnorm(std::shared_ptr<Graph>& graph) {
121+
bool graph_modified = FoldFrozenLinearBatchnorm(graph->block());
122+
EliminateDeadCode(graph);
123+
return graph_modified;
124+
}
125+
126+
} // namespace jit
127+
} // namespace torch
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
3+
#include <torch/csrc/jit/ir/ir.h>
4+
5+
namespace torch {
6+
namespace jit {
7+
8+
// Fuses Linear -> BatchNormNd into a single Linear by
9+
// folding batchnorm weights into linear weights.
10+
// This pass only works on Frozen Graphs; otherwise it is a No-Op.
11+
TORCH_API bool FoldFrozenLinearBatchnorm(std::shared_ptr<Graph>& graph);
12+
13+
} // namespace jit
14+
} // namespace torch

torch/csrc/jit/python/init.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
4141
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
4242
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
43+
#include <torch/csrc/jit/passes/frozen_linear_folding.h>
4344
#include <torch/csrc/jit/passes/frozen_linear_transpose.h>
4445
#include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
4546
#include <torch/csrc/jit/passes/fuse_linear.h>
@@ -399,6 +400,7 @@ void initJITBindings(PyObject* module) {
399400
.def("_jit_pass_fold_frozen_conv_bn", &FoldFrozenConvBatchnorm)
400401
.def("_jit_pass_fold_frozen_conv_add_or_sub", &FoldFrozenConvAddOrSub)
401402
.def("_jit_pass_fold_frozen_conv_mul_or_div", &FoldFrozenConvMulOrDiv)
403+
.def("_jit_pass_fold_frozen_linear_bn", &FoldFrozenLinearBatchnorm)
402404
.def("_jit_pass_convert_frozen_ops_to_mkldnn", &ConvertFrozenOpsToMKLDNN)
403405
.def("_jit_pass_fuse_frozen_conv_add_relu", &FuseFrozenConvAddRelu)
404406
.def("_jit_pass_transpose_frozen_linear", &FrozenLinearTranspose)

0 commit comments

Comments
 (0)