Skip to content

Commit 34a1c47

Browse files
committed
[quant][graphmode] use whitelist for selecting observed values
Summary: Previously we observe all the Tensor values, but what we want is actually observing only the ones that can be quantized. Test Plan: python test/test_jit.py python test/test_quantizer.py Reviewers: pt1quant Subscribers: Tasks: Tags: ghstack-source-id: 68e19e6 Pull Request resolved: #25974
1 parent 563de9c commit 34a1c47

File tree

2 files changed

+44
-14
lines changed

2 files changed

+44
-14
lines changed

test/test_jit.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,6 @@ def test_recursive_cse(self):
876876
FileCheck().run(input_str, graph)
877877

878878
@_tmp_donotuse_dont_inline_everything
879-
@unittest.skip("temoprarily disable the test case, it will pass in later PR")
880879
def test_insert_observers(self):
881880
class Observer(torch.nn.Module):
882881
def __init__(self):
@@ -933,7 +932,6 @@ def get_forward_graph(m):
933932
.run(str(m._c._get_module("conv")._get_method('conv2d_forward').graph))
934933

935934
@_tmp_donotuse_dont_inline_everything
936-
@unittest.skip("temoprarily disable the test case, it will pass in later PR")
937935
def test_insert_observers_child_qconfig(self):
938936
class Observer(torch.nn.Module):
939937
def __init__(self):
@@ -1004,8 +1002,6 @@ def get_forward(c):
10041002
check_observed(get_forward(m._c._get_module('sub')._get_module('linear')).graph)
10051003

10061004
@_tmp_donotuse_dont_inline_everything
1007-
@unittest.skip("temoprarily disable the test since \
1008-
I want to put the insert_quant_dequant changes in a separate PR")
10091005
def test_insert_observers_skip_values(self):
10101006
import torch.nn.functional as F
10111007

@@ -1043,7 +1039,6 @@ def get_forward(m):
10431039
def test_module(module, relu_call, num_observers):
10441040
m = torch.jit.script(module())
10451041
observer = torch.jit.script(Observer())
1046-
10471042
torch._C._jit_pass_constant_propagation(get_forward(m).graph)
10481043
torch._C._jit_pass_constant_propagation(m._c._get_module('conv')._get_method('conv2d_forward').graph)
10491044
qconfig_dict = {

torch/csrc/jit/passes/quantization.cpp

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,40 @@ graph(%self, %input):
6464
}
6565
}
6666

67-
static bool outputsNeedToBeObserved(Node* n) {
68-
return n->kind() != prim::Constant;
67+
bool nodeQuantizable(Node* n) {
68+
static std::vector<Symbol> aten_funcs = {
69+
Symbol::aten("conv2d"),
70+
Symbol::aten("linear")
71+
};
72+
bool is_quantizable = std::find(aten_funcs.begin(), aten_funcs.end(), n->kind()) != aten_funcs.end();
73+
static std::vector<std::string> call_funcs = {
74+
"linear",
75+
"relu"
76+
};
77+
if (n->kind() == prim::CallFunction) {
78+
auto func_node = n->inputs()[0]->node();
79+
if (func_node->kind() == prim::Constant) {
80+
is_quantizable |= std::find(call_funcs.begin(), call_funcs.end(), func_node->s(attr::name)) != call_funcs.end();
81+
}
82+
}
83+
return is_quantizable;
84+
}
85+
86+
bool valueNeedsToBeObserved(Value* v) {
87+
if (!v->type()->isSubtypeOf(TensorType::get())) {
88+
return false;
89+
}
90+
// Check whether producer is quantizable
91+
if (nodeQuantizable(v->node())) {
92+
return true;
93+
}
94+
// Check whether user is quantizable
95+
for (const auto& use: v->uses()) {
96+
if (nodeQuantizable(use.user)) {
97+
return true;
98+
}
99+
}
100+
return false;
69101
}
70102

71103
Node* traverseToQuantNode(Node* dq) {
@@ -209,14 +241,17 @@ void InsertObserversImpl(
209241
// point is the beginning of graph node. This also safe guards against
210242
// observing a potentially mutated value due to some in-place operation
211243
Value* self = graph->inputs()[0];
244+
std::unordered_set<Value*> values_observed;
212245
for (size_t idx = 1; idx < method.num_inputs(); ++idx) {
213246
auto& v = graph->inputs()[idx];
214-
if (v->type()->isSubtypeOf(TensorType::get()) &&
215-
values_to_skip.count(v) == 0) {
247+
if (valueNeedsToBeObserved(v) &&
248+
values_to_skip.count(v) == 0 &&
249+
values_observed.count(v) == 0) {
216250
if (module_qconfig_map.count(module.module_object()) == 0) {
217251
// the module is added by us, it's an observer module
218252
continue;
219253
}
254+
values_observed.emplace(v);
220255
auto qconfig = module_qconfig_map.at(module.module_object());
221256
if (qconfig) {
222257
auto observer_node =
@@ -235,15 +270,18 @@ void InsertObserversImpl(
235270
for (Node* n : b->nodes()) {
236271
// Skip nodes that we don't need to observe, e.g. 'prim::Constant' or
237272
// observer nodes
238-
if (!outputsNeedToBeObserved(n) || observer_for_input.count(n) != 0) {
273+
if (observer_for_input.count(n) != 0) {
239274
continue;
240275
}
241276

242277
// Record all outputs in the values_to_observe - we'll later add observers
243278
// for all values from it.
244279
for (Value* v : n->outputs()) {
245-
if (values_to_skip.count(v) == 0) {
280+
if (valueNeedsToBeObserved(v) &&
281+
values_to_skip.count(v) == 0 &&
282+
values_observed.count(v) == 0) {
246283
values_to_observe.push_back(v);
284+
values_observed.emplace(v);
247285
}
248286
if (v->node()->kind() == prim::CallMethod) {
249287
// If we find a call to a method of a child module,
@@ -277,9 +315,6 @@ void InsertObserversImpl(
277315

278316
// Actually add observer nodes.
279317
for (Value* v : values_to_observe) {
280-
if (!v->type()->isSubtypeOf(TensorType::get())) {
281-
continue;
282-
}
283318
// Skip inserting observer for bias
284319
if (v->node()->kind() == prim::GetAttr &&
285320
v->node()->s(c10::attr::name) == "bias") {

0 commit comments

Comments
 (0)