Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions test/quantization/test_quantize_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,11 +2251,12 @@ def __init__(self):
self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3)
self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3)
self.dropout = torch.nn.Dropout()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 3)
self.relu = torch.nn.ReLU()

def forward(self, x):
x = self.conv(x)
x = self.conv1(x)
# add_scalar
x = x + 3
# mul_scalar
Expand Down Expand Up @@ -2316,7 +2317,7 @@ def forward(self, x):
y = []
y.append(x)
x, _ = y
x = self.conv(x)
x = self.conv2(x)
return x

data = torch.rand(1, 3, 10, 10)
Expand All @@ -2336,14 +2337,18 @@ def forward(self, x):
# observers and also successfully fused two quantized::conv2d
# patterns
# one quantize_per_tensor for input
# TODO: the checks are problematic, we need to split all checks
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True) \
.check_count("quantized::conv2d", 2, exactly=True) \
.check("aten::dequantize") \
.run(m.graph)

FileCheck().check_count("quantized::conv2d(", 2, exactly=True) \
.run(m.graph)

FileCheck().check_count("aten::dequantize", 1, exactly=True) \
.run(m.graph)

FileCheck().check("quantized::add_scalar") \
.check("quantized::mul_scalar") \
.check("aten::append(") \
.run(m.graph)

def test_general_value_ops(self):
Expand Down
50 changes: 50 additions & 0 deletions torch/csrc/jit/passes/quantization/finalize.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/csrc/jit/passes/quantization/finalize.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
#include <torch/csrc/jit/passes/prepack_folding.h>
#include <torch/csrc/jit/passes/quantization/quantization_patterns.h>

Expand All @@ -14,6 +15,9 @@ struct PatternReplaceInfo {
};

namespace {

using graph_rewrite_helper::PatternInfo;

void insertPrepackUnpackForLinear(std::shared_ptr<Graph>& graph) {
std::string linear_with_quant = R"(
graph(%a_dequant, %w_quant, %b):
Expand Down Expand Up @@ -109,6 +113,50 @@ graph(%a_dequant, %w_quant, %b, %stride, %padding, %dilation, %groups):
}
}

void rewriteListAddToAppend(std::shared_ptr<Graph>& graph) {
GRAPH_DUMP("Before restore append", graph);
std::string list_add = R"IR(
graph(%list, %x):
%x_list : Tensor[] = prim::ListConstruct(%x)
%result : Tensor[] = aten::add(%list, %x_list)
return (%result) )IR";

/* Rewrite the above pattern to
std::string append = R"IR(
graph(%list, %x):
%ignore : Tensor[] = aten::append(%list, %x)
return (%list) )IR";
this is not supported by subgraph rewriter, so we'll do
this manually.
*/

const PatternInfo& list_add_pattern_info =
PatternInfo::parse_from_str(list_add);
const Graph& list_add_graph = *list_add_pattern_info.pattern_graph;
const auto& list_add_vmap = list_add_pattern_info.vmap;
const auto& matches = findPatternMatches(list_add_graph, *graph);
for (const auto& match : matches) {
Value* result = match.values_map.at(list_add_vmap.at("result"));
Node* list_add_node = result->node();
Value* list = list_add_node->input(0);
Value* x_list = list_add_node->input(1);

Node* x_list_node = x_list->node();
Value* x = x_list_node->input(0);

result->replaceAllUsesWith(list);
WithInsertPoint ins(list_add_node);
Node* append_node = graph->create(Symbol::aten("append"), {list, x});
append_node->output()->setType(ListType::ofTensors());
graph->insertNode(append_node);
for (Node* n : {list_add_node, x_list_node}) {
n->removeAllInputs();
n->destroy();
}
}
GRAPH_DUMP("After restore append", graph);
}

} // namespace

void QuantFusion(std::shared_ptr<Graph>& graph, QuantType quant_type) {
Expand Down Expand Up @@ -153,6 +201,8 @@ void FoldQuantizedPrepackingOps(Module& module) {

Module Finalize(Module& module, QuantType quant_type) {
auto graph = module.get_method("forward").graph();
GRAPH_DUMP("Before rewrite list add to append:", graph);
rewriteListAddToAppend(graph);
InsertPrepackUnpack(graph);
GRAPH_DUMP("Before QuantFusion:", graph);
QuantFusion(graph, quant_type);
Expand Down
36 changes: 32 additions & 4 deletions torch/csrc/jit/passes/quantization/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ AtenFuncArgs _observe_inputs_aten_func = {};
CallFuncArgs _observe_inputs_call_func = {{"batch_norm", 1}};

// Aten functions for getting tensor information
std::vector<std::string> _tensor_info_funcs = {"size", "len", "dim"};
std::vector<std::string> _tensor_info_funcs = {"size", "len", "dim", "numel"};

// Aten functions whose output will be quantized or not quantized depending
// on input tensor
Expand Down Expand Up @@ -309,10 +309,16 @@ std::vector<Value*> getPassThroughInputs(Value* v) {
inputs.push_back(v);
}
return inputs;
} else if (n->kind() == Symbol::aten("append")) {
// notice that append is an op that changes input inplace
return {n->input(0), n->input(1)};
} else if (isListAdd(n)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the rule for list add? Do we assume that all the tensors in the list have a dequantize op prior to the add?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the transformed "append", as shown in the description. we will check if the inputs are produced with dequantize to make sure the inputs are quantized

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the rule of list add:

%y = aten::add(%list, %x)

we'll check if %list is empty list, if it is, then the pass through list for %y is {%x}, otherwise, it is {%list, %x}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you check if the existing list has only dequantized tensors?. I see how you can do it for the input %x, dont follow how the check is done for %list

Copy link
Contributor Author

@jerryzh168 jerryzh168 Jul 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two cases for %list:
1.
one case is same as %x

%list = aten::dequantize(%list_quant)
%result = aten::add(%list, %x_list)

we check if both %list and %x_list is produced by dequantize or not
2.
another case is when list is empty, it can be considered as containing quantized tensors

%list : Tensor[] = prim::ListConstruct()
%result = aten::add(%list, %x_list)

in this case we only need to check if %x_list is produced by dequantize.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My question is on how we check if %list is produced by dequantize or not. Do we iterate over all elements in the list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no we don't, we just check if %list is produced by dequantize or not

// We need to propagate dequantize of n->input(0) if it is
// not an empty list
if (isEmptyList(n->input(0)->node())) {
return {n->input(1)};
} else {
return {n->input(0), n->input(1)};
}
}

return {};
}

Expand Down Expand Up @@ -413,6 +419,28 @@ bool isBinaryOpWithScalarInput(Node* n) {
return isPropagateQuantBinaryOp(n) && isScalar(n->input(1));
}

bool isListAdd(Node* n) {
return n->kind() == Symbol::aten("add") && n->inputs().size() == 2 &&
n->outputs().size() == 1 &&
n->output()->type()->isSubtypeOf(ListType::ofTensors()) &&
n->input(0)->type()->isSubtypeOf(ListType::ofTensors()) &&
n->input(1)->type()->isSubtypeOf(ListType::ofTensors());
}

bool isEmptyList(Node* n) {
if (n->outputs().size() != 1) {
return false;
}
bool is_empty_tensor_list_node = n->kind() == prim::ListConstruct &&
n->inputs().size() == 0 &&
n->output()->type()->isSubtypeOf(ListType::ofTensors());
auto iv = toIValue(n->output());
bool is_empty_tensor_list_constant = iv.has_value() && iv->isList() &&
iv->toList().size() == 0 &&
n->output()->type()->isSubtypeOf(ListType::ofTensors());
return is_empty_tensor_list_node || is_empty_tensor_list_constant;
}

c10::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(Node* n) {
static std::vector<NodeKind> fixed_qparam_funcs;
std::transform(
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/passes/quantization/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ TORCH_API bool isPropagateQuantOp(Node* n);
// quantized::{op}_scalar
TORCH_API bool isBinaryOpWithScalarInput(Node* n);

// Check if the node is a aten::add with list inputs
bool isListAdd(Node* n);

// Check if the node is a empty list construct node
bool isEmptyList(Node* n);

TORCH_API c10::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams(
Node* n);

Expand Down
42 changes: 42 additions & 0 deletions torch/csrc/jit/passes/quantization/insert_observers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,46 @@ void InsertObserversHelper::fillBoundaryValueMap(
}
}

void makeAppendNonInplace(std::shared_ptr<Graph>& graph) {
std::string append_pattern = R"IR(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you elaborate more on why directly supporting append is not possible?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All current ops including inplace ops assumes that the output will be consumed by the following ops. To break this assumption, we'll need to introduce substantial changes/hacks, I don't think it's worth the effort

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, if there are perf problems we can easily add a pass to change the add to append in the end

graph(%list, %x):
%ignore : Tensor[] = aten::append(%list, %x)
return (%ignore) )IR";

/* Rewrite the above pattern to
std::string append_replacement = R"IR(
graph(%list, %x):
%x_list : Tensor[] = prim::ListConstruct(%x)
%result : Tensor[] = aten::add(%list, %x_list)
return (%result) )IR";
this is not supported by subgraph rewriter, so we'll do
this manually.
*/

GRAPH_DUMP("Before replace append", graph);
const PatternInfo& append_pattern_info =
PatternInfo::parse_from_str(append_pattern);
const Graph& append_graph = *append_pattern_info.pattern_graph;
const auto& append_vmap = append_pattern_info.vmap;
const auto& matches = findPatternMatches(append_graph, *graph);
for (const auto& match : matches) {
auto append_node = match.values_map.at(append_vmap.at("ignore"))->node();
Value* list_val = append_node->input(0);
Value* x = append_node->input(1);
WithInsertPoint ins(append_node);
Node* x_list_node = graph->createList(TensorType::get(), {x});
graph->insertNode(x_list_node);
Node* add_node =
graph->create(Symbol::aten("add"), {list_val, x_list_node->output()});
graph->insertNode(add_node);
add_node->output()->setType(ListType::ofTensors());
list_val->replaceAllUsesAfterNodeWith(add_node, add_node->output());
append_node->removeAllInputs();
append_node->destroy();
}
GRAPH_DUMP("After replace append", graph);
}

void InsertObserversHelper::preprocess(
Module& module,
const std::string& method_name) {
Expand All @@ -1116,6 +1156,7 @@ void InsertObserversHelper::preprocess(
// fuse decomposed linear into aten::linear
FuseLinear(graph);
replaceConvolutionWithAtenConv(graph);
makeAppendNonInplace(graph);
}

void InsertObserversHelper::analyze(
Expand Down Expand Up @@ -1520,6 +1561,7 @@ void InsertObserversHelper::propagateObservedProperty(
observed_values_.count(v) || block_observed_values.count(v);
}
if (all_observed) {
GRAPH_DEBUG("Pass through observed property in node:", *output->node());
// This is to propagate observed property through
// all ops that doesn't require observation
block_observed_values.insert(output);
Expand Down
6 changes: 6 additions & 0 deletions torch/csrc/jit/passes/quantization/insert_quant_dequant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,12 @@ c10::optional<std::vector<Value*>> getDequantizedInputs(Value* output) {
// point
bool is_dequantized = true;
for (auto* input : inputs) {
GRAPH_DEBUG(
"checking if input:",
input->debugName(),
" in node:",
*input->node(),
"is quantized");
is_dequantized &= input->node()->kind() == Symbol::aten("dequantize");
}
if (is_dequantized) {
Expand Down