Skip to content

Commit 94964a9

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Add fusion for quantized linear (#25624)
Summary: Pull Request resolved: #25624 First fuse the splitted op into aten::linear and then fuse `dequant - aten::linear - quant` into quantized linear op Test Plan: python test/test_jit.py 'TestJit.quant_fusion' Imported from OSS Differential Revision: D17208891 fbshipit-source-id: 864b19fabab2e8e6f8f8ad35eb3dbbf2d5fdb8c4
1 parent e9e7e9d commit 94964a9

File tree

4 files changed

+101
-34
lines changed

4 files changed

+101
-34
lines changed

test/test_jit.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,7 +1123,7 @@ def get_forward(m):
11231123
.run(str(m._c._get_module('conv')._get_method('conv2d_forward').graph))
11241124

11251125
def test_quant_fusion(self):
1126-
input_str = """
1126+
input_strs = ["""
11271127
graph(%a, %w, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype,
11281128
%b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype, %c, %d, %e, %f):
11291129
%a_quant = aten::quantize_linear(%a, %a_scale, %a_zero_point, %a_dtype)
@@ -1151,12 +1151,40 @@ def test_quant_fusion(self):
11511151
%r_intrepr = aten::int_repr(%r_quant)
11521152
# CHECK: aten::_dequantize_linear
11531153
%r_dequant = aten::_dequantize_linear(%r_intrepr, %r_scale, %r_zero_point, %r_dtype)
1154-
return (%r_dequant)
1155-
)
1156-
"""
1157-
graph = parse_ir(input_str)
1158-
torch._C._jit_pass_quant_fusion(graph)
1159-
FileCheck().run(input_str, graph)
1154+
return (%r_dequant)""",
1155+
"""
1156+
graph(%a, %w, %b, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype,
1157+
%b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype):
1158+
%a_quant = aten::quantize_linear(%a, %a_scale, %a_zero_point, %a_dtype)
1159+
# CHECK-NOT: aten::int_repr
1160+
%a_intrepr = aten::int_repr(%a_quant)
1161+
# CHECK-NOT: aten::_dequantize_linear
1162+
%a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
1163+
%w_quant = aten::quantize_linear(%w, %w_scale, %w_zero_point, %w_dtype)
1164+
# CHECK-NOT: aten::int_repr
1165+
%w_intrepr = aten::int_repr(%w_quant)
1166+
# CHECK-NOT: aten::_dequantize_linear
1167+
%w_dequant = aten::_dequantize_linear(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
1168+
# CHECK-NOT: aten::int_repr
1169+
%b_quant = aten::quantize_linear(%b, %b_scale, %b_zero_point, %b_dtype)
1170+
%b_intrepr = aten::int_repr(%b_quant)
1171+
# CHECK-NOT: aten::_dequantize_linear
1172+
%b_dequant = aten::_dequantize_linear(%b_intrepr, %b_scale, %b_zero_point, %b_dtype)
1173+
# CHECK: quantized::fbgemm_linear_prepack
1174+
# CHECK: quantized::fbgemm_linear
1175+
# CHECK-NOT: aten::linear
1176+
%r = aten::linear(%a_dequant, %w_dequant, %b_dequant)
1177+
# CHECK-NOT: aten::quantize_linear
1178+
%r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype)
1179+
# CHECK: aten::int_repr
1180+
%r_intrepr = aten::int_repr(%r_quant)
1181+
# CHECK: aten::_dequantize_linear
1182+
%r_dequant = aten::_dequantize_linear(%r_intrepr, %r_scale, %r_zero_point, %r_dtype)
1183+
return (%r_dequant)"""]
1184+
for input_str in input_strs:
1185+
graph = parse_ir(input_str)
1186+
torch._C._jit_pass_quant_fusion(graph)
1187+
FileCheck().run(input_str, graph)
11601188

11611189
@_tmp_donotuse_dont_inline_everything
11621190
def test_foldbn_trivial(self):

torch/csrc/jit/init.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,9 +167,7 @@ void initJITBindings(PyObject* module) {
167167
"_jit_pass_quant_fusion",
168168
[](std::shared_ptr<Graph>& g) { return QuantFusion(g); })
169169
.def("_jit_pass_fold_convbn", &FoldConvBatchNorm2d)
170-
.def(
171-
"_jit_pass_fuse_linear",
172-
[](std::shared_ptr<Graph>& g) { return FuseLinear(g); })
170+
.def("_jit_pass_fuse_linear", &FuseLinear)
173171
.def(
174172
"_jit_pass_quantlint",
175173
[](std::shared_ptr<Graph>& g) { return QuantLinting(g); })

torch/csrc/jit/passes/quantization.cpp

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <torch/csrc/jit/passes/quantization.h>
2-
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
32
#include <torch/csrc/jit/passes/constant_propagation.h>
3+
#include <torch/csrc/jit/passes/fuse_linear.h>
4+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
45

56
#include <torch/csrc/jit/ir.h>
67
#include <torch/csrc/jit/irparser.h>
@@ -128,7 +129,7 @@ Node* insertObserver(
128129
std::string observer_name = "observer_for_" + v->debugName();
129130
// Temporary workaround to skip inserting duplicate modules,
130131
// full support will come in next PR
131-
for (script::Slot s: module.get_module_slots()) {
132+
for (script::Slot s : module.get_module_slots()) {
132133
if (s.name() == observer_name) {
133134
return nullptr;
134135
}
@@ -257,14 +258,20 @@ void InsertObserversImpl(
257258
TORCH_INTERNAL_ASSERT(
258259
child_module,
259260
"Child module " + child_module_name + " does not exist");
260-
// Recursively insert observer for the forward function of child module
261-
InsertObserversImpl(child_module.value(), module_method_name, module_qconfig_map, values_to_skip);
261+
// Recursively insert observer for the forward function of child
262+
// module
263+
InsertObserversImpl(
264+
child_module.value(),
265+
module_method_name,
266+
module_qconfig_map,
267+
values_to_skip);
262268
} else {
263269
TORCH_INTERNAL_ASSERT(
264270
module_instance == graph->inputs()[0],
265271
"We only support call method either on %self"
266272
"or child instance in insert_observers_pass right now");
267-
InsertObserversImpl(module, module_method_name, module_qconfig_map, values_to_skip);
273+
InsertObserversImpl(
274+
module, module_method_name, module_qconfig_map, values_to_skip);
268275
}
269276
}
270277
}
@@ -366,7 +373,8 @@ class QuantizeHelper {
366373
public:
367374
QuantizeHelper(const script::Module& m) : module_(m) {}
368375
IValue getQParams(Value* v);
369-
c10::optional<script::Module> findChildModuleToQuantize(Value* child_instance);
376+
c10::optional<script::Module> findChildModuleToQuantize(
377+
Value* child_instance);
370378
void quantizeBias(Value* v);
371379
void quantizeTensor(Value* v, bool insert_after = true);
372380
void removeObserver(Value* v, const std::string& observer_name);
@@ -489,7 +497,7 @@ c10::optional<script::Module> QuantizeHelper::findChildModuleToQuantize(
489497
TORCH_INTERNAL_ASSERT(
490498
child_module,
491499
"InsertQuantDeQuant - Child module " + child_module_name +
492-
" does not exist");
500+
" does not exist");
493501
return child_module;
494502
}
495503
return c10::nullopt;
@@ -558,22 +566,16 @@ void InsertQuantDeQuantImpl(
558566

559567
qh.destroyNodes();
560568
}
561-
562569
} // namespace
563570

564571
TORCH_API void InsertObservers(
565572
script::Module& module,
566573
const std::string& method_name,
567574
const QConfigDict& qconfig_dict) {
568575
ModuleQConfigMap module_qconfig_map;
569-
fillQConfigMap(module,
570-
qconfig_dict,
571-
module_qconfig_map);
576+
fillQConfigMap(module, qconfig_dict, module_qconfig_map);
572577
std::unordered_set<Value*> values_to_skip;
573-
InsertObserversImpl(module,
574-
method_name,
575-
module_qconfig_map,
576-
values_to_skip);
578+
InsertObserversImpl(module, method_name, module_qconfig_map, values_to_skip);
577579
}
578580

579581
script::Module InsertQuantDeQuant(
@@ -602,7 +604,11 @@ void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph) {
602604
}
603605

604606
void QuantFusion(std::shared_ptr<Graph>& graph) {
605-
std::string pattern = R"(
607+
// First fuse aten::linear op
608+
FuseLinear(graph);
609+
const std::unordered_map<std::string, std::string> pattern_and_replacements =
610+
{// quantized::conv2d
611+
{R"(
606612
graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype, %c, %d, %e, %f):
607613
%a_intrepr = aten::int_repr(%a_quant)
608614
%a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
@@ -612,9 +618,8 @@ graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale,
612618
%b_dequant = aten::_dequantize_linear(%b_intrepr, %b_scale, %b_zero_point, %b_dtype)
613619
%r = aten::conv2d(%a_dequant, %w_dequant, %b_dequant, %c, %d, %e, %f)
614620
%r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype)
615-
return (%r_quant))";
616-
617-
std::string replacement = R"(
621+
return (%r_quant))",
622+
R"(
618623
graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
619624
%0 : int = prim::Constant[value=0]()
620625
%1 : int = prim::Constant[value=1]()
@@ -627,10 +632,38 @@ graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale,
627632
%r = quantized::conv2d(%a_perm, %w_packed, %b_quant, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
628633
%out_param : int[] = prim::ListConstruct(%0, %3, %1, %2)
629634
%r_perm = aten::permute(%r, %out_param)
630-
return (%r_perm))";
631-
SubgraphRewriter rewriter;
632-
rewriter.RegisterRewritePattern(pattern, replacement);
633-
rewriter.runOnGraph(graph);
635+
return (%r_perm))"},
636+
// quantized::linear
637+
{R"(
638+
graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype):
639+
%a_intrepr = aten::int_repr(%a_quant)
640+
%a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
641+
%w_intrepr = aten::int_repr(%w_quant)
642+
%w_dequant = aten::_dequantize_linear(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
643+
%b_intrepr = aten::int_repr(%b_quant)
644+
%b_dequant = aten::_dequantize_linear(%b_intrepr, %b_scale, %b_zero_point, %b_dtype)
645+
%r = aten::linear(%a_dequant, %w_dequant, %b_dequant)
646+
%r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype)
647+
return (%r_quant))",
648+
R"(
649+
graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype):
650+
%0 : int = prim::Constant[value=0]()
651+
%1 : int = prim::Constant[value=1]()
652+
%2 : int = prim::Constant[value=2]()
653+
%3 : int = prim::Constant[value=3]()
654+
%in_param : int[] = prim::ListConstruct(%0, %2, %3, %1)
655+
%a_perm : Tensor = aten::permute(%a_quant, %in_param)
656+
%w_perm : Tensor = aten::permute(%w_quant, %in_param)
657+
%w_packed = quantized::fbgemm_linear_prepack(%w_perm)
658+
%r = quantized::fbgemm_linear(%a_perm, %w_packed, %b_quant, %r_scale, %r_zero_point)
659+
%out_param : int[] = prim::ListConstruct(%0, %3, %1, %2)
660+
%r_perm = aten::permute(%r, %out_param)
661+
return (%r_perm))"}};
662+
for (const auto& item : pattern_and_replacements) {
663+
SubgraphRewriter rewriter;
664+
rewriter.RegisterRewritePattern(item.first, item.second);
665+
rewriter.runOnGraph(graph);
666+
}
634667
}
635668

636669
struct ConvBNParameters {
@@ -792,6 +825,5 @@ graph(%self, %x):
792825
}
793826
}
794827
}
795-
796828
} // namespace jit
797829
} // namespace torch

torch/csrc/jit/passes/quantization.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ TORCH_API script::Module InsertQuantDeQuant(
8484
* Right now this is a fusion for fbgemm backend and only works for quantized
8585
* conv op, we'll extend to more ops and more backends in the future.
8686
*
87+
* Currently supported fusion:
88+
* q(conv2d(dq(a), dq(w), dq(b))) --> to_nchw(fbgemm_conv2d(prepack(to_nhwc(a)),
89+
* prepack(to_nhwc(w)),
90+
* prepack(to_nhwc(b))))
91+
*
92+
* q(linear(dq(a), dq(w), dq(b))) --> to_nchw(fbgemm_linear(prepack(to_nhwc(a)),
93+
* prepack(to_nhwc(w)),
94+
* prepack(to_nhwc(b))))
95+
*
8796
* \param graph the graph we want to apply fusion
8897
*/
8998
TORCH_API void QuantFusion(std::shared_ptr<Graph>& graph);

0 commit comments

Comments
 (0)