Skip to content

Commit 65b3f2c

Browse files
Zafarxuzhao9
authored andcommitted
[quant] conv_transpose1d / conv_transpose2d (#40370)
Summary: Pull Request resolved: #40370 Test Plan: Imported from OSS Reviewed By: vkuzo Differential Revision: D22158979 Pulled By: z-a-f fbshipit-source-id: f5cb812c9953efa7608f06cf0188de447f73f358
1 parent ae4880c commit 65b3f2c

File tree

4 files changed

+316
-34
lines changed

4 files changed

+316
-34
lines changed

aten/src/ATen/native/quantized/cpu/qconv.cpp

Lines changed: 150 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,43 +12,113 @@
1212
#include <ATen/native/quantized/cpu/conv_packed_params.h>
1313
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>
1414

15+
namespace {
16+
// To have a sanity check for maximum matrix size.
17+
constexpr int64_t kReasonableMaxDim = 1000000;
18+
}
19+
1520
template <int kSpatialDim = 2>
1621
bool ConvDimChecks(
1722
int64_t act_dims,
1823
int64_t stride_dims,
1924
int64_t padding_dims,
20-
int64_t dilation_dims) {
25+
int64_t output_padding_dims,
26+
int64_t dilation_dims,
27+
std::string func_name,
28+
bool transpose = false) {
2129
TORCH_CHECK(
2230
act_dims == kSpatialDim + 2,
23-
"quantized::conv",
31+
func_name,
2432
kSpatialDim,
2533
"d(): Expected activation tensor to have ",
2634
kSpatialDim + 2,
27-
" dimensions.");
35+
" dimensions, got ",
36+
act_dims);
2837
TORCH_CHECK(
2938
stride_dims == kSpatialDim,
30-
"quantized::conv",
39+
func_name,
3140
kSpatialDim,
3241
"d(): Expected stride tensor to have ",
3342
kSpatialDim,
34-
" dimensions.");
43+
" dimensions, got ",
44+
stride_dims);
3545
TORCH_CHECK(
3646
padding_dims == kSpatialDim,
37-
"quantized::conv",
47+
func_name,
3848
kSpatialDim,
3949
"d(): Expected padding tensor to have ",
4050
kSpatialDim,
41-
" dimensions.");
51+
" dimensions, got ",
52+
padding_dims);
53+
TORCH_CHECK(
54+
!transpose || (output_padding_dims == kSpatialDim),
55+
func_name,
56+
kSpatialDim,
57+
"d(): Expected output padding tensor to have ",
58+
kSpatialDim,
59+
" dimensions, got ",
60+
output_padding_dims);
4261
TORCH_CHECK(
4362
dilation_dims == kSpatialDim,
44-
"quantized::conv",
63+
func_name,
4564
kSpatialDim,
4665
"d(): Expected dilation tensor to have ",
4766
kSpatialDim,
48-
" dimensions.");
67+
" dimensions, got ",
68+
dilation_dims);
4969
return true;
5070
}
5171

72+
inline int64_t compute_deconv_shape(int64_t input,
73+
int64_t kernel,
74+
int64_t stride,
75+
int64_t input_padding,
76+
int64_t output_padding,
77+
int64_t dilation) {
78+
int64_t out = (input - 1) * stride - 2 * input_padding
79+
+ dilation * (kernel - 1) + output_padding + 1;
80+
return out;
81+
}
82+
83+
template <int64_t kSpatialDim>
84+
at::SmallVector<int64_t, kSpatialDim + 2> MakeDeConvOutputShape(
85+
int64_t N, int64_t M,
86+
const std::vector<int64_t>& input_shape,
87+
const std::vector<int64_t>& kernel,
88+
const torch::List<int64_t>& stride,
89+
const torch::List<int64_t>& input_padding,
90+
const torch::List<int64_t>& output_padding,
91+
const torch::List<int64_t>& dilation) {
92+
at::SmallVector<int64_t, kSpatialDim + 2> output_shape;
93+
output_shape.resize(kSpatialDim + 2);
94+
output_shape[0] = N; // Batch size
95+
output_shape[1] = M; // Output channels
96+
for (int64_t idx = 0; idx < kSpatialDim; ++idx) {
97+
output_shape[idx + 2] = compute_deconv_shape(input_shape[idx],
98+
kernel[idx],
99+
stride[idx],
100+
input_padding[idx],
101+
output_padding[idx],
102+
dilation[idx]);
103+
TORCH_CHECK(output_shape[idx + 2] > 0,
104+
"Output dimension is zero for ", idx, " axis;"
105+
" kernel: ", kernel[idx],
106+
", stride: ", stride[idx],
107+
", input padding: ", input_padding[idx],
108+
", output padding: ", output_padding[idx],
109+
", dilation: ", dilation[idx])
110+
TORCH_CHECK(output_shape[idx + 2] < kReasonableMaxDim,
111+
"Output dimension is beyound reasonable maximum for ", idx,
112+
" axis;"
113+
" kernel: ", kernel[idx],
114+
", stride: ", stride[idx],
115+
", input padding: ", input_padding[idx],
116+
", output padding: ", output_padding[idx],
117+
", dilation: ", dilation[idx]);
118+
}
119+
return output_shape;
120+
}
121+
52122
#ifdef USE_FBGEMM
53123

54124
template <int kSpatialDim = 2>
@@ -203,10 +273,13 @@ at::Tensor PackedConvWeight<kSpatialDim>::apply_impl(
203273
//
204274
// This might change when full memory format support lands
205275
// See https://github.com/pytorch/pytorch/issues/23403
276+
const std::string func_name = transpose() ? "quantized::conv_transpose"
277+
: "quantized::conv";
206278
TORCH_CHECK(
207279
fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
208280
ConvDimChecks<kSpatialDim>(
209-
act.ndimension(), stride_.size(), padding_.size(), dilation_.size());
281+
act.ndimension(), stride().size(), padding().size(),
282+
output_padding().size(), dilation().size(), func_name, transpose());
210283

211284
const int N = act.size(0);
212285
const int C = act.size(1);
@@ -466,17 +539,25 @@ at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
466539
const at::Tensor& act,
467540
double output_scale,
468541
int64_t output_zero_point) {
542+
const std::string func_name = transpose() ? "quantized::conv_transpose"
543+
: "quantized::conv";
544+
TORCH_CHECK(!(kReluFused && transpose()),
545+
kSpatialDim == 2,
546+
func_name, kSpatialDim,
547+
"d (qnnpack): ConvTranspose cannot be fused with ReLU.");
469548
TORCH_CHECK(
470549
kSpatialDim == 2,
471-
"quantized::conv2d (qnnpack): QNNPACK only supports Conv2d "
472-
"now.");
550+
func_name, kSpatialDim,
551+
"d (qnnpack): QNNPACK only supports Conv2d now.");
473552
ConvDimChecks<kSpatialDim>(
474-
act.ndimension(), stride_.size(), padding_.size(), dilation_.size());
553+
act.ndimension(), stride().size(), padding().size(),
554+
output_padding().size(), dilation().size(), func_name, transpose());
475555

476556
auto* pack_w = w.get();
477557

478558
// TODO Can be replaced with packB->getOutputChannels() when update pre-pack
479559
// to actually do the packing.
560+
const int out_ch_idx = transpose() ? 1 : 0;
480561
const auto out_ch = bias.size(0);
481562
// inputs are in semantic NCHW format
482563
const int N = act.size(0);
@@ -500,10 +581,10 @@ at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
500581

501582
// Re-quantizing the bias based on input scale and weight scale.
502583
if (!input_scale.has_value() || input_scale.value() != act_input_scale) {
503-
TORCH_CHECK( M == orig_weight.size(0),
504-
"Output channel size of weight and bias must match.");
505-
TORCH_CHECK( C == groups_ * orig_weight.size(1),
584+
TORCH_CHECK(M == (transpose() ? groups() : 1) * orig_weight.size(out_ch_idx),
506585
"Output channel size of weight and bias must match.");
586+
TORCH_CHECK(C == (transpose() ? 1 : groups()) * orig_weight.size(1 - out_ch_idx),
587+
"Input channel size of weight and bias must match.");
507588

508589
// Get the original weight and adjust it to uint8 from int8
509590
auto weight_contig =
@@ -574,8 +655,15 @@ at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
574655
}
575656

576657
TORCH_INTERNAL_ASSERT(pack_w != nullptr, "Packed Weights are NULL");
577-
const auto output_shape = MakeConvOutputShape<kSpatialDim>(
578-
N, M, {H, W}, kernel_, stride_, padding_, dilation_);
658+
at::SmallVector<int64_t, kSpatialDim + 2> output_shape;
659+
if (transpose()) {
660+
output_shape = MakeDeConvOutputShape<kSpatialDim>(N, M, {H, W},
661+
kernel_, stride(), padding(), output_padding(), dilation());
662+
} else {
663+
output_shape = MakeConvOutputShape<kSpatialDim>(N, M, {H, W},
664+
kernel_, stride(), padding(), dilation());
665+
}
666+
579667
if (act_nhwc.numel() > 0) {
580668
TORCH_CHECK(
581669
std::all_of(
@@ -596,22 +684,42 @@ at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
596684
output_zero_point,
597685
c10::nullopt);
598686

599-
const pytorch_qnnp_status run_status = qnnpack::qnnpackConv(
600-
conv_p,
601-
convolution_op.get(),
602-
pack_w->getPackedWeights(),
603-
N,
604-
H,
605-
W,
606-
act_nhwc.q_zero_point(),
607-
reinterpret_cast<uint8_t*>(act_nhwc.template data_ptr<c10::quint8>()),
608-
w_zero_points.data(),
609-
requantization_scales.data(),
610-
output.q_zero_point(),
611-
output_min,
612-
output_max,
613-
reinterpret_cast<uint8_t*>(output.template data_ptr<c10::quint8>()),
614-
caffe2::pthreadpool_());
687+
pytorch_qnnp_status run_status;
688+
if (transpose()) {
689+
run_status = qnnpack::qnnpackDeConv(
690+
conv_p,
691+
convolution_op.get(),
692+
pack_w->getPackedWeights(),
693+
N,
694+
H,
695+
W,
696+
act_nhwc.q_zero_point(),
697+
reinterpret_cast<uint8_t*>(act_nhwc.template data_ptr<c10::quint8>()),
698+
w_zero_points.data(),
699+
requantization_scales.data(),
700+
output.q_zero_point(),
701+
output_min,
702+
output_max,
703+
reinterpret_cast<uint8_t*>(output.template data_ptr<c10::quint8>()),
704+
caffe2::pthreadpool_());
705+
} else {
706+
run_status = qnnpack::qnnpackConv(
707+
conv_p,
708+
convolution_op.get(),
709+
pack_w->getPackedWeights(),
710+
N,
711+
H,
712+
W,
713+
act_nhwc.q_zero_point(),
714+
reinterpret_cast<uint8_t*>(act_nhwc.template data_ptr<c10::quint8>()),
715+
w_zero_points.data(),
716+
requantization_scales.data(),
717+
output.q_zero_point(),
718+
output_min,
719+
output_max,
720+
reinterpret_cast<uint8_t*>(output.template data_ptr<c10::quint8>()),
721+
caffe2::pthreadpool_());
722+
}
615723

616724
TORCH_INTERNAL_ASSERT(
617725
run_status == pytorch_qnnp_status_success,
@@ -753,11 +861,19 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
753861
m.impl("conv2d_relu", QConvInt8ForBC<2, true>::run);
754862
m.impl("conv3d", QConvInt8ForBC<3, false>::run);
755863
m.impl("conv3d_relu", QConvInt8ForBC<3, true>::run);
864+
865+
// transpose
866+
m.impl("conv_transpose1d", QConv1dInt8<false>::run);
867+
m.impl("conv_transpose2d", QConvInt8<2, false>::run);
756868
}
757869

758870
TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
759871
m.impl("conv2d", QConvInt8<2, false>::run);
760872
m.impl("conv2d_relu", QConvInt8<2, true>::run);
873+
874+
// transpose
875+
m.impl("conv_transpose1d", QConv1dInt8<false>::run);
876+
m.impl("conv_transpose2d", QConvInt8<2, false>::run);
761877
}
762878

763879
} // namespace

aten/src/ATen/native/quantized/library.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ TORCH_LIBRARY(quantized, m) {
9090
m.def("conv3d_dilation(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]");
9191
m.def("conv3d_groups(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int");
9292
// conv_tranpsose
93+
m.def("conv_transpose1d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor");
94+
m.def("conv_transpose2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor");
9395
m.def("conv_transpose1d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase");
9496
m.def("conv_transpose2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase");
9597
m.def("conv_transpose1d_unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)");

0 commit comments

Comments
 (0)