1010#include < torch/csrc/jit/operator.h>
1111#include < torch/csrc/jit/subgraph_matcher.h>
1212
13+ #include < algorithm>
1314#include < stack>
1415
1516namespace torch {
@@ -64,8 +65,59 @@ graph(%self, %input):
6465 }
6566}
6667
67- static bool outputsNeedToBeObserved (Node* n) {
68- return n->kind () != prim::Constant;
68+ std::string getFuncName (const c10::QualifiedName& qname) {
69+ const auto & name = qname.qualifiedName ();
70+ auto rdot_idx = name.rfind (' .' );
71+ if (rdot_idx != std::string::npos) {
72+ return name.substr (rdot_idx + 1 , name.length ());
73+ } else {
74+ return name;
75+ }
76+ }
77+
78+ bool nodeQuantizable (Node* n) {
79+ static std::vector<std::string> call_funcs = {
80+ " conv2d" ,
81+ " linear" ,
82+ " relu" ,
83+ };
84+ std::vector<Symbol> aten_funcs;
85+ std::transform (
86+ call_funcs.begin (),
87+ call_funcs.end (),
88+ std::back_inserter (aten_funcs),
89+ [](const std::string& s) { return Symbol::aten (s); });
90+ bool is_quantizable =
91+ std::find (aten_funcs.begin (), aten_funcs.end (), n->kind ()) !=
92+ aten_funcs.end ();
93+ if (n->kind () == prim::CallFunction) {
94+ auto func_node = n->inputs ()[0 ]->node ();
95+ auto func = func_node->output ()->type ()->expect <FunctionType>()->function ();
96+ auto func_name = getFuncName (func->qualname ());
97+ if (func_node->kind () == prim::Constant) {
98+ is_quantizable |=
99+ std::find (call_funcs.begin (), call_funcs.end (), func_name) !=
100+ call_funcs.end ();
101+ }
102+ }
103+ return is_quantizable;
104+ }
105+
106+ bool valueNeedsToBeQuantized (Value* v) {
107+ if (!v->type ()->isSubtypeOf (TensorType::get ())) {
108+ return false ;
109+ }
110+ // Check whether producer is quantizable
111+ if (nodeQuantizable (v->node ())) {
112+ return true ;
113+ }
114+ // Check whether user is quantizable
115+ for (const auto & use : v->uses ()) {
116+ if (nodeQuantizable (use.user )) {
117+ return true ;
118+ }
119+ }
120+ return false ;
69121}
70122
71123Node* traverseToQuantNode (Node* dq) {
@@ -216,8 +268,7 @@ void InsertObserversImpl(
216268 Value* self = graph->inputs ()[0 ];
217269 for (size_t idx = 1 ; idx < method.num_inputs (); ++idx) {
218270 auto & v = graph->inputs ()[idx];
219- if (v->type ()->isSubtypeOf (TensorType::get ()) &&
220- values_to_skip.count (v) == 0 ) {
271+ if (!values_to_skip.count (v) && valueNeedsToBeQuantized (v)) {
221272 auto qconfig = module_qconfig_map.at (module .module_object ());
222273 if (qconfig) {
223274 auto observer_node =
@@ -234,16 +285,15 @@ void InsertObserversImpl(
234285 Block* b = blocks_to_visit.top ();
235286 blocks_to_visit.pop ();
236287 for (Node* n : b->nodes ()) {
237- // Skip nodes that we don't need to observe, e.g. 'prim::Constant' or
238- // observer nodes
239- if (!outputsNeedToBeObserved (n) || observer_for_input.count (n) != 0 ) {
288+ // Skip observer nodes
289+ if (observer_for_input.count (n) != 0 ) {
240290 continue ;
241291 }
242292
243293 // Record all outputs in the values_to_observe - we'll later add observers
244294 // for all values from it.
245295 for (Value* v : n->outputs ()) {
246- if (values_to_skip.count (v) == 0 ) {
296+ if (! values_to_skip.count (v) && valueNeedsToBeQuantized (v) ) {
247297 values_to_observe.push_back (v);
248298 }
249299 if (v->node ()->kind () == prim::CallMethod) {
@@ -284,9 +334,6 @@ void InsertObserversImpl(
284334
285335 // Actually add observer nodes.
286336 for (Value* v : values_to_observe) {
287- if (!v->type ()->isSubtypeOf (TensorType::get ())) {
288- continue ;
289- }
290337 // Skip inserting observer for bias
291338 if (v->node ()->kind () == prim::GetAttr &&
292339 v->node ()->s (c10::attr::name) == " bias" ) {
@@ -843,7 +890,6 @@ graph(%self, %scale, %zero_point, %dtype):
843890 auto matches = findPatternMatches (pattern_graph, *graph);
844891 for (const auto & match : matches) {
845892 auto match_vmap = match.values_map ;
846- auto * weight = match_vmap.at (vmap.at (" weight" ));
847893 auto float_weight = module .get_parameter (" weight" ).variable_data ();
848894 auto scale = toIValue (match_vmap.at (vmap.at (" scale" ))).value ().toDouble ();
849895 auto zero_point =
0 commit comments