Skip to content

Commit ba2cb07

Browse files
committed
[Quant] onednn backend switch to ideep new api without affacting performance
ghstack-source-id: 70704c0 Pull Request resolved: #91056
1 parent 9cf8434 commit ba2cb07

File tree

5 files changed

+65
-106
lines changed

5 files changed

+65
-106
lines changed

aten/src/ATen/native/quantized/cpu/OnednnUtils.h

Lines changed: 26 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -91,78 +91,53 @@ struct LinearPrimitiveCache : PrimitiveCache {
9191
struct ConvPrimitiveCache : PrimitiveCache {
9292
ConvPrimitiveCache() {}
9393

94-
ConvPrimitiveCache(const PrimitiveCacheKey& key,
95-
const ConvDesc& conv_desc,
96-
const ideep::tensor& bias,
97-
const ideep::attr_t bias_attr) {
94+
ConvPrimitiveCache(
95+
const PrimitiveCacheKey& key,
96+
const ConvParams& params,
97+
const ideep::tensor& bias) {
9898
this->key = key;
99-
this->primitive_desc = conv_desc;
100-
this->primitive = Conv(this->primitive_desc);
101-
// Construct tensor of input zero point
102-
ideep::tensor::desc input_zp_desc = {{1}, ideep::data_type::s32, {1}};
103-
this->input_zp_tensor.init(input_zp_desc, ideep::engine::cpu_engine());
104-
auto zp_data_ptr = reinterpret_cast<int32_t *>(this->input_zp_tensor.get_data_handle());
105-
zp_data_ptr[0] = std::get<InputZeroPoint>(key);
106-
// Construct expected bias
107-
this->expected_bias = bias.reorder_if_differ_in(conv_desc.bias_desc(), bias_attr);
99+
this->params = params;
100+
if (!bias.is_empty()) {
101+
this->expected_bias =
102+
bias.reorder_if_differ_in(params.pd.bias_desc(), params.bias_attr);
103+
}
108104
}
109105

110-
ConvDesc primitive_desc;
111-
Conv primitive;
112-
ideep::tensor input_zp_tensor;
113106
ideep::tensor expected_bias;
107+
ConvParams params;
114108

115-
inline ConvDesc& get_primitive_desc() {
116-
return primitive_desc;
117-
}
118-
119-
inline Conv& get_primitive() {
120-
return primitive;
121-
}
122-
123-
inline ideep::tensor& get_src_zp_tensor() {
124-
return input_zp_tensor;
109+
ConvParams& get_params() {
110+
return params;
125111
}
126112

127-
inline ideep::tensor& get_bias() {
113+
ideep::tensor& get_bias() {
128114
return expected_bias;
129115
}
130116
};
131117

132118
struct DeconvPrimitiveCache : PrimitiveCache {
133119
DeconvPrimitiveCache() {}
134120

135-
DeconvPrimitiveCache(const PrimitiveCacheKey& key,
136-
const DeconvDesc& deconv_desc,
137-
const ideep::tensor& bias,
138-
const ideep::attr_t bias_attr,
139-
const ideep::tensor& input_zero_point) {
121+
DeconvPrimitiveCache(
122+
const PrimitiveCacheKey& key,
123+
const DeconvParams& params,
124+
const ideep::tensor& bias) {
140125
this->key = key;
141-
this->primitive_desc = deconv_desc;
142-
this->primitive = Deconv(this->primitive_desc);
143-
this->input_zp_tensor = std::move(input_zero_point);
144-
// Construct expected bias
145-
this->expected_bias = bias.reorder_if_differ_in(deconv_desc.bias_desc(), bias_attr);
126+
this->params = params;
127+
if (!bias.is_empty()) {
128+
this->expected_bias =
129+
bias.reorder_if_differ_in(params.pd.bias_desc(), params.bias_attr);
130+
}
146131
}
147132

148-
DeconvDesc primitive_desc;
149-
Deconv primitive;
150-
ideep::tensor input_zp_tensor;
133+
DeconvParams params;
151134
ideep::tensor expected_bias;
152135

153-
inline DeconvDesc& get_primitive_desc() {
154-
return primitive_desc;
155-
}
156-
157-
inline Deconv& get_primitive() {
158-
return primitive;
159-
}
160-
161-
inline ideep::tensor& get_src_zp_tensor() {
162-
return input_zp_tensor;
136+
DeconvParams& get_params() {
137+
return params;
163138
}
164139

165-
inline ideep::tensor& get_bias() {
140+
ideep::tensor& get_bias() {
166141
return expected_bias;
167142
}
168143
};

aten/src/ATen/native/quantized/cpu/qconv.cpp

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1249,7 +1249,6 @@ at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
12491249
// Scales of ONEDNN and PyTorch are reciprocal
12501250
const ideep::scale_t& src_scales = ideep::scale_t(1, 1.0/input_scale);
12511251
const ideep::scale_t& weights_scales = weights.get_scale();
1252-
int64_t scale_size = weights_scales.size();
12531252
double inv_output_scale = 1.0/output_scale;
12541253
const ideep::zero_point_t src_zero_points = ideep::zero_point_t(1, input_zp);
12551254
const ideep::zero_point_t dst_zero_points = ideep::zero_point_t(1, output_zero_point);
@@ -1274,29 +1273,25 @@ at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
12741273
ideep::convolution_transpose_forward::prepare(
12751274
params, src, weights, b, dst_dims, dst,
12761275
strides, padding_l, padding_r, dilates, groups(),
1277-
src_scales, weights_scales, ideep::scale_t(scale_size, inv_output_scale),
1276+
src_scales, weights_scales, ideep::scale_t(1, inv_output_scale),
12781277
src_zero_points, dst_zero_points, op_attr,
12791278
dnnl::algorithm::deconvolution_direct,
12801279
dnnl::prop_kind::forward_inference,
12811280
ideep::u8s8, ideep::engine::cpu_engine());
1282-
get_deconv_cache() = DeconvPrimitiveCache(
1283-
cache_key, params.pd, b, params.bias_attr, params.input_zero_point);
1284-
onednn_utils::try_reorder(
1285-
weights, (ideep::tensor::desc)params.pd.weights_desc(), weights_scales);
1281+
get_deconv_cache() = DeconvPrimitiveCache(cache_key, params, b);
1282+
weights = weights.reorder_if_differ_in(params.pd.weights_desc());
12861283
});
12871284
if (get_deconv_cache().hit(cache_key)) {
1288-
Deconv& primitive = get_deconv_cache().get_primitive();
1289-
DeconvDesc& pd = get_deconv_cache().get_primitive_desc();
1290-
auto& src_zp_tensor = get_deconv_cache().get_src_zp_tensor();
1285+
DeconvParams& params = get_deconv_cache().get_params();
12911286
auto& expected_bias = get_deconv_cache().get_bias();
1292-
ideep::convolution_transpose_forward::compute(
1293-
pd, primitive, src, weights, expected_bias, dst, src_zp_tensor, groups());
1287+
ideep::convolution_transpose_forward::compute<false, false>(
1288+
params, src, weights, expected_bias, dst);
12941289
} else {
1295-
ideep::convolution_transpose_forward::compute_v2(
1290+
ideep::convolution_transpose_forward::compute(
12961291
src, weights, b, dst_dims, dst,
12971292
strides, padding_l, padding_r, dilates,
12981293
groups(), src_scales, weights_scales,
1299-
ideep::scale_t(scale_size, inv_output_scale),
1294+
ideep::scale_t(1, inv_output_scale),
13001295
src_zero_points, dst_zero_points, op_attr,
13011296
dnnl::algorithm::deconvolution_direct,
13021297
dnnl::prop_kind::forward_inference,
@@ -1306,42 +1301,32 @@ at::Tensor PackedConvWeightsOnednn<kSpatialDim>::apply_impl(
13061301
PrimitiveCacheKey cache_key = std::make_tuple(
13071302
input_scale, input_zp, src_dims, output_scale, output_zero_point, num_threads);
13081303
c10::call_once(*cache_initialized_flag, [&](){
1309-
src.set_zero_point(src_zero_points);
1310-
dst.set_zero_point(dst_zero_points);
13111304
ConvParams params;
13121305
ideep::convolution_forward::prepare(
13131306
params, src, weights, b, dst_dims, dst,
13141307
strides, dilates, padding_l, padding_r, groups(),
1315-
src_scales, weights_scales, ideep::scale_t(scale_size, inv_output_scale),
1308+
src_scales, weights_scales, ideep::scale_t(1, inv_output_scale),
1309+
src_zero_points, dst_zero_points,
13161310
op_attr, dnnl::algorithm::convolution_direct,
13171311
dnnl::prop_kind::forward_inference,
13181312
ideep::u8s8, ideep::engine::cpu_engine());
1319-
get_conv_cache() = ConvPrimitiveCache(cache_key, params.pd, b, params.bias_attr);
1320-
onednn_utils::try_reorder(
1321-
weights, (ideep::tensor::desc)params.pd.weights_desc(), weights_scales);
1313+
get_conv_cache() = ConvPrimitiveCache(cache_key, params, b);
1314+
weights = weights.reorder_if_differ_in(params.pd.weights_desc());
13221315
});
13231316
// If hit, use cached data. If miss, fall back to normal path.
13241317
if (get_conv_cache().hit(cache_key)) {
1325-
ConvDesc& pd = get_conv_cache().get_primitive_desc();
1326-
Conv& primitive = get_conv_cache().get_primitive();
1327-
auto& src_zp_tensor = get_conv_cache().get_src_zp_tensor();
1318+
auto& params = get_conv_cache().get_params();
13281319
auto& expected_bias = get_conv_cache().get_bias();
1329-
ideep::convolution_forward::compute(
1330-
pd, primitive, src, weights, expected_bias, dst, src_zp_tensor, groups());
1320+
ideep::convolution_forward::compute<false, false>(params, src, weights, expected_bias, dst);
13311321
} else {
1332-
src.set_zero_point(src_zero_points);
1333-
dst.set_zero_point(dst_zero_points);
1334-
ConvParams params;
1335-
ideep::convolution_forward::prepare(
1336-
params, src, weights, b, dst_dims, dst,
1322+
ideep::convolution_forward::compute(
1323+
src, weights, b, dst_dims, dst,
13371324
strides, dilates, padding_l, padding_r, groups(),
1338-
src_scales, weights_scales, ideep::scale_t(scale_size, inv_output_scale),
1339-
op_attr, dnnl::algorithm::convolution_direct,
1325+
src_scales, weights_scales, ideep::scale_t(1, inv_output_scale),
1326+
src_zero_points, dst_zero_points, op_attr,
1327+
dnnl::algorithm::convolution_direct,
13401328
dnnl::prop_kind::forward_inference,
13411329
ideep::u8s8, ideep::engine::cpu_engine());
1342-
onednn_utils::try_reorder(
1343-
weights, (ideep::tensor::desc)params.pd.weights_desc(), weights_scales);
1344-
ideep::convolution_forward::compute(params, src, weights, b, dst);
13451330
}
13461331
}
13471332
return output;

aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,8 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsOnednn<
408408
ideep::tag w_tag = ideep::tag::any;
409409
const bool with_groups = groups > 1;
410410
if (transpose) {
411-
w_desc = ideep::convolution_transpose_forward::expected_weights_desc(
411+
// template args: <(src/dst) is_channels_last, transposed>
412+
w_desc = ideep::convolution_transpose_forward::expected_weights_desc<true, false>(
412413
dims, dnnl::memory::data_type::s8,
413414
strides, padding_l, padding_r, dilates, groups,
414415
dnnl::algorithm::deconvolution_direct, dnnl::prop_kind::forward_inference,
@@ -419,15 +420,14 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsOnednn<
419420
dims_giohw = with_groups ? ideep::utils::group_dims(dims_iohw, groups) : dims_iohw;
420421
std::vector<int64_t> perms(dims_giohw.size(), 0); // for permutation of weight
421422
std::iota(perms.begin(), perms.end(), 0);
422-
w_desc = w_desc.transpose(with_groups, with_groups + 1);
423423
std::swap(perms[with_groups], perms[with_groups + 1]);
424424
weight_copy = weight.reshape(dims_giohw).permute(c10::IntArrayRef(perms)).clone();
425425
} else {
426426
w_desc = ideep::convolution_forward::expected_weights_desc(
427427
dims, dnnl::memory::data_type::s8,
428428
strides, padding_l, padding_r, dilates, groups,
429429
dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference,
430-
dnnl::memory::data_type::u8, ideep::dims(), op_attr);
430+
dnnl::memory::data_type::u8, ideep::dims(), op_attr, /*is_channels_last=*/true);
431431
weight_copy = weight.clone();
432432
}
433433
if (with_groups) {

aten/src/ATen/native/quantized/cpu/qlinear.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -844,19 +844,20 @@ at::Tensor PackedLinearWeightsOnednn::apply_impl(
844844
c10::call_once(*cache_initialized_flag, [&](){
845845
LinearParams params;
846846
ideep::matmul_forward::prepare</*is_dynamic=*/false>(
847-
params, x, w, b, y, 1.0f, 1.0f,
847+
params, x, w, b, y,
848848
src_scales, weights_scales, dst_scales,
849-
src_zero_point, dst_zero_point, op_attr);
850-
get_cache() = LinearPrimitiveCache(cache_key, params);
851-
onednn_utils::try_reorder(
852-
w, (ideep::tensor::desc)params.pd.weights_desc(), weights_scales);
849+
src_zero_point, dst_zero_point, 1.0f, 1.0f, op_attr);
850+
get_cache() = LinearPrimitiveCache(cache_key, params, b);
851+
w = w.reorder_if_differ_in(params.pd.weights_desc());
853852
});
854853
if (get_cache().hit(cache_key)) {
855854
LinearParams& params = get_cache().get_param();
856-
ideep::matmul_forward::compute(params, x, w, b, y);
855+
auto& expected_bias = get_cache().get_expected_bias();
856+
ideep::matmul_forward::compute<false, false>(params, x, w, expected_bias, y);
857857
} else {
858-
ideep::matmul_forward::compute_v2(x, w, b, y, 1.0f, 1.0f, src_scales, weights_scales,
859-
dst_scales, src_zero_point, dst_zero_point, op_attr);
858+
ideep::matmul_forward::compute(x, w, b, y, src_scales, weights_scales,
859+
dst_scales, src_zero_point, dst_zero_point,
860+
1.0f, 1.0f, op_attr);
860861
}
861862
auto out_sizes = input.sizes().vec();
862863
out_sizes.back() = N;

aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -567,22 +567,20 @@ at::Tensor PackedLinearWeightsOnednn::apply_dynamic_impl(
567567
c10::call_once(*cache_initialized_flag, [&](){
568568
LinearParams params;
569569
ideep::matmul_forward::prepare</*is_dynamic=*/true>(
570-
params, x, w, b, y, 1.0f, 1.0f,
570+
params, x, w, b, y,
571571
src_scales, weights_scales, ideep::scale_t(),
572-
src_zero_point, ideep::zero_point_t(), op_attr);
572+
src_zero_point, ideep::zero_point_t(), 1.0f, 1.0f, op_attr);
573573
get_cache() = LinearPrimitiveCache(cache_key, params);
574-
onednn_utils::try_reorder(
575-
w, (ideep::tensor::desc)params.pd.weights_desc(), weights_scales);
574+
w = w.reorder_if_differ_in(params.pd.weights_desc());
576575
});
577576
if (get_cache().hit_dynamic(cache_key)) {
578577
LinearParams& params = get_cache().get_param();
579-
ideep::matmul_forward::compute_dynamic(
580-
params, x, w, b, y, 1.0f, 1.0f, src_scales, weights_scales,
581-
ideep::scale_t(), src_zero_point, ideep::zero_point_t());
578+
ideep::matmul_forward::compute(params, x, w, b, y, src_scales, src_zero_point);
582579
} else {
583-
ideep::matmul_forward::compute_v2(x, w, b, y, 1.0f, 1.0f,
584-
src_scales, weights_scales, ideep::scale_t(),
585-
src_zero_point, ideep::zero_point_t(), op_attr);
580+
ideep::matmul_forward::compute(x, w, b, y,
581+
src_scales, weights_scales, ideep::scale_t(),
582+
src_zero_point, ideep::zero_point_t(),
583+
1.0f, 1.0f, op_attr);
586584
}
587585
auto out_sizes = input.sizes().vec();
588586
out_sizes.back() = w.get_dim(1);

0 commit comments

Comments
 (0)