1212#include < ATen/native/quantized/cpu/conv_packed_params.h>
1313#include < caffe2/utils/threadpool/pthreadpool-cpp.h>
1414
15+ namespace {
16+ // To have a sanity check for maximum matrix size.
17+ constexpr int64_t kReasonableMaxDim = 1000000 ;
18+ }
19+
1520template <int kSpatialDim = 2 >
1621bool ConvDimChecks (
1722 int64_t act_dims,
1823 int64_t stride_dims,
1924 int64_t padding_dims,
20- int64_t dilation_dims) {
25+ int64_t output_padding_dims,
26+ int64_t dilation_dims,
27+ std::string func_name,
28+ bool transpose = false ) {
2129 TORCH_CHECK (
2230 act_dims == kSpatialDim + 2 ,
23- " quantized::conv " ,
31+ func_name ,
2432 kSpatialDim ,
2533 " d(): Expected activation tensor to have " ,
2634 kSpatialDim + 2 ,
27- " dimensions." );
35+ " dimensions, got " ,
36+ act_dims);
2837 TORCH_CHECK (
2938 stride_dims == kSpatialDim ,
30- " quantized::conv " ,
39+ func_name ,
3140 kSpatialDim ,
3241 " d(): Expected stride tensor to have " ,
3342 kSpatialDim ,
34- " dimensions." );
43+ " dimensions, got " ,
44+ stride_dims);
3545 TORCH_CHECK (
3646 padding_dims == kSpatialDim ,
37- " quantized::conv " ,
47+ func_name ,
3848 kSpatialDim ,
3949 " d(): Expected padding tensor to have " ,
4050 kSpatialDim ,
41- " dimensions." );
51+ " dimensions, got " ,
52+ padding_dims);
53+ TORCH_CHECK (
54+ !transpose || (output_padding_dims == kSpatialDim ),
55+ func_name,
56+ kSpatialDim ,
57+ " d(): Expected output padding tensor to have " ,
58+ kSpatialDim ,
59+ " dimensions, got " ,
60+ output_padding_dims);
4261 TORCH_CHECK (
4362 dilation_dims == kSpatialDim ,
44- " quantized::conv " ,
63+ func_name ,
4564 kSpatialDim ,
4665 " d(): Expected dilation tensor to have " ,
4766 kSpatialDim ,
48- " dimensions." );
67+ " dimensions, got " ,
68+ dilation_dims);
4969 return true ;
5070}
5171
72+ inline int64_t compute_deconv_shape (int64_t input,
73+ int64_t kernel,
74+ int64_t stride,
75+ int64_t input_padding,
76+ int64_t output_padding,
77+ int64_t dilation) {
78+ int64_t out = (input - 1 ) * stride - 2 * input_padding
79+ + dilation * (kernel - 1 ) + output_padding + 1 ;
80+ return out;
81+ }
82+
83+ template <int64_t kSpatialDim >
84+ at::SmallVector<int64_t , kSpatialDim + 2 > MakeDeConvOutputShape (
85+ int64_t N, int64_t M,
86+ const std::vector<int64_t >& input_shape,
87+ const std::vector<int64_t >& kernel,
88+ const torch::List<int64_t >& stride,
89+ const torch::List<int64_t >& input_padding,
90+ const torch::List<int64_t >& output_padding,
91+ const torch::List<int64_t >& dilation) {
92+ at::SmallVector<int64_t , kSpatialDim + 2 > output_shape;
93+ output_shape.resize (kSpatialDim + 2 );
94+ output_shape[0 ] = N; // Batch size
95+ output_shape[1 ] = M; // Output channels
96+ for (int64_t idx = 0 ; idx < kSpatialDim ; ++idx) {
97+ output_shape[idx + 2 ] = compute_deconv_shape (input_shape[idx],
98+ kernel[idx],
99+ stride[idx],
100+ input_padding[idx],
101+ output_padding[idx],
102+ dilation[idx]);
103+ TORCH_CHECK (output_shape[idx + 2 ] > 0 ,
104+ " Output dimension is zero for " , idx, " axis;"
105+ " kernel: " , kernel[idx],
106+ " , stride: " , stride[idx],
107+ " , input padding: " , input_padding[idx],
108+ " , output padding: " , output_padding[idx],
109+ " , dilation: " , dilation[idx])
110+ TORCH_CHECK (output_shape[idx + 2 ] < kReasonableMaxDim ,
111+ " Output dimension is beyound reasonable maximum for " , idx,
112+ " axis;"
113+ " kernel: " , kernel[idx],
114+ " , stride: " , stride[idx],
115+ " , input padding: " , input_padding[idx],
116+ " , output padding: " , output_padding[idx],
117+ " , dilation: " , dilation[idx]);
118+ }
119+ return output_shape;
120+ }
121+
52122#ifdef USE_FBGEMM
53123
54124template <int kSpatialDim = 2 >
@@ -203,10 +273,13 @@ at::Tensor PackedConvWeight<kSpatialDim>::apply_impl(
203273 //
204274 // This might change when full memory format support lands
205275 // See https://github.com/pytorch/pytorch/issues/23403
276+ const std::string func_name = transpose () ? " quantized::conv_transpose"
277+ : " quantized::conv" ;
206278 TORCH_CHECK (
207279 fbgemm::fbgemmSupportedCPU (), " Your CPU does not support FBGEMM." );
208280 ConvDimChecks<kSpatialDim >(
209- act.ndimension (), stride_.size (), padding_.size (), dilation_.size ());
281+ act.ndimension (), stride ().size (), padding ().size (),
282+ output_padding ().size (), dilation ().size (), func_name, transpose ());
210283
211284 const int N = act.size (0 );
212285 const int C = act.size (1 );
@@ -466,17 +539,25 @@ at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
466539 const at::Tensor& act,
467540 double output_scale,
468541 int64_t output_zero_point) {
542+ const std::string func_name = transpose () ? " quantized::conv_transpose"
543+ : " quantized::conv" ;
544+ TORCH_CHECK (!(kReluFused && transpose ()),
545+ kSpatialDim == 2 ,
546+ func_name, kSpatialDim ,
547+ " d (qnnpack): ConvTranspose cannot be fused with ReLU." );
469548 TORCH_CHECK (
470549 kSpatialDim == 2 ,
471- " quantized::conv2d (qnnpack): QNNPACK only supports Conv2d "
472- " now." );
550+ func_name, kSpatialDim ,
551+ " d (qnnpack): QNNPACK only supports Conv2d now." );
473552 ConvDimChecks<kSpatialDim >(
474- act.ndimension (), stride_.size (), padding_.size (), dilation_.size ());
553+ act.ndimension (), stride ().size (), padding ().size (),
554+ output_padding ().size (), dilation ().size (), func_name, transpose ());
475555
476556 auto * pack_w = w.get ();
477557
478558 // TODO Can be replaced with packB->getOutputChannels() when update pre-pack
479559 // to actually do the packing.
560+ const int out_ch_idx = transpose () ? 1 : 0 ;
480561 const auto out_ch = bias.size (0 );
481562 // inputs are in semantic NCHW format
482563 const int N = act.size (0 );
@@ -500,10 +581,10 @@ at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
500581
501582 // Re-quantizing the bias based on input scale and weight scale.
502583 if (!input_scale.has_value () || input_scale.value () != act_input_scale) {
503- TORCH_CHECK ( M == orig_weight.size (0 ),
504- " Output channel size of weight and bias must match." );
505- TORCH_CHECK ( C == groups_ * orig_weight.size (1 ),
584+ TORCH_CHECK (M == (transpose () ? groups () : 1 ) * orig_weight.size (out_ch_idx),
506585 " Output channel size of weight and bias must match." );
586+ TORCH_CHECK (C == (transpose () ? 1 : groups ()) * orig_weight.size (1 - out_ch_idx),
587+ " Input channel size of weight and bias must match." );
507588
508589 // Get the original weight and adjust it to uint8 from int8
509590 auto weight_contig =
@@ -574,8 +655,15 @@ at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
574655 }
575656
576657 TORCH_INTERNAL_ASSERT (pack_w != nullptr , " Packed Weights are NULL" );
577- const auto output_shape = MakeConvOutputShape<kSpatialDim >(
578- N, M, {H, W}, kernel_, stride_, padding_, dilation_);
658+ at::SmallVector<int64_t , kSpatialDim + 2 > output_shape;
659+ if (transpose ()) {
660+ output_shape = MakeDeConvOutputShape<kSpatialDim >(N, M, {H, W},
661+ kernel_, stride (), padding (), output_padding (), dilation ());
662+ } else {
663+ output_shape = MakeConvOutputShape<kSpatialDim >(N, M, {H, W},
664+ kernel_, stride (), padding (), dilation ());
665+ }
666+
579667 if (act_nhwc.numel () > 0 ) {
580668 TORCH_CHECK (
581669 std::all_of (
@@ -596,22 +684,42 @@ at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
596684 output_zero_point,
597685 c10::nullopt );
598686
599- const pytorch_qnnp_status run_status = qnnpack::qnnpackConv (
600- conv_p,
601- convolution_op.get (),
602- pack_w->getPackedWeights (),
603- N,
604- H,
605- W,
606- act_nhwc.q_zero_point (),
607- reinterpret_cast <uint8_t *>(act_nhwc.template data_ptr <c10::quint8>()),
608- w_zero_points.data (),
609- requantization_scales.data (),
610- output.q_zero_point (),
611- output_min,
612- output_max,
613- reinterpret_cast <uint8_t *>(output.template data_ptr <c10::quint8>()),
614- caffe2::pthreadpool_ ());
687+ pytorch_qnnp_status run_status;
688+ if (transpose ()) {
689+ run_status = qnnpack::qnnpackDeConv (
690+ conv_p,
691+ convolution_op.get (),
692+ pack_w->getPackedWeights (),
693+ N,
694+ H,
695+ W,
696+ act_nhwc.q_zero_point (),
697+ reinterpret_cast <uint8_t *>(act_nhwc.template data_ptr <c10::quint8>()),
698+ w_zero_points.data (),
699+ requantization_scales.data (),
700+ output.q_zero_point (),
701+ output_min,
702+ output_max,
703+ reinterpret_cast <uint8_t *>(output.template data_ptr <c10::quint8>()),
704+ caffe2::pthreadpool_ ());
705+ } else {
706+ run_status = qnnpack::qnnpackConv (
707+ conv_p,
708+ convolution_op.get (),
709+ pack_w->getPackedWeights (),
710+ N,
711+ H,
712+ W,
713+ act_nhwc.q_zero_point (),
714+ reinterpret_cast <uint8_t *>(act_nhwc.template data_ptr <c10::quint8>()),
715+ w_zero_points.data (),
716+ requantization_scales.data (),
717+ output.q_zero_point (),
718+ output_min,
719+ output_max,
720+ reinterpret_cast <uint8_t *>(output.template data_ptr <c10::quint8>()),
721+ caffe2::pthreadpool_ ());
722+ }
615723
616724 TORCH_INTERNAL_ASSERT (
617725 run_status == pytorch_qnnp_status_success,
@@ -753,11 +861,19 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
753861 m.impl (" conv2d_relu" , QConvInt8ForBC<2 , true >::run);
754862 m.impl (" conv3d" , QConvInt8ForBC<3 , false >::run);
755863 m.impl (" conv3d_relu" , QConvInt8ForBC<3 , true >::run);
864+
865+ // transpose
866+ m.impl (" conv_transpose1d" , QConv1dInt8<false >::run);
867+ m.impl (" conv_transpose2d" , QConvInt8<2 , false >::run);
756868}
757869
758870TORCH_LIBRARY_IMPL (_quantized, QuantizedCPU, m) {
759871 m.impl (" conv2d" , QConvInt8<2 , false >::run);
760872 m.impl (" conv2d_relu" , QConvInt8<2 , true >::run);
873+
874+ // transpose
875+ m.impl (" conv_transpose1d" , QConv1dInt8<false >::run);
876+ m.impl (" conv_transpose2d" , QConvInt8<2 , false >::run);
761877}
762878
763879} // namespace
0 commit comments