Skip to content

Commit e440c37

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant] Fix fuse linear pass (#40549)
Summary: Pull Request resolved: #40549 Currently we didn't check if %weight_t is produced by `aten::t`, this will fuse some `matmul`/`addmm` that is not 2d to `aten::linear`, which is incorrect Test Plan: Imported from OSS Differential Revision: D22225921 fbshipit-source-id: 9723e82fdbac6d8e1a7ade22f3a9791321ab12b6
1 parent eae1ed9 commit e440c37

File tree

4 files changed

+46
-7
lines changed

4 files changed

+46
-7
lines changed

test/quantization/test_quantize_jit.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,11 @@ def forward(self, x):
309309
x2 = torch.rand(5, 5)
310310
w2 = torch.rand(5, 5)
311311
b2 = torch.rand(5)
312-
for has_bias, (x, weight, b) in itertools.product([True, False], [(x1, w1, b1), (x2, w2, b2)]):
312+
313+
x3 = torch.rand(5, 5, 5)
314+
w3 = torch.rand(5, 5)
315+
b3 = torch.rand(5)
316+
for has_bias, (x, weight, b) in itertools.product([True, False], [(x1, w1, b1), (x2, w2, b2), (x3, w3, b3)]):
313317
bias = b if has_bias else None
314318
model = torch.jit.trace(FunctionalLinear(weight, bias), [x])
315319
torch._C._jit_pass_fuse_linear(model.graph)
@@ -319,6 +323,29 @@ def forward(self, x):
319323
for cn in check_not:
320324
FileCheck().check_not(cn) \
321325
.run(model.graph)
326+
# make sure it runs
327+
model(x)
328+
329+
# check matmuls are not fused
330+
class Matmul(torch.nn.Module):
331+
def __init__(self, weight):
332+
super(Matmul, self).__init__()
333+
self.weight = weight
334+
335+
def forward(self, x):
336+
return torch.matmul(x, self.weight)
337+
338+
x = torch.rand(5, 6, 5)
339+
w = torch.rand(5, 5, 100)
340+
model = torch.jit.trace(Matmul(w), [x])
341+
torch._C._jit_pass_fuse_linear(model.graph)
342+
# check 3d matmul is not fused
343+
FileCheck().check("aten::matmul") \
344+
.run(model.graph)
345+
FileCheck().check_not("aten::linear") \
346+
.run(model.graph)
347+
# make sure it runs
348+
model(x)
322349

323350
def test_insert_observers(self):
324351
class M(torch.nn.Module):
@@ -2672,7 +2699,7 @@ class TestQuantizeDynamicJitOps(QuantizationTestCase):
26722699
for individual ops end to end.
26732700
"""
26742701
@override_qengines
2675-
def test_quantized_linear_dynamic(self):
2702+
def test_linear(self):
26762703
class FunctionalLinear(torch.nn.Module):
26772704
def __init__(self, weight, bias):
26782705
super(FunctionalLinear, self).__init__()

torch/csrc/jit/passes/fuse_linear.cpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,21 @@ void FuseLinear(std::shared_ptr<Graph>& graph) {
2121
return is_int_constant(match, vmap, "beta", 1);
2222
};
2323

24+
// check %weight_t is produced by `aten::t` to make sure
25+
// we can transform the pattern to `aten::linear`
26+
auto weight_transposed =
27+
[](const Match& match,
28+
const std::unordered_map<std::string, Value*>& vmap) {
29+
const auto& match_vmap = match.values_map;
30+
auto v = match_vmap.at(vmap.at("weight_t"));
31+
return v->node()->kind() == Symbol::aten("t");
32+
};
33+
2434
// replace addmm pattern to linear
2535
SubgraphRewriter addmm_to_linear;
2636
addmm_to_linear.RegisterRewritePattern(addmm_pattern, fused_linear_addmm);
27-
addmm_to_linear.runOnGraph(graph, {aten_add_alpha_is_one, beta_is_one});
37+
addmm_to_linear.runOnGraph(
38+
graph, {aten_add_alpha_is_one, beta_is_one, weight_transposed});
2839

2940
std::string matmul_add_pattern = R"IR(
3041
graph(%input, %weight_t, %bias, %alpha):
@@ -40,7 +51,8 @@ void FuseLinear(std::shared_ptr<Graph>& graph) {
4051
SubgraphRewriter matmuladd_to_linear;
4152
matmuladd_to_linear.RegisterRewritePattern(
4253
matmul_add_pattern, fused_linear_matmul);
43-
matmuladd_to_linear.runOnGraph(graph, aten_add_alpha_is_one);
54+
matmuladd_to_linear.runOnGraph(
55+
graph, {aten_add_alpha_is_one, weight_transposed});
4456

4557
std::string matmul_pattern = R"IR(
4658
graph(%input, %weight_t):
@@ -57,7 +69,7 @@ void FuseLinear(std::shared_ptr<Graph>& graph) {
5769
SubgraphRewriter matmul_to_linear;
5870
matmul_to_linear.RegisterRewritePattern(
5971
matmul_pattern, fused_linear_bias_none);
60-
matmul_to_linear.runOnGraph(graph);
72+
matmul_to_linear.runOnGraph(graph, weight_transposed);
6173

6274
// clean up extra transpose for the weight of aten::linear
6375
std::string linear_weight_extra_transpose = R"IR(

torch/csrc/jit/passes/graph_rewrite_helper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ std::unordered_map<std::string, c10::IValue> getConvParams(
6161
}
6262

6363
void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph) {
64+
// TODO: remove constant prop in the pass
6465
ConstantPropagation(graph);
6566
std::string convolution = R"(
6667
graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],

torch/csrc/jit/passes/quantization/insert_observers.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,10 +1107,9 @@ void InsertObserversHelper::preprocess(
11071107

11081108
Method method = module.get_method(method_name);
11091109
auto graph = method.graph();
1110-
// must do constant propagation first before replacement
1111-
replaceConvolutionWithAtenConv(graph);
11121110
// fuse decomposed linear into aten::linear
11131111
FuseLinear(graph);
1112+
replaceConvolutionWithAtenConv(graph);
11141113
}
11151114

11161115
void InsertObserversHelper::analyze(

0 commit comments

Comments
 (0)