@@ -156,52 +156,6 @@ void initJITBindings(PyObject* module) {
156156 .def (
157157 " _jit_pass_quant_fusion" ,
158158 [](std::shared_ptr<Graph>& g) { return QuantFusion (g); })
159- .def (
160- " _jit_pass_insert_quantdequant" ,
161- [](const script::Module& moduleObj,
162- const std::string& methodName,
163- py::dict& pyQParamDict) {
164- if (!pyQParamDict.size ()) {
165- return ;
166- }
167-
168- auto qparam_dict = py::cast<std::unordered_map<
169- std::string,
170- std::tuple<std::string, float , int >>>(pyQParamDict);
171- return InsertQuantDequantNodes (moduleObj, methodName, qparam_dict);
172- })
173- .def (
174- " _jit_pass_insert_quantdequant_for_weight_bias" ,
175- [](const script::Module& moduleObj,
176- const std::string& method_name,
177- const std::string& param_name,
178- py::function pyGetQParamFunc) {
179- // For different static params we pass different getQParamFunc via
180- // same interface exposed by the quantizer.
181- if (param_name == std::string (" weight" )) {
182- auto getQParamFunc =
183- py::cast<std::function<std::tuple<std::string, float , int >(
184- at::Tensor)>>(pyGetQParamFunc);
185- InsertQuantDequantNodesForParam (
186- moduleObj,
187- method_name,
188- param_name,
189- getQParamFunc,
190- at::ScalarType::QInt8);
191- } else if (param_name == std::string (" bias" )) {
192- auto getQParamFunc =
193- py::cast<std::function<std::tuple<std::string, float , int >(
194- float , float )>>(pyGetQParamFunc);
195- InsertQuantDequantNodesForParam (
196- moduleObj,
197- method_name,
198- param_name,
199- getQParamFunc,
200- at::ScalarType::QInt32);
201- } else {
202- TORCH_CHECK (false , " Invalid Param Name" );
203- }
204- })
205159 .def (
206160 " _jit_pass_quantlint" ,
207161 [](std::shared_ptr<Graph>& g) { return QuantLinting (g); })
0 commit comments