@@ -19,14 +19,16 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeight<
1919 c10::optional<at::Tensor> bias,
2020 torch::List<int64_t > stride,
2121 torch::List<int64_t > padding,
22+ torch::List<int64_t > output_padding,
2223 torch::List<int64_t > dilation,
23- int64_t groups) {
24+ int64_t groups,
25+ bool transpose) {
26+ TORCH_CHECK (!transpose, " FBGEMM doesn't supprort conv_transpose yet." )
2427 TORCH_CHECK (
2528 weight.ndimension () == kSpatialDim + 2 ,
2629 " Weights are expected to have " ,
2730 kSpatialDim + 2 ,
2831 " dimensions" );
29-
3032 TORCH_CHECK (
3133 stride.size () == kSpatialDim ,
3234 " stride should contain " ,
@@ -45,7 +47,8 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeight<
4547 " elements for " ,
4648 kSpatialDim ,
4749 " D convolution." );
48- const int output_channels = weight.size (0 );
50+ const int output_channels_idx = transpose ? 1 : 0 ;
51+ const int output_channels = weight.size (output_channels_idx);
4952 const int input_channels_per_group = weight.size (1 );
5053 const int kernel_d = kSpatialDim == 2 ? 1 : weight.size (2 );
5154 const int kernel_h = weight.size (kSpatialDim );
@@ -143,8 +146,10 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeight<
143146 bias_contig,
144147 stride,
145148 padding,
149+ output_padding,
146150 dilation,
147151 groups,
152+ transpose,
148153 col_offsets,
149154 kSpatialDim == 2 ? std::vector<int64_t >{kernel_h, kernel_w}
150155 : std::vector<int64_t >{kernel_d, kernel_h, kernel_w},
@@ -166,28 +171,42 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsQnnp<
166171 c10::optional<at::Tensor> bias_in,
167172 torch::List<int64_t > stride,
168173 torch::List<int64_t > padding,
174+ torch::List<int64_t > output_padding,
169175 torch::List<int64_t > dilation,
170- int64_t groups) {
176+ int64_t groups,
177+ bool transpose) {
178+ TORCH_CHECK (kSpatialDim == 2 , " QNNPACK packing only supports 2D " ,
179+ " convolution." );
180+ TORCH_CHECK (
181+ weight.ndimension () == kSpatialDim + 2 ,
182+ " quantized::conv_prepack (qnnpack): Weights are expected to have " ,
183+ kSpatialDim + 2 , " dimensions" );
171184 TORCH_CHECK (
172- weight.ndimension () == 4 ,
173- " quantized::conv2d_prepack (qnnpack): Weights are expected to have 4 "
174- " dimensions" );
185+ stride.size () == kSpatialDim ,
186+ " quantized::conv_prepack (qnnpack): " ,
187+ kSpatialDim , " D convolution expects stride to have " ,
188+ kSpatialDim , " elements." );
175189 TORCH_CHECK (
176- stride.size () == 2 ,
177- " quantized::conv2d_prepack (qnnpack): 2D convolution only" );
190+ padding.size () == kSpatialDim ,
191+ " quantized::conv_prepack (qnnpack): Specify top/left input padding "
192+ " only. bottom/right padding assumed to be equal to top/left" );
178193 TORCH_CHECK (
179- padding .size () == 2 ,
180- " quantized::conv2d_prepack (qnnpack): Specify top/left padding only. "
181- " bottom/right padding assumed to be equal to top/left" );
194+ output_padding .size () == kSpatialDim ,
195+ " quantized::conv_prepack (qnnpack): Specify top/left output padding "
196+ " only. bottom/right padding assumed to be equal to top/left" );
182197 TORCH_CHECK (
183- dilation.size () == 2 ,
184- " quantized::conv2d_prepack (qnnpack): 2D convolution only" );
198+ dilation.size () == kSpatialDim ,
199+ " quantized::conv_prepack (qnnpack): " ,
200+ kSpatialDim , " D convolution expects dilation to have " ,
201+ kSpatialDim , " elements." );
185202
186203 at::native::initQNNPACK ();
187204
188205 // QNNPACK expects weights to be of the format {out_c, kH, kW, in_c/groups},
189206 // but PyTorch lays them out as {out_c, in_c/groups, kH, kW}
190- const size_t out_ch = weight.size (0 );
207+ // (or for ConvTranspose {in_c, out_c/groups, kH, kW})
208+ const size_t out_ch_idx = transpose ? 1 : 0 ;
209+ const size_t out_ch = weight.size (out_ch_idx);
191210 const uint32_t kernel_h = weight.size (2 );
192211 const uint32_t kernel_w = weight.size (3 );
193212
@@ -228,8 +247,10 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsQnnp<
228247 bias_fp32.contiguous (), /* fp32 bias */
229248 stride,
230249 padding,
250+ output_padding,
231251 dilation,
232252 groups,
253+ transpose,
233254 c10::nullopt , /* input_scale */
234255 {kernel_h, kernel_w},
235256 w_scales,
@@ -248,18 +269,38 @@ namespace {
248269template <int kSpatialDim = 2 >
249270class QConvPackWeightInt8 final {
250271 public:
251- static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim >> run (
272+ static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim >> run_conv (
252273 Tensor weight,
253274 c10::optional<Tensor> bias,
254275 torch::List<int64_t > stride,
255276 torch::List<int64_t > padding,
256277 torch::List<int64_t > dilation,
257278 int64_t groups) {
279+ torch::List<int64_t > output_padding;
280+ output_padding.reserve (kSpatialDim );
281+ for (int idx = 0 ; idx < kSpatialDim ; ++idx) {
282+ output_padding.push_back ((int64_t )0 );
283+ }
284+ return _run (weight, bias, stride, padding, output_padding, dilation, groups,
285+ false );
286+ }
287+
288+ private:
289+ static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim >> _run (
290+ Tensor weight,
291+ c10::optional<Tensor> bias,
292+ torch::List<int64_t > stride,
293+ torch::List<int64_t > padding,
294+ torch::List<int64_t > output_padding,
295+ torch::List<int64_t > dilation,
296+ int64_t groups,
297+ bool transpose) {
258298 auto & ctx = at::globalContext ();
259299#ifdef USE_FBGEMM
260300 if (ctx.qEngine () == at::QEngine::FBGEMM) {
261301 return PackedConvWeight<kSpatialDim >::prepack (
262- weight, bias, stride, padding, dilation, groups);
302+ weight, bias, stride, padding, output_padding, dilation, groups,
303+ transpose);
263304 }
264305#endif
265306
@@ -270,7 +311,8 @@ class QConvPackWeightInt8 final {
270311 " quantized::conv_prepack (qnnpack): QNNPACK only supports Conv1d "
271312 " and Conv2d now." );
272313 return PackedConvWeightsQnnp<kSpatialDim >::prepack (
273- weight, bias, stride, padding, dilation, groups);
314+ weight, bias, stride, padding, output_padding, dilation, groups,
315+ transpose);
274316 }
275317#endif
276318
@@ -283,31 +325,49 @@ class QConvPackWeightInt8 final {
283325
284326class QConv1dPackWeightInt8 final {
285327 public:
286- static c10::intrusive_ptr<ConvPackedParamsBase<2 >> run (
328+ static c10::intrusive_ptr<ConvPackedParamsBase<2 >> run_conv (
287329 Tensor weight,
288330 c10::optional<Tensor> bias,
289331 torch::List<int64_t > stride,
290332 torch::List<int64_t > padding,
291333 torch::List<int64_t > dilation,
292334 int64_t groups) {
335+ const torch::List<int64_t > output_padding ({0 });
336+ return _run (weight, bias, stride, padding, output_padding, dilation, groups,
337+ false );
338+ }
339+
340+ private:
341+ static c10::intrusive_ptr<ConvPackedParamsBase<2 >> _run (
342+ Tensor weight,
343+ c10::optional<Tensor> bias,
344+ torch::List<int64_t > stride,
345+ torch::List<int64_t > padding,
346+ torch::List<int64_t > output_padding,
347+ torch::List<int64_t > dilation,
348+ int64_t groups,
349+ bool transpose) {
293350 auto & ctx = at::globalContext ();
294351 if (weight.dim () == 3 ) {
295352 weight = weight.unsqueeze (quant_utils::kConv1dSqueezeDim + 2 );
296353 }
297354 stride = quant_utils::MakeArgForConv1d (stride, 1 );
298355 padding = quant_utils::MakeArgForConv1d (padding, 0 );
356+ output_padding = quant_utils::MakeArgForConv1d (output_padding, 0 );
299357 dilation = quant_utils::MakeArgForConv1d (dilation, 1 );
300358#ifdef USE_FBGEMM
301359 if (ctx.qEngine () == at::QEngine::FBGEMM) {
302360 return PackedConvWeight<2 >::prepack (
303- weight, bias, stride, padding, dilation, groups);
361+ weight, bias, stride, padding, output_padding, dilation, groups,
362+ transpose);
304363 }
305364#endif
306365
307366#ifdef USE_PYTORCH_QNNPACK
308367 if (ctx.qEngine () == at::QEngine::QNNPACK) {
309368 return PackedConvWeightsQnnp<2 >::prepack (
310- weight, bias, stride, padding, dilation, groups);
369+ weight, bias, stride, padding, output_padding, dilation, groups,
370+ transpose);
311371 }
312372#endif
313373 TORCH_CHECK (
@@ -319,14 +379,14 @@ class QConv1dPackWeightInt8 final {
319379
320380TORCH_LIBRARY_IMPL (quantized, QuantizedCPU, m) {
321381 // conv_prepack is deprecated, please use conv2d_prepack for 2D conv.
322- m.impl (" conv_prepack" , TORCH_FN (QConvPackWeightInt8<2 >::run ));
323- m.impl (" conv1d_prepack" , TORCH_FN (QConv1dPackWeightInt8::run ));
324- m.impl (" conv2d_prepack" , TORCH_FN (QConvPackWeightInt8<2 >::run ));
325- m.impl (" conv3d_prepack" , TORCH_FN (QConvPackWeightInt8<3 >::run ));
382+ m.impl (" conv_prepack" , TORCH_FN (QConvPackWeightInt8<2 >::run_conv ));
383+ m.impl (" conv1d_prepack" , TORCH_FN (QConv1dPackWeightInt8::run_conv ));
384+ m.impl (" conv2d_prepack" , TORCH_FN (QConvPackWeightInt8<2 >::run_conv ));
385+ m.impl (" conv3d_prepack" , TORCH_FN (QConvPackWeightInt8<3 >::run_conv ));
326386}
327387
328388TORCH_LIBRARY_IMPL (_quantized, QuantizedCPU, m) {
329- m.impl (" conv2d_prepack" , TORCH_FN (QConvPackWeightInt8<2 >::run ));
389+ m.impl (" conv2d_prepack" , TORCH_FN (QConvPackWeightInt8<2 >::run_conv ));
330390}
331391
332392} // namespace
0 commit comments