@@ -154,8 +154,40 @@ bool valueObservedInAnotherMethod(
154154 return false ;
155155}
156156
157- static bool outputsNeedToBeObserved (Node* n) {
158- return n->kind () != prim::Constant;
157+ bool nodeQuantizable (Node* n) {
158+ static std::vector<Symbol> aten_funcs = {
159+ Symbol::aten (" conv2d" ),
160+ Symbol::aten (" linear" )
161+ };
162+ bool is_quantizable = std::find (aten_funcs.begin (), aten_funcs.end (), n->kind ()) != aten_funcs.end ();
163+ static std::vector<std::string> call_funcs = {
164+ " linear" ,
165+ " relu"
166+ };
167+ if (n->kind () == prim::CallFunction) {
168+ auto func_node = n->inputs ()[0 ]->node ();
169+ if (func_node->kind () == prim::Constant) {
170+ is_quantizable |= std::find (call_funcs.begin (), call_funcs.end (), func_node->s (attr::name)) != call_funcs.end ();
171+ }
172+ }
173+ return is_quantizable;
174+ }
175+
176+ bool valueNeedsToBeObserved (Value* v) {
177+ if (!v->type ()->isSubtypeOf (TensorType::get ())) {
178+ return false ;
179+ }
180+ // Check whether producer is quantizable
181+ if (nodeQuantizable (v->node ())) {
182+ return true ;
183+ }
184+ // Check whether user is quantizable
185+ for (const auto & use: v->uses ()) {
186+ if (nodeQuantizable (use.user )) {
187+ return true ;
188+ }
189+ }
190+ return false ;
159191}
160192
161193Node* traverseToQuantNode (Node* dq) {
@@ -301,15 +333,18 @@ void InsertObserversImpl(
301333 // point is the beginning of graph node. This also safe guards against
302334 // observing a potentially mutated value due to some in-place operation
303335 Value* self = graph->inputs ()[0 ];
336+ std::unordered_set<Value*> values_observed;
304337 for (size_t idx = 1 ; idx < method.num_inputs (); ++idx) {
305338 auto & v = graph->inputs ()[idx];
306- if (v-> type ()-> isSubtypeOf ( TensorType::get () ) &&
339+ if (valueNeedsToBeObserved (v ) &&
307340 values_to_skip.count (v) == 0 &&
341+ values_observed.count (v) == 0 &&
308342 !valueObservedInAnotherMethod (v, self, module , child_module_set)) {
309343 if (module_qconfig_map.count (module .module_object ()) == 0 ) {
310344 // the module is added by us, it's an observer module
311345 continue ;
312346 }
347+ values_observed.emplace (v);
313348 auto qconfig = module_qconfig_map.at (module .module_object ());
314349 if (qconfig) {
315350 auto observer_node =
@@ -328,16 +363,19 @@ void InsertObserversImpl(
328363 for (Node* n : b->nodes ()) {
329364 // Skip nodes that we don't need to observe, e.g. 'prim::Constant' or
330365 // observer nodes
331- if (! outputsNeedToBeObserved (n) || observer_for_input.count (n) != 0 ) {
366+ if (observer_for_input.count (n) != 0 ) {
332367 continue ;
333368 }
334369
335370 // Record all outputs in the values_to_observe - we'll later add observers
336371 // for all values from it.
337372 for (Value* v : n->outputs ()) {
338- if (values_to_skip.count (v) == 0 &&
373+ if (valueNeedsToBeObserved (v) &&
374+ values_to_skip.count (v) == 0 &&
375+ values_observed.count (v) == 0 &&
339376 !valueObservedInAnotherMethod (v, self, module , child_module_set)) {
340377 values_to_observe.push_back (v);
378+ values_observed.emplace (v);
341379 }
342380 if (v->node ()->kind () == prim::CallMethod) {
343381 // If we find a call to a method of a child module,
@@ -381,9 +419,6 @@ void InsertObserversImpl(
381419
382420 // Actually add observer nodes.
383421 for (Value* v : values_to_observe) {
384- if (!v->type ()->isSubtypeOf (TensorType::get ())) {
385- continue ;
386- }
387422 // Skip inserting observer for bias
388423 if (v->node ()->kind () == prim::GetAttr &&
389424 v->node ()->s (c10::attr::name) == " bias" ) {
0 commit comments