Skip to content

Commit b15d914

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Remove InsertQuantDeQuantNode (#25000)
Summary: Pull Request resolved: #25000 Remove deprecated insert_quantdequant pass Test Plan: . Imported from OSS Differential Revision: D17001139 fbshipit-source-id: 5ecabdff84598fe21f24ea827b615e697081ee53
1 parent fbb88f5 commit b15d914

File tree

3 files changed

+0
-474
lines changed

3 files changed

+0
-474
lines changed

torch/csrc/jit/init.cpp

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -156,52 +156,6 @@ void initJITBindings(PyObject* module) {
156156
.def(
157157
"_jit_pass_quant_fusion",
158158
[](std::shared_ptr<Graph>& g) { return QuantFusion(g); })
159-
.def(
160-
"_jit_pass_insert_quantdequant",
161-
[](const script::Module& moduleObj,
162-
const std::string& methodName,
163-
py::dict& pyQParamDict) {
164-
if (!pyQParamDict.size()) {
165-
return;
166-
}
167-
168-
auto qparam_dict = py::cast<std::unordered_map<
169-
std::string,
170-
std::tuple<std::string, float, int>>>(pyQParamDict);
171-
return InsertQuantDequantNodes(moduleObj, methodName, qparam_dict);
172-
})
173-
.def(
174-
"_jit_pass_insert_quantdequant_for_weight_bias",
175-
[](const script::Module& moduleObj,
176-
const std::string& method_name,
177-
const std::string& param_name,
178-
py::function pyGetQParamFunc) {
179-
// For different static params we pass different getQParamFunc via
180-
// same interface exposed by the quantizer.
181-
if (param_name == std::string("weight")) {
182-
auto getQParamFunc =
183-
py::cast<std::function<std::tuple<std::string, float, int>(
184-
at::Tensor)>>(pyGetQParamFunc);
185-
InsertQuantDequantNodesForParam(
186-
moduleObj,
187-
method_name,
188-
param_name,
189-
getQParamFunc,
190-
at::ScalarType::QInt8);
191-
} else if (param_name == std::string("bias")) {
192-
auto getQParamFunc =
193-
py::cast<std::function<std::tuple<std::string, float, int>(
194-
float, float)>>(pyGetQParamFunc);
195-
InsertQuantDequantNodesForParam(
196-
moduleObj,
197-
method_name,
198-
param_name,
199-
getQParamFunc,
200-
at::ScalarType::QInt32);
201-
} else {
202-
TORCH_CHECK(false, "Invalid Param Name");
203-
}
204-
})
205159
.def(
206160
"_jit_pass_quantlint",
207161
[](std::shared_ptr<Graph>& g) { return QuantLinting(g); })

0 commit comments

Comments
 (0)