-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[quant][graphmode] Support quantization for aten::apend
#40743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
32601ef
1cddb30
a852b87
454fe3c
cc74dd0
0563e22
a13e0a1
9264704
9451edb
86b26ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1100,6 +1100,46 @@ void InsertObserversHelper::fillBoundaryValueMap( | |
| } | ||
| } | ||
|
|
||
| void makeAppendNonInplace(std::shared_ptr<Graph>& graph) { | ||
| std::string append_pattern = R"IR( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you elaborate more on why directly supporting append is not possible?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All current ops including inplace ops assumes that the output will be consumed by the following ops. To break this assumption, we'll need to introduce substantial changes/hacks, I don't think it's worth the effort
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. btw, if there are perf problems we can easily add a pass to change the add to append in the end |
||
| graph(%list, %x): | ||
| %ignore : Tensor[] = aten::append(%list, %x) | ||
| return (%ignore) )IR"; | ||
|
|
||
| /* Rewrite the above pattern to | ||
| std::string append_replacement = R"IR( | ||
| graph(%list, %x): | ||
| %x_list : Tensor[] = prim::ListConstruct(%x) | ||
| %result : Tensor[] = aten::add(%list, %x_list) | ||
| return (%result) )IR"; | ||
| this is not supported by subgraph rewriter, so we'll do | ||
| this manually. | ||
| */ | ||
|
|
||
| GRAPH_DUMP("Before replace append", graph); | ||
| const PatternInfo& append_pattern_info = | ||
| PatternInfo::parse_from_str(append_pattern); | ||
| const Graph& append_graph = *append_pattern_info.pattern_graph; | ||
| const auto& append_vmap = append_pattern_info.vmap; | ||
| const auto& matches = findPatternMatches(append_graph, *graph); | ||
| for (const auto& match : matches) { | ||
| auto append_node = match.values_map.at(append_vmap.at("ignore"))->node(); | ||
| Value* list_val = append_node->input(0); | ||
| Value* x = append_node->input(1); | ||
| WithInsertPoint ins(append_node); | ||
| Node* x_list_node = graph->createList(TensorType::get(), {x}); | ||
| graph->insertNode(x_list_node); | ||
| Node* add_node = | ||
| graph->create(Symbol::aten("add"), {list_val, x_list_node->output()}); | ||
| graph->insertNode(add_node); | ||
| add_node->output()->setType(ListType::ofTensors()); | ||
| list_val->replaceAllUsesAfterNodeWith(add_node, add_node->output()); | ||
| append_node->removeAllInputs(); | ||
| append_node->destroy(); | ||
| } | ||
| GRAPH_DUMP("After replace append", graph); | ||
| } | ||
|
|
||
| void InsertObserversHelper::preprocess( | ||
| Module& module, | ||
| const std::string& method_name) { | ||
|
|
@@ -1116,6 +1156,7 @@ void InsertObserversHelper::preprocess( | |
| // fuse decomposed linear into aten::linear | ||
| FuseLinear(graph); | ||
| replaceConvolutionWithAtenConv(graph); | ||
| makeAppendNonInplace(graph); | ||
| } | ||
|
|
||
| void InsertObserversHelper::analyze( | ||
|
|
@@ -1520,6 +1561,7 @@ void InsertObserversHelper::propagateObservedProperty( | |
| observed_values_.count(v) || block_observed_values.count(v); | ||
| } | ||
| if (all_observed) { | ||
| GRAPH_DEBUG("Pass through observed property in node:", *output->node()); | ||
| // This is to propagate observed property through | ||
| // all ops that doesn't require observation | ||
| block_observed_values.insert(output); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the rule for list add? Do we assume that all the tensors in the list have a dequantize op prior to the add?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is the transformed "append", as shown in the description. we will check if the inputs are produced with dequantize to make sure the inputs are quantized
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the rule of list add:
we'll check if %list is empty list, if it is, then the pass through list for %y is {%x}, otherwise, it is {%list, %x}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you check if the existing list has only dequantized tensors?. I see how you can do it for the input %x, dont follow how the check is done for %list
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are two cases for %list:
1.
one case is same as %x
we check if both %list and %x_list is produced by dequantize or not
2.
another case is when list is empty, it can be considered as containing quantized tensors
in this case we only need to check if %x_list is produced by dequantize.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My question is on how we check if %list is produced by dequantize or not. Do we iterate over all elements in the list?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no we don't, we just check if %list is produced by dequantize or not