11#include < torch/csrc/jit/passes/quantization.h>
2- #include < torch/csrc/jit/passes/subgraph_rewrite.h>
32#include < torch/csrc/jit/passes/constant_propagation.h>
3+ #include < torch/csrc/jit/passes/fuse_linear.h>
4+ #include < torch/csrc/jit/passes/subgraph_rewrite.h>
45
56#include < torch/csrc/jit/ir.h>
67#include < torch/csrc/jit/irparser.h>
@@ -128,7 +129,7 @@ Node* insertObserver(
128129 std::string observer_name = " observer_for_" + v->debugName ();
129130 // Temporary workaround to skip inserting duplicate modules,
130131 // full support will come in next PR
131- for (script::Slot s: module .get_module_slots ()) {
132+ for (script::Slot s : module .get_module_slots ()) {
132133 if (s.name () == observer_name) {
133134 return nullptr ;
134135 }
@@ -257,14 +258,20 @@ void InsertObserversImpl(
257258 TORCH_INTERNAL_ASSERT (
258259 child_module,
259260 " Child module " + child_module_name + " does not exist" );
260- // Recursively insert observer for the forward function of child module
261- InsertObserversImpl (child_module.value (), module_method_name, module_qconfig_map, values_to_skip);
261+ // Recursively insert observer for the forward function of child
262+ // module
263+ InsertObserversImpl (
264+ child_module.value (),
265+ module_method_name,
266+ module_qconfig_map,
267+ values_to_skip);
262268 } else {
263269 TORCH_INTERNAL_ASSERT (
264270 module_instance == graph->inputs ()[0 ],
265271 " We only support call method either on %self"
266272 " or child instance in insert_observers_pass right now" );
267- InsertObserversImpl (module , module_method_name, module_qconfig_map, values_to_skip);
273+ InsertObserversImpl (
274+ module , module_method_name, module_qconfig_map, values_to_skip);
268275 }
269276 }
270277 }
@@ -366,7 +373,8 @@ class QuantizeHelper {
366373 public:
367374 QuantizeHelper (const script::Module& m) : module_(m) {}
368375 IValue getQParams (Value* v);
369- c10::optional<script::Module> findChildModuleToQuantize (Value* child_instance);
376+ c10::optional<script::Module> findChildModuleToQuantize (
377+ Value* child_instance);
370378 void quantizeBias (Value* v);
371379 void quantizeTensor (Value* v, bool insert_after = true );
372380 void removeObserver (Value* v, const std::string& observer_name);
@@ -489,7 +497,7 @@ c10::optional<script::Module> QuantizeHelper::findChildModuleToQuantize(
489497 TORCH_INTERNAL_ASSERT (
490498 child_module,
491499 " InsertQuantDeQuant - Child module " + child_module_name +
492- " does not exist" );
500+ " does not exist" );
493501 return child_module;
494502 }
495503 return c10::nullopt ;
@@ -558,22 +566,16 @@ void InsertQuantDeQuantImpl(
558566
559567 qh.destroyNodes ();
560568}
561-
562569} // namespace
563570
564571TORCH_API void InsertObservers (
565572 script::Module& module ,
566573 const std::string& method_name,
567574 const QConfigDict& qconfig_dict) {
568575 ModuleQConfigMap module_qconfig_map;
569- fillQConfigMap (module ,
570- qconfig_dict,
571- module_qconfig_map);
576+ fillQConfigMap (module , qconfig_dict, module_qconfig_map);
572577 std::unordered_set<Value*> values_to_skip;
573- InsertObserversImpl (module ,
574- method_name,
575- module_qconfig_map,
576- values_to_skip);
578+ InsertObserversImpl (module , method_name, module_qconfig_map, values_to_skip);
577579}
578580
579581script::Module InsertQuantDeQuant (
@@ -602,7 +604,11 @@ void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph) {
602604}
603605
604606void QuantFusion (std::shared_ptr<Graph>& graph) {
605- std::string pattern = R"(
607+ // First fuse aten::linear op
608+ FuseLinear (graph);
609+ const std::unordered_map<std::string, std::string> pattern_and_replacements =
610+ {// quantized::conv2d
611+ {R"(
606612graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype, %c, %d, %e, %f):
607613 %a_intrepr = aten::int_repr(%a_quant)
608614 %a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
@@ -612,9 +618,8 @@ graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale,
612618 %b_dequant = aten::_dequantize_linear(%b_intrepr, %b_scale, %b_zero_point, %b_dtype)
613619 %r = aten::conv2d(%a_dequant, %w_dequant, %b_dequant, %c, %d, %e, %f)
614620 %r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype)
615- return (%r_quant))" ;
616-
617- std::string replacement = R"(
621+ return (%r_quant))" ,
622+ R"(
618623graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype, %stride, %padding, %dilation, %groups):
619624 %0 : int = prim::Constant[value=0]()
620625 %1 : int = prim::Constant[value=1]()
@@ -627,10 +632,38 @@ graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale,
627632 %r = quantized::conv2d(%a_perm, %w_packed, %b_quant, %stride, %padding, %dilation, %groups, %r_scale, %r_zero_point)
628633 %out_param : int[] = prim::ListConstruct(%0, %3, %1, %2)
629634 %r_perm = aten::permute(%r, %out_param)
630- return (%r_perm))" ;
631- SubgraphRewriter rewriter;
632- rewriter.RegisterRewritePattern (pattern, replacement);
633- rewriter.runOnGraph (graph);
635+ return (%r_perm))" },
636+ // quantized::linear
637+ {R"(
638+ graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype):
639+ %a_intrepr = aten::int_repr(%a_quant)
640+ %a_dequant = aten::_dequantize_linear(%a_intrepr, %a_scale, %a_zero_point, %a_dtype)
641+ %w_intrepr = aten::int_repr(%w_quant)
642+ %w_dequant = aten::_dequantize_linear(%w_intrepr, %w_scale, %w_zero_point, %w_dtype)
643+ %b_intrepr = aten::int_repr(%b_quant)
644+ %b_dequant = aten::_dequantize_linear(%b_intrepr, %b_scale, %b_zero_point, %b_dtype)
645+ %r = aten::linear(%a_dequant, %w_dequant, %b_dequant)
646+ %r_quant = aten::quantize_linear(%r, %r_scale, %r_zero_point, %r_dtype)
647+ return (%r_quant))" ,
648+ R"(
649+ graph(%a_quant, %w_quant, %b_quant, %a_scale, %a_zero_point, %a_dtype, %w_scale, %w_zero_point, %w_dtype, %b_scale, %b_zero_point, %b_dtype, %r_scale, %r_zero_point, %r_dtype):
650+ %0 : int = prim::Constant[value=0]()
651+ %1 : int = prim::Constant[value=1]()
652+ %2 : int = prim::Constant[value=2]()
653+ %3 : int = prim::Constant[value=3]()
654+ %in_param : int[] = prim::ListConstruct(%0, %2, %3, %1)
655+ %a_perm : Tensor = aten::permute(%a_quant, %in_param)
656+ %w_perm : Tensor = aten::permute(%w_quant, %in_param)
657+ %w_packed = quantized::fbgemm_linear_prepack(%w_perm)
658+ %r = quantized::fbgemm_linear(%a_perm, %w_packed, %b_quant, %r_scale, %r_zero_point)
659+ %out_param : int[] = prim::ListConstruct(%0, %3, %1, %2)
660+ %r_perm = aten::permute(%r, %out_param)
661+ return (%r_perm))" }};
662+ for (const auto & item : pattern_and_replacements) {
663+ SubgraphRewriter rewriter;
664+ rewriter.RegisterRewritePattern (item.first , item.second );
665+ rewriter.runOnGraph (graph);
666+ }
634667}
635668
636669struct ConvBNParameters {
@@ -792,6 +825,5 @@ graph(%self, %x):
792825 }
793826 }
794827}
795-
796828} // namespace jit
797829} // namespace torch
0 commit comments