@@ -64,8 +64,40 @@ graph(%self, %input):
6464 }
6565}
6666
67- static bool outputsNeedToBeObserved (Node* n) {
68- return n->kind () != prim::Constant;
67+ bool nodeQuantizable (Node* n) {
68+ static std::vector<Symbol> aten_funcs = {
69+ Symbol::aten (" conv2d" ),
70+ Symbol::aten (" linear" )
71+ };
72+ bool is_quantizable = std::find (aten_funcs.begin (), aten_funcs.end (), n->kind ()) != aten_funcs.end ();
73+ static std::vector<std::string> call_funcs = {
74+ " linear" ,
75+ " relu"
76+ };
77+ if (n->kind () == prim::CallFunction) {
78+ auto func_node = n->inputs ()[0 ]->node ();
79+ if (func_node->kind () == prim::Constant) {
80+ is_quantizable |= std::find (call_funcs.begin (), call_funcs.end (), func_node->s (attr::name)) != call_funcs.end ();
81+ }
82+ }
83+ return is_quantizable;
84+ }
85+
86+ bool valueNeedsToBeObserved (Value* v) {
87+ if (!v->type ()->isSubtypeOf (TensorType::get ())) {
88+ return false ;
89+ }
90+ // Check whether producer is quantizable
91+ if (nodeQuantizable (v->node ())) {
92+ return true ;
93+ }
94+ // Check whether user is quantizable
95+ for (const auto & use: v->uses ()) {
96+ if (nodeQuantizable (use.user )) {
97+ return true ;
98+ }
99+ }
100+ return false ;
69101}
70102
71103Node* traverseToQuantNode (Node* dq) {
@@ -209,14 +241,17 @@ void InsertObserversImpl(
209241 // point is the beginning of graph node. This also safe guards against
210242 // observing a potentially mutated value due to some in-place operation
211243 Value* self = graph->inputs ()[0 ];
244+ std::unordered_set<Value*> values_observed;
212245 for (size_t idx = 1 ; idx < method.num_inputs (); ++idx) {
213246 auto & v = graph->inputs ()[idx];
214- if (v->type ()->isSubtypeOf (TensorType::get ()) &&
215- values_to_skip.count (v) == 0 ) {
247+ if (valueNeedsToBeObserved (v) &&
248+ values_to_skip.count (v) == 0 &&
249+ values_observed.count (v) == 0 ) {
216250 if (module_qconfig_map.count (module .module_object ()) == 0 ) {
217251 // the module is added by us, it's an observer module
218252 continue ;
219253 }
254+ values_observed.emplace (v);
220255 auto qconfig = module_qconfig_map.at (module .module_object ());
221256 if (qconfig) {
222257 auto observer_node =
@@ -235,15 +270,18 @@ void InsertObserversImpl(
235270 for (Node* n : b->nodes ()) {
236271 // Skip nodes that we don't need to observe, e.g. 'prim::Constant' or
237272 // observer nodes
238- if (! outputsNeedToBeObserved (n) || observer_for_input.count (n) != 0 ) {
273+ if (observer_for_input.count (n) != 0 ) {
239274 continue ;
240275 }
241276
242277 // Record all outputs in the values_to_observe - we'll later add observers
243278 // for all values from it.
244279 for (Value* v : n->outputs ()) {
245- if (values_to_skip.count (v) == 0 ) {
280+ if (valueNeedsToBeObserved (v) &&
281+ values_to_skip.count (v) == 0 &&
282+ values_observed.count (v) == 0 ) {
246283 values_to_observe.push_back (v);
284+ values_observed.emplace (v);
247285 }
248286 if (v->node ()->kind () == prim::CallMethod) {
249287 // If we find a call to a method of a child module,
@@ -277,9 +315,6 @@ void InsertObserversImpl(
277315
278316 // Actually add observer nodes.
279317 for (Value* v : values_to_observe) {
280- if (!v->type ()->isSubtypeOf (TensorType::get ())) {
281- continue ;
282- }
283318 // Skip inserting observer for bias
284319 if (v->node ()->kind () == prim::GetAttr &&
285320 v->node ()->s (c10::attr::name) == " bias" ) {
0 commit comments