@@ -1249,7 +1249,6 @@ at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
12491249 // Scales of ONEDNN and PyTorch are reciprocal
12501250 const ideep::scale_t & src_scales = ideep::scale_t (1 , 1.0 /input_scale);
12511251 const ideep::scale_t & weights_scales = weights.get_scale ();
1252- int64_t scale_size = weights_scales.size ();
12531252 double inv_output_scale = 1.0 /output_scale;
12541253 const ideep::zero_point_t src_zero_points = ideep::zero_point_t (1 , input_zp);
12551254 const ideep::zero_point_t dst_zero_points = ideep::zero_point_t (1 , output_zero_point);
@@ -1274,29 +1273,25 @@ at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
12741273 ideep::convolution_transpose_forward::prepare (
12751274 params, src, weights, b, dst_dims, dst,
12761275 strides, padding_l, padding_r, dilates, groups (),
1277- src_scales, weights_scales, ideep::scale_t (scale_size , inv_output_scale),
1276+ src_scales, weights_scales, ideep::scale_t (1 , inv_output_scale),
12781277 src_zero_points, dst_zero_points, op_attr,
12791278 dnnl::algorithm::deconvolution_direct,
12801279 dnnl::prop_kind::forward_inference,
12811280 ideep::u8s8, ideep::engine::cpu_engine ());
1282- get_deconv_cache () = DeconvPrimitiveCache (
1283- cache_key, params.pd , b, params.bias_attr , params.input_zero_point );
1284- onednn_utils::try_reorder (
1285- weights, (ideep::tensor::desc)params.pd .weights_desc (), weights_scales);
1281+ get_deconv_cache () = DeconvPrimitiveCache (cache_key, params, b);
1282+ weights = weights.reorder_if_differ_in (params.pd .weights_desc ());
12861283 });
12871284 if (get_deconv_cache ().hit (cache_key)) {
1288- Deconv& primitive = get_deconv_cache ().get_primitive ();
1289- DeconvDesc& pd = get_deconv_cache ().get_primitive_desc ();
1290- auto & src_zp_tensor = get_deconv_cache ().get_src_zp_tensor ();
1285+ DeconvParams& params = get_deconv_cache ().get_params ();
12911286 auto & expected_bias = get_deconv_cache ().get_bias ();
1292- ideep::convolution_transpose_forward::compute (
1293- pd, primitive, src, weights, expected_bias, dst, src_zp_tensor, groups () );
1287+ ideep::convolution_transpose_forward::compute< false , false > (
1288+ params, src, weights, expected_bias, dst);
12941289 } else {
1295- ideep::convolution_transpose_forward::compute_v2 (
1290+ ideep::convolution_transpose_forward::compute (
12961291 src, weights, b, dst_dims, dst,
12971292 strides, padding_l, padding_r, dilates,
12981293 groups (), src_scales, weights_scales,
1299- ideep::scale_t (scale_size , inv_output_scale),
1294+ ideep::scale_t (1 , inv_output_scale),
13001295 src_zero_points, dst_zero_points, op_attr,
13011296 dnnl::algorithm::deconvolution_direct,
13021297 dnnl::prop_kind::forward_inference,
@@ -1306,42 +1301,32 @@ at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
13061301 PrimitiveCacheKey cache_key = std::make_tuple (
13071302 input_scale, input_zp, src_dims, output_scale, output_zero_point, num_threads);
13081303 c10::call_once (*cache_initialized_flag, [&](){
1309- src.set_zero_point (src_zero_points);
1310- dst.set_zero_point (dst_zero_points);
13111304 ConvParams params;
13121305 ideep::convolution_forward::prepare (
13131306 params, src, weights, b, dst_dims, dst,
13141307 strides, dilates, padding_l, padding_r, groups (),
1315- src_scales, weights_scales, ideep::scale_t (scale_size, inv_output_scale),
1308+ src_scales, weights_scales, ideep::scale_t (1 , inv_output_scale),
1309+ src_zero_points, dst_zero_points,
13161310 op_attr, dnnl::algorithm::convolution_direct,
13171311 dnnl::prop_kind::forward_inference,
13181312 ideep::u8s8, ideep::engine::cpu_engine ());
1319- get_conv_cache () = ConvPrimitiveCache (cache_key, params.pd , b, params.bias_attr );
1320- onednn_utils::try_reorder (
1321- weights, (ideep::tensor::desc)params.pd .weights_desc (), weights_scales);
1313+ get_conv_cache () = ConvPrimitiveCache (cache_key, params, b);
1314+ weights = weights.reorder_if_differ_in (params.pd .weights_desc ());
13221315 });
13231316 // If hit, use cached data. If miss, fall back to normal path.
13241317 if (get_conv_cache ().hit (cache_key)) {
1325- ConvDesc& pd = get_conv_cache ().get_primitive_desc ();
1326- Conv& primitive = get_conv_cache ().get_primitive ();
1327- auto & src_zp_tensor = get_conv_cache ().get_src_zp_tensor ();
1318+ auto & params = get_conv_cache ().get_params ();
13281319 auto & expected_bias = get_conv_cache ().get_bias ();
1329- ideep::convolution_forward::compute (
1330- pd, primitive, src, weights, expected_bias, dst, src_zp_tensor, groups ());
1320+ ideep::convolution_forward::compute<false , false >(params, src, weights, expected_bias, dst);
13311321 } else {
1332- src.set_zero_point (src_zero_points);
1333- dst.set_zero_point (dst_zero_points);
1334- ConvParams params;
1335- ideep::convolution_forward::prepare (
1336- params, src, weights, b, dst_dims, dst,
1322+ ideep::convolution_forward::compute (
1323+ src, weights, b, dst_dims, dst,
13371324 strides, dilates, padding_l, padding_r, groups (),
1338- src_scales, weights_scales, ideep::scale_t (scale_size, inv_output_scale),
1339- op_attr, dnnl::algorithm::convolution_direct,
1325+ src_scales, weights_scales, ideep::scale_t (1 , inv_output_scale),
1326+ src_zero_points, dst_zero_points, op_attr,
1327+ dnnl::algorithm::convolution_direct,
13401328 dnnl::prop_kind::forward_inference,
13411329 ideep::u8s8, ideep::engine::cpu_engine ());
1342- onednn_utils::try_reorder (
1343- weights, (ideep::tensor::desc)params.pd .weights_desc (), weights_scales);
1344- ideep::convolution_forward::compute (params, src, weights, b, dst);
13451330 }
13461331 }
13471332 return output;
0 commit comments