Skip to content

Commit 6d3ac7f

Browse files
jerryzh168facebook-github-bot
authored andcommitted
use whitelist for selecting observed values (#25974)
Summary: Pull Request resolved: #25974 Previously we observe all the Tensor values, but what we want is actually observing only the ones that can be quantized. Test Plan: python test/test_jit.py python test/test_quantizer.py Imported from OSS Differential Revision: D17348986 fbshipit-source-id: 55be0d73862a0e7eb1e7fd882d16e0d830618b63
1 parent d250f01 commit 6d3ac7f

File tree

2 files changed

+58
-17
lines changed

2 files changed

+58
-17
lines changed

test/test_jit.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,6 @@ def test_recursive_cse(self):
876876
FileCheck().run(input_str, graph)
877877

878878
@_tmp_donotuse_dont_inline_everything
879-
@unittest.skip("temoprarily disable the test case, it will pass in later PR")
880879
def test_insert_observers(self):
881880
class Observer(torch.nn.Module):
882881
def __init__(self):
@@ -933,7 +932,6 @@ def get_forward_graph(m):
933932
.run(str(m._c._get_module("conv")._get_method('conv2d_forward').graph))
934933

935934
@_tmp_donotuse_dont_inline_everything
936-
@unittest.skip("temoprarily disable the test case, it will pass in later PR")
937935
def test_insert_observers_child_qconfig(self):
938936
class Observer(torch.nn.Module):
939937
def __init__(self):
@@ -1004,8 +1002,6 @@ def get_forward(c):
10041002
check_observed(get_forward(m._c._get_module('sub')._get_module('linear')).graph)
10051003

10061004
@_tmp_donotuse_dont_inline_everything
1007-
@unittest.skip("temoprarily disable the test since \
1008-
I want to put the insert_quant_dequant changes in a separate PR")
10091005
def test_insert_observers_skip_values(self):
10101006
import torch.nn.functional as F
10111007

@@ -1067,7 +1063,6 @@ def test_module(module, relu_call, num_observers):
10671063
test_module(M2, 'prim::CallMethod[name="forward"]', 0)
10681064

10691065
@_tmp_donotuse_dont_inline_everything
1070-
@unittest.skip("temoprarily disable the test")
10711066
def test_insert_quant_dequant(self):
10721067
class Observer(torch.nn.Module):
10731068
def __init__(self):

torch/csrc/jit/passes/quantization.cpp

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <torch/csrc/jit/operator.h>
1111
#include <torch/csrc/jit/subgraph_matcher.h>
1212

13+
#include <algorithm>
1314
#include <stack>
1415

1516
namespace torch {
@@ -64,8 +65,59 @@ graph(%self, %input):
6465
}
6566
}
6667

67-
static bool outputsNeedToBeObserved(Node* n) {
68-
return n->kind() != prim::Constant;
68+
std::string getFuncName(const c10::QualifiedName& qname) {
69+
const auto& name = qname.qualifiedName();
70+
auto rdot_idx = name.rfind('.');
71+
if (rdot_idx != std::string::npos) {
72+
return name.substr(rdot_idx + 1, name.length());
73+
} else {
74+
return name;
75+
}
76+
}
77+
78+
bool nodeQuantizable(Node* n) {
79+
static std::vector<std::string> call_funcs = {
80+
"conv2d",
81+
"linear",
82+
"relu",
83+
};
84+
std::vector<Symbol> aten_funcs;
85+
std::transform(
86+
call_funcs.begin(),
87+
call_funcs.end(),
88+
std::back_inserter(aten_funcs),
89+
[](const std::string& s) { return Symbol::aten(s); });
90+
bool is_quantizable =
91+
std::find(aten_funcs.begin(), aten_funcs.end(), n->kind()) !=
92+
aten_funcs.end();
93+
if (n->kind() == prim::CallFunction) {
94+
auto func_node = n->inputs()[0]->node();
95+
auto func = func_node->output()->type()->expect<FunctionType>()->function();
96+
auto func_name = getFuncName(func->qualname());
97+
if (func_node->kind() == prim::Constant) {
98+
is_quantizable |=
99+
std::find(call_funcs.begin(), call_funcs.end(), func_name) !=
100+
call_funcs.end();
101+
}
102+
}
103+
return is_quantizable;
104+
}
105+
106+
bool valueNeedsToBeQuantized(Value* v) {
107+
if (!v->type()->isSubtypeOf(TensorType::get())) {
108+
return false;
109+
}
110+
// Check whether producer is quantizable
111+
if (nodeQuantizable(v->node())) {
112+
return true;
113+
}
114+
// Check whether user is quantizable
115+
for (const auto& use : v->uses()) {
116+
if (nodeQuantizable(use.user)) {
117+
return true;
118+
}
119+
}
120+
return false;
69121
}
70122

71123
Node* traverseToQuantNode(Node* dq) {
@@ -216,8 +268,7 @@ void InsertObserversImpl(
216268
Value* self = graph->inputs()[0];
217269
for (size_t idx = 1; idx < method.num_inputs(); ++idx) {
218270
auto& v = graph->inputs()[idx];
219-
if (v->type()->isSubtypeOf(TensorType::get()) &&
220-
values_to_skip.count(v) == 0) {
271+
if (!values_to_skip.count(v) && valueNeedsToBeQuantized(v)) {
221272
auto qconfig = module_qconfig_map.at(module.module_object());
222273
if (qconfig) {
223274
auto observer_node =
@@ -234,16 +285,15 @@ void InsertObserversImpl(
234285
Block* b = blocks_to_visit.top();
235286
blocks_to_visit.pop();
236287
for (Node* n : b->nodes()) {
237-
// Skip nodes that we don't need to observe, e.g. 'prim::Constant' or
238-
// observer nodes
239-
if (!outputsNeedToBeObserved(n) || observer_for_input.count(n) != 0) {
288+
// Skip observer nodes
289+
if (observer_for_input.count(n) != 0) {
240290
continue;
241291
}
242292

243293
// Record all outputs in the values_to_observe - we'll later add observers
244294
// for all values from it.
245295
for (Value* v : n->outputs()) {
246-
if (values_to_skip.count(v) == 0) {
296+
if (!values_to_skip.count(v) && valueNeedsToBeQuantized(v)) {
247297
values_to_observe.push_back(v);
248298
}
249299
if (v->node()->kind() == prim::CallMethod) {
@@ -284,9 +334,6 @@ void InsertObserversImpl(
284334

285335
// Actually add observer nodes.
286336
for (Value* v : values_to_observe) {
287-
if (!v->type()->isSubtypeOf(TensorType::get())) {
288-
continue;
289-
}
290337
// Skip inserting observer for bias
291338
if (v->node()->kind() == prim::GetAttr &&
292339
v->node()->s(c10::attr::name) == "bias") {
@@ -843,7 +890,6 @@ graph(%self, %scale, %zero_point, %dtype):
843890
auto matches = findPatternMatches(pattern_graph, *graph);
844891
for (const auto& match : matches) {
845892
auto match_vmap = match.values_map;
846-
auto* weight = match_vmap.at(vmap.at("weight"));
847893
auto float_weight = module.get_parameter("weight").variable_data();
848894
auto scale = toIValue(match_vmap.at(vmap.at("scale"))).value().toDouble();
849895
auto zero_point =

0 commit comments

Comments
 (0)