Skip to content

Commit d58d559

Browse files
committed
[quant][graphmode] use whitelist for selecting observed values
Summary: Previously we observe all the Tensor values, but what we want is actually observing only the ones that can be quantized. Test Plan: python test/test_jit.py python test/test_quantizer.py Reviewers: pt1quant Subscribers: Tasks: Tags: ghstack-source-id: 36b6999 Pull Request resolved: #25974
1 parent adc88b7 commit d58d559

File tree

2 files changed

+43
-9
lines changed

2 files changed

+43
-9
lines changed

test/test_jit.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1039,7 +1039,6 @@ def get_forward(m):
10391039
def test_module(module, relu_call, num_observers):
10401040
m = torch.jit.script(module())
10411041
observer = torch.jit.script(Observer())
1042-
10431042
torch._C._jit_pass_constant_propagation(get_forward(m).graph)
10441043
torch._C._jit_pass_constant_propagation(m._c._get_module('conv')._get_method('conv2d_forward').graph)
10451044
qconfig_dict = {

torch/csrc/jit/passes/quantization.cpp

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,40 @@ bool valueObservedInAnotherMethod(
154154
return false;
155155
}
156156

157-
static bool outputsNeedToBeObserved(Node* n) {
158-
return n->kind() != prim::Constant;
157+
bool nodeQuantizable(Node* n) {
158+
static std::vector<Symbol> aten_funcs = {
159+
Symbol::aten("conv2d"),
160+
Symbol::aten("linear")
161+
};
162+
bool is_quantizable = std::find(aten_funcs.begin(), aten_funcs.end(), n->kind()) != aten_funcs.end();
163+
static std::vector<std::string> call_funcs = {
164+
"linear",
165+
"relu"
166+
};
167+
if (n->kind() == prim::CallFunction) {
168+
auto func_node = n->inputs()[0]->node();
169+
if (func_node->kind() == prim::Constant) {
170+
is_quantizable |= std::find(call_funcs.begin(), call_funcs.end(), func_node->s(attr::name)) != call_funcs.end();
171+
}
172+
}
173+
return is_quantizable;
174+
}
175+
176+
bool valueNeedsToBeObserved(Value* v) {
177+
if (!v->type()->isSubtypeOf(TensorType::get())) {
178+
return false;
179+
}
180+
// Check whether producer is quantizable
181+
if (nodeQuantizable(v->node())) {
182+
return true;
183+
}
184+
// Check whether user is quantizable
185+
for (const auto& use: v->uses()) {
186+
if (nodeQuantizable(use.user)) {
187+
return true;
188+
}
189+
}
190+
return false;
159191
}
160192

161193
Node* traverseToQuantNode(Node* dq) {
@@ -301,15 +333,18 @@ void InsertObserversImpl(
301333
// point is the beginning of graph node. This also safe guards against
302334
// observing a potentially mutated value due to some in-place operation
303335
Value* self = graph->inputs()[0];
336+
std::unordered_set<Value*> values_observed;
304337
for (size_t idx = 1; idx < method.num_inputs(); ++idx) {
305338
auto& v = graph->inputs()[idx];
306-
if (v->type()->isSubtypeOf(TensorType::get()) &&
339+
if (valueNeedsToBeObserved(v) &&
307340
values_to_skip.count(v) == 0 &&
341+
values_observed.count(v) == 0 &&
308342
!valueObservedInAnotherMethod(v, self, module, child_module_set)) {
309343
if (module_qconfig_map.count(module.module_object()) == 0) {
310344
// the module is added by us, it's an observer module
311345
continue;
312346
}
347+
values_observed.emplace(v);
313348
auto qconfig = module_qconfig_map.at(module.module_object());
314349
if (qconfig) {
315350
auto observer_node =
@@ -328,16 +363,19 @@ void InsertObserversImpl(
328363
for (Node* n : b->nodes()) {
329364
// Skip nodes that we don't need to observe, e.g. 'prim::Constant' or
330365
// observer nodes
331-
if (!outputsNeedToBeObserved(n) || observer_for_input.count(n) != 0) {
366+
if (observer_for_input.count(n) != 0) {
332367
continue;
333368
}
334369

335370
// Record all outputs in the values_to_observe - we'll later add observers
336371
// for all values from it.
337372
for (Value* v : n->outputs()) {
338-
if (values_to_skip.count(v) == 0 &&
373+
if (valueNeedsToBeObserved(v) &&
374+
values_to_skip.count(v) == 0 &&
375+
values_observed.count(v) == 0 &&
339376
!valueObservedInAnotherMethod(v, self, module, child_module_set)) {
340377
values_to_observe.push_back(v);
378+
values_observed.emplace(v);
341379
}
342380
if (v->node()->kind() == prim::CallMethod) {
343381
// If we find a call to a method of a child module,
@@ -381,9 +419,6 @@ void InsertObserversImpl(
381419

382420
// Actually add observer nodes.
383421
for (Value* v : values_to_observe) {
384-
if (!v->type()->isSubtypeOf(TensorType::get())) {
385-
continue;
386-
}
387422
// Skip inserting observer for bias
388423
if (v->node()->kind() == prim::GetAttr &&
389424
v->node()->s(c10::attr::name) == "bias") {

0 commit comments

Comments
 (0)