Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
14 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,6 @@ def test_recursive_cse(self):
FileCheck().run(input_str, graph)

@_tmp_donotuse_dont_inline_everything
@unittest.skip("temoprarily disable the test case, it will pass in later PR")
def test_insert_observers(self):
class Observer(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -933,7 +932,6 @@ def get_forward_graph(m):
.run(str(m._c._get_module("conv")._get_method('conv2d_forward').graph))

@_tmp_donotuse_dont_inline_everything
@unittest.skip("temoprarily disable the test case, it will pass in later PR")
def test_insert_observers_child_qconfig(self):
class Observer(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -1004,8 +1002,6 @@ def get_forward(c):
check_observed(get_forward(m._c._get_module('sub')._get_module('linear')).graph)

@_tmp_donotuse_dont_inline_everything
@unittest.skip("temoprarily disable the test since \
I want to put the insert_quant_dequant changes in a separate PR")
def test_insert_observers_skip_values(self):
import torch.nn.functional as F

Expand Down Expand Up @@ -1067,7 +1063,6 @@ def test_module(module, relu_call, num_observers):
test_module(M2, 'prim::CallMethod[name="forward"]', 0)

@_tmp_donotuse_dont_inline_everything
@unittest.skip("temoprarily disable the test")
def test_insert_quant_dequant(self):
class Observer(torch.nn.Module):
def __init__(self):
Expand Down
70 changes: 58 additions & 12 deletions torch/csrc/jit/passes/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <torch/csrc/jit/operator.h>
#include <torch/csrc/jit/subgraph_matcher.h>

#include <algorithm>
#include <stack>

namespace torch {
Expand Down Expand Up @@ -64,8 +65,59 @@ graph(%self, %input):
}
}

static bool outputsNeedToBeObserved(Node* n) {
return n->kind() != prim::Constant;
std::string getFuncName(const c10::QualifiedName& qname) {
const auto& name = qname.qualifiedName();
auto rdot_idx = name.rfind('.');
if (rdot_idx != std::string::npos) {
return name.substr(rdot_idx + 1, name.length());
} else {
return name;
}
}

bool nodeQuantizable(Node* n) {
static std::vector<std::string> call_funcs = {
"conv2d",
"linear",
"relu",
};
std::vector<Symbol> aten_funcs;
std::transform(
call_funcs.begin(),
call_funcs.end(),
std::back_inserter(aten_funcs),
[](const std::string& s) { return Symbol::aten(s); });
bool is_quantizable =
std::find(aten_funcs.begin(), aten_funcs.end(), n->kind()) !=
aten_funcs.end();
if (n->kind() == prim::CallFunction) {
auto func_node = n->inputs()[0]->node();
auto func = func_node->output()->type()->expect<FunctionType>()->function();
auto func_name = getFuncName(func->qualname());
if (func_node->kind() == prim::Constant) {
is_quantizable |=
std::find(call_funcs.begin(), call_funcs.end(), func_name) !=
call_funcs.end();
}
}
return is_quantizable;
}

bool valueNeedsToBeQuantized(Value* v) {
if (!v->type()->isSubtypeOf(TensorType::get())) {
return false;
}
// Check whether producer is quantizable
if (nodeQuantizable(v->node())) {
return true;
}
// Check whether user is quantizable
for (const auto& use : v->uses()) {
if (nodeQuantizable(use.user)) {
return true;
}
}
return false;
}

Node* traverseToQuantNode(Node* dq) {
Expand Down Expand Up @@ -216,8 +268,7 @@ void InsertObserversImpl(
Value* self = graph->inputs()[0];
for (size_t idx = 1; idx < method.num_inputs(); ++idx) {
auto& v = graph->inputs()[idx];
if (v->type()->isSubtypeOf(TensorType::get()) &&
values_to_skip.count(v) == 0) {
if (!values_to_skip.count(v) && valueNeedsToBeQuantized(v)) {
auto qconfig = module_qconfig_map.at(module.module_object());
if (qconfig) {
auto observer_node =
Expand All @@ -234,16 +285,15 @@ void InsertObserversImpl(
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
// Skip nodes that we don't need to observe, e.g. 'prim::Constant' or
// observer nodes
if (!outputsNeedToBeObserved(n) || observer_for_input.count(n) != 0) {
// Skip observer nodes
if (observer_for_input.count(n) != 0) {
continue;
}

// Record all outputs in the values_to_observe - we'll later add observers
// for all values from it.
for (Value* v : n->outputs()) {
if (values_to_skip.count(v) == 0) {
if (!values_to_skip.count(v) && valueNeedsToBeQuantized(v)) {
values_to_observe.push_back(v);
}
if (v->node()->kind() == prim::CallMethod) {
Expand Down Expand Up @@ -284,9 +334,6 @@ void InsertObserversImpl(

// Actually add observer nodes.
for (Value* v : values_to_observe) {
if (!v->type()->isSubtypeOf(TensorType::get())) {
continue;
}
// Skip inserting observer for bias
if (v->node()->kind() == prim::GetAttr &&
v->node()->s(c10::attr::name) == "bias") {
Expand Down Expand Up @@ -843,7 +890,6 @@ graph(%self, %scale, %zero_point, %dtype):
auto matches = findPatternMatches(pattern_graph, *graph);
for (const auto& match : matches) {
auto match_vmap = match.values_map;
auto* weight = match_vmap.at(vmap.at("weight"));
auto float_weight = module.get_parameter("weight").variable_data();
auto scale = toIValue(match_vmap.at(vmap.at("scale"))).value().toDouble();
auto zero_point =
Expand Down