Skip to content

Commit 43335cd

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Fold quantize op into module (#25625)
Summary: Pull Request resolved: #25625 We want to fold the quantize op for weights/bias into module to avoid quantizing weights on the fly. Test Plan: python test/test_jit.py Imported from OSS Differential Revision: D17208889 fbshipit-source-id: 1854b8953b065855d210bc1166533c08ca264354
1 parent 27b5a6c commit 43335cd

File tree

4 files changed

+70
-0
lines changed

4 files changed

+70
-0
lines changed

test/test_jit.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,23 @@ def test_fuse_linear(self):
13311331
torch._C._jit_pass_fuse_linear(graph)
13321332
FileCheck().run(input_str, graph)
13331333

1334+
@_tmp_donotuse_dont_inline_everything
1335+
def test_fold_quantize(self):
1336+
class M(torch.nn.Module):
1337+
def __init__(self):
1338+
super(M, self).__init__()
1339+
self.weight = torch.nn.Parameter(torch.tensor([2], dtype=torch.float))
1340+
1341+
def forward(self, x):
1342+
return torch.quantize_linear(self.weight, 2.0, 0, torch.quint8)
1343+
1344+
m = torch.jit.script(M())
1345+
torch._C._jit_pass_fold_quantize(m._c, 'forward')
1346+
self.assertTrue(m._c._has_attribute('_quantized_weight'))
1347+
FileCheck().check_not('GetAttr[name="weight"]') \
1348+
.check('GetAttr[name="_quantized_weight"]') \
1349+
.run(m._c._get_method('forward').graph)
1350+
13341351
def test_pattern_based_rewrite(self):
13351352
# mul(mul(mul(mul(x,y),z),x),y) --> mul(mul(mulmul(x,y,z), x), y) -->
13361353
# --> mulmul(mulmul(x,y,z), x, y)

torch/csrc/jit/init.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,10 @@ void initJITBindings(PyObject* module) {
168168
[](std::shared_ptr<Graph>& g) { return QuantFusion(g); })
169169
.def("_jit_pass_fold_convbn", &FoldConvBatchNorm2d)
170170
.def("_jit_pass_fuse_linear", &FuseLinear)
171+
.def("_jit_pass_fold_quantize",
172+
[](script::Module& module, const std::string& method_name) {
173+
FoldQuantizeCallIntoBuffer(module, method_name);
174+
})
171175
.def(
172176
"_jit_pass_quantlint",
173177
[](std::shared_ptr<Graph>& g) { return QuantLinting(g); })

torch/csrc/jit/passes/quantization.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,5 +825,43 @@ graph(%self, %x):
825825
}
826826
}
827827
}
828+
829+
void FoldQuantizeCallIntoBuffer(
830+
script::Module& module,
831+
const std::string& method_name) {
832+
// TODO: extra filter on scale/zero_point/dtype to make sure they are Constant
833+
const std::string pattern = R"(
834+
graph(%self, %scale, %zero_point, %dtype):
835+
%weight = prim::GetAttr[name="weight"](%self)
836+
%weight_quant = aten::quantize_linear(%weight, %scale, %zero_point, %dtype)
837+
return (%weight_quant))";
838+
Graph pattern_graph;
839+
std::unordered_map<std::string, Value*> vmap;
840+
script::parseIR(pattern, &pattern_graph, vmap);
841+
auto method = module.get_method(method_name);
842+
auto graph = method.graph();
843+
auto matches = findPatternMatches(pattern_graph, *graph);
844+
for (const auto& match : matches) {
845+
auto match_vmap = match.values_map;
846+
auto* weight = match_vmap.at(vmap.at("weight"));
847+
auto float_weight = module.get_parameter("weight").variable_data();
848+
auto scale = toIValue(match_vmap.at(vmap.at("scale"))).value().toDouble();
849+
auto zero_point =
850+
toIValue(match_vmap.at(vmap.at("zero_point"))).value().toInt();
851+
auto dtype =
852+
toIValue(match_vmap.at(vmap.at("dtype"))).value().toScalarType();
853+
module.register_buffer(
854+
"_quantized_weight",
855+
at::quantize_linear(float_weight, scale, zero_point, dtype));
856+
}
857+
858+
std::string replacement = R"(
859+
graph(%self, %scale, %zero_point, %dtype):
860+
%weight_quant = prim::GetAttr[name="_quantized_weight"](%self)
861+
return (%weight_quant))";
862+
SubgraphRewriter rewriter;
863+
rewriter.RegisterRewritePattern(pattern, replacement);
864+
rewriter.runOnGraph(graph);
865+
}
828866
} // namespace jit
829867
} // namespace torch

torch/csrc/jit/passes/quantization.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,5 +105,16 @@ TORCH_API void QuantFusion(std::shared_ptr<Graph>& graph);
105105
*/
106106
TORCH_API void FoldConvBatchNorm2d(const script::Module& module);
107107

108+
/** \brief Fold quantize function call into module
109+
*
110+
* For the graph in the specified method of module, if we find a quantize_linear
111+
* call on an attribute("weight") of the module, we'll quantize the attribute directly
112+
* and register a new buffer "_quantized_weight" on the module and remove the
113+
* quantize_linear call and replace the use of the quantized weight with
114+
* "_quantized_weight".
115+
*/
116+
TORCH_API void FoldQuantizeCallIntoBuffer(script::Module& module, const std::string& method_name);
117+
118+
108119
} // namespace jit
109120
} // namespace torch

0 commit comments

Comments
 (0)