Skip to content

Commit de14d4a

Browse files
author
Zafar
committed
[quant] Prep for conv_transpose packing
ghstack-source-id: 939b91c Pull Request resolved: #39714
1 parent 3040572 commit de14d4a

File tree

5 files changed

+135
-35
lines changed

5 files changed

+135
-35
lines changed

aten/src/ATen/native/quantized/cpu/conv_packed_params.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ struct ConvPackedParamsBase : public torch::jit::CustomClassHolder {
1818

1919
virtual torch::List<int64_t> stride() const = 0;
2020
virtual torch::List<int64_t> padding() const = 0;
21+
virtual torch::List<int64_t> output_padding() const = 0;
2122
virtual torch::List<int64_t> dilation() const = 0;
2223
virtual int64_t groups() const = 0;
24+
virtual bool transpose() const = 0;
2325
};

aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ CAFFE2_API torch::jit::class_<ConvPackedParamsBase<kSpatialDim>> register_conv_p
296296
}
297297
groups = scalars_tensor[1].item<int64_t>();
298298
switch (version) {
299-
case 2: break;
299+
case 2: break; // V2 is already covered, skipping output_padding
300300
case 3: {
301301
for (; idx < 4 * kSpatialDim; ++idx) {
302302
at::Tensor p = params1_tensor[idx];
@@ -319,8 +319,10 @@ CAFFE2_API torch::jit::class_<ConvPackedParamsBase<kSpatialDim>> register_conv_p
319319
bias,
320320
stride,
321321
padding,
322+
/*output_padding=*/padding,
322323
dilation,
323-
groups);
324+
groups,
325+
/*transpose=*/false);
324326
}
325327
#endif // USE_FBGEMM
326328
#ifdef USE_PYTORCH_QNNPACK
@@ -334,8 +336,10 @@ CAFFE2_API torch::jit::class_<ConvPackedParamsBase<kSpatialDim>> register_conv_p
334336
bias,
335337
stride,
336338
padding,
339+
/*output_padding=*/padding,
337340
dilation,
338-
groups);
341+
groups,
342+
/*transpose=*/false);
339343
}
340344
#endif // USE_PYTORCH_QNNPACK
341345
TORCH_CHECK(

aten/src/ATen/native/quantized/cpu/fbgemm_utils.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,10 @@ struct CAFFE2_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
123123
c10::optional<at::Tensor> bias,
124124
torch::List<int64_t> stride,
125125
torch::List<int64_t> padding,
126+
torch::List<int64_t> output_padding,
126127
torch::List<int64_t> dilation,
127128
int64_t groups,
129+
uint8_t transpose,
128130
std::vector<int32_t> col_offsets,
129131
std::vector<int64_t> kernel,
130132
std::vector<float> w_scale,
@@ -134,8 +136,10 @@ struct CAFFE2_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
134136
bias(std::move(bias)),
135137
stride_(std::move(stride)),
136138
padding_(std::move(padding)),
139+
output_padding_(std::move(output_padding)),
137140
dilation_(std::move(dilation)),
138141
groups_(groups),
142+
transpose_(transpose),
139143
col_offsets(std::move(col_offsets)),
140144
kernel(std::move(kernel)),
141145
w_scale(std::move(w_scale)),
@@ -146,8 +150,10 @@ struct CAFFE2_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
146150
c10::optional<at::Tensor> bias;
147151
torch::List<int64_t> stride_;
148152
torch::List<int64_t> padding_;
153+
torch::List<int64_t> output_padding_;
149154
torch::List<int64_t> dilation_;
150155
int64_t groups_;
156+
uint8_t transpose_;
151157
std::vector<int32_t> col_offsets;
152158
std::vector<int64_t> kernel;
153159
std::vector<float> w_scale;
@@ -171,8 +177,10 @@ struct CAFFE2_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
171177
c10::optional<at::Tensor> bias,
172178
torch::List<int64_t> stride,
173179
torch::List<int64_t> padding,
180+
torch::List<int64_t> output_padding,
174181
torch::List<int64_t> dilation,
175-
int64_t groups);
182+
int64_t groups,
183+
bool transpose);
176184

177185
const float* GetBiasData(at::Tensor* bias);
178186

@@ -190,6 +198,10 @@ struct CAFFE2_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
190198
return padding_;
191199
}
192200

201+
torch::List<int64_t> output_padding() const override {
202+
return output_padding_;
203+
}
204+
193205
torch::List<int64_t> dilation() const override {
194206
return dilation_;
195207
}
@@ -198,6 +210,10 @@ struct CAFFE2_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
198210
return groups_;
199211
}
200212

213+
bool transpose() const override {
214+
return (bool)transpose_;
215+
}
216+
201217
private:
202218
template <bool ReluFused>
203219
at::Tensor apply_impl(

aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp

Lines changed: 86 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,16 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeight<
1919
c10::optional<at::Tensor> bias,
2020
torch::List<int64_t> stride,
2121
torch::List<int64_t> padding,
22+
torch::List<int64_t> output_padding,
2223
torch::List<int64_t> dilation,
23-
int64_t groups) {
24+
int64_t groups,
25+
bool transpose) {
26+
TORCH_CHECK(!transpose, "FBGEMM doesn't supprort conv_transpose yet.")
2427
TORCH_CHECK(
2528
weight.ndimension() == kSpatialDim + 2,
2629
"Weights are expected to have ",
2730
kSpatialDim + 2,
2831
" dimensions");
29-
3032
TORCH_CHECK(
3133
stride.size() == kSpatialDim,
3234
"stride should contain ",
@@ -45,7 +47,8 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeight<
4547
" elements for ",
4648
kSpatialDim,
4749
"D convolution.");
48-
const int output_channels = weight.size(0);
50+
const int output_channels_idx = transpose ? 1 : 0;
51+
const int output_channels = weight.size(output_channels_idx);
4952
const int input_channels_per_group = weight.size(1);
5053
const int kernel_d = kSpatialDim == 2 ? 1 : weight.size(2);
5154
const int kernel_h = weight.size(kSpatialDim);
@@ -143,8 +146,10 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeight<
143146
bias_contig,
144147
stride,
145148
padding,
149+
output_padding,
146150
dilation,
147151
groups,
152+
transpose,
148153
col_offsets,
149154
kSpatialDim == 2 ? std::vector<int64_t>{kernel_h, kernel_w}
150155
: std::vector<int64_t>{kernel_d, kernel_h, kernel_w},
@@ -166,28 +171,42 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsQnnp<
166171
c10::optional<at::Tensor> bias_in,
167172
torch::List<int64_t> stride,
168173
torch::List<int64_t> padding,
174+
torch::List<int64_t> output_padding,
169175
torch::List<int64_t> dilation,
170-
int64_t groups) {
176+
int64_t groups,
177+
bool transpose) {
178+
TORCH_CHECK(kSpatialDim == 2, "QNNPACK packing only supports 2D ",
179+
"convolution.");
180+
TORCH_CHECK(
181+
weight.ndimension() == kSpatialDim + 2,
182+
"quantized::conv_prepack (qnnpack): Weights are expected to have ",
183+
kSpatialDim + 2, "dimensions");
171184
TORCH_CHECK(
172-
weight.ndimension() == 4,
173-
"quantized::conv2d_prepack (qnnpack): Weights are expected to have 4 "
174-
"dimensions");
185+
stride.size() == kSpatialDim,
186+
"quantized::conv_prepack (qnnpack): ",
187+
kSpatialDim, "D convolution expects stride to have ",
188+
kSpatialDim, " elements.");
175189
TORCH_CHECK(
176-
stride.size() == 2,
177-
"quantized::conv2d_prepack (qnnpack): 2D convolution only");
190+
padding.size() == kSpatialDim,
191+
"quantized::conv_prepack (qnnpack): Specify top/left input padding "
192+
"only. bottom/right padding assumed to be equal to top/left");
178193
TORCH_CHECK(
179-
padding.size() == 2,
180-
"quantized::conv2d_prepack (qnnpack): Specify top/left padding only. "
181-
"bottom/right padding assumed to be equal to top/left");
194+
output_padding.size() == kSpatialDim,
195+
"quantized::conv_prepack (qnnpack): Specify top/left output padding "
196+
"only. bottom/right padding assumed to be equal to top/left");
182197
TORCH_CHECK(
183-
dilation.size() == 2,
184-
" quantized::conv2d_prepack (qnnpack): 2D convolution only");
198+
dilation.size() == kSpatialDim,
199+
"quantized::conv_prepack (qnnpack): ",
200+
kSpatialDim, "D convolution expects dilation to have ",
201+
kSpatialDim, " elements.");
185202

186203
at::native::initQNNPACK();
187204

188205
// QNNPACK expects weights to be of the format {out_c, kH, kW, in_c/groups},
189206
// but PyTorch lays them out as {out_c, in_c/groups, kH, kW}
190-
const size_t out_ch = weight.size(0);
207+
// (or for ConvTranspose {in_c, out_c/groups, kH, kW})
208+
const size_t out_ch_idx = transpose ? 1 : 0;
209+
const size_t out_ch = weight.size(out_ch_idx);
191210
const uint32_t kernel_h = weight.size(2);
192211
const uint32_t kernel_w = weight.size(3);
193212

@@ -228,8 +247,10 @@ c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightsQnnp<
228247
bias_fp32.contiguous(), /* fp32 bias */
229248
stride,
230249
padding,
250+
output_padding,
231251
dilation,
232252
groups,
253+
transpose,
233254
c10::nullopt, /* input_scale */
234255
{kernel_h, kernel_w},
235256
w_scales,
@@ -248,18 +269,38 @@ namespace {
248269
template <int kSpatialDim = 2>
249270
class QConvPackWeightInt8 final {
250271
public:
251-
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> run(
272+
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> run_conv(
252273
Tensor weight,
253274
c10::optional<Tensor> bias,
254275
torch::List<int64_t> stride,
255276
torch::List<int64_t> padding,
256277
torch::List<int64_t> dilation,
257278
int64_t groups) {
279+
torch::List<int64_t> output_padding;
280+
output_padding.reserve(kSpatialDim);
281+
for (int idx = 0; idx < kSpatialDim; ++idx) {
282+
output_padding.push_back((int64_t)0);
283+
}
284+
return _run(weight, bias, stride, padding, output_padding, dilation, groups,
285+
false);
286+
}
287+
288+
private:
289+
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> _run(
290+
Tensor weight,
291+
c10::optional<Tensor> bias,
292+
torch::List<int64_t> stride,
293+
torch::List<int64_t> padding,
294+
torch::List<int64_t> output_padding,
295+
torch::List<int64_t> dilation,
296+
int64_t groups,
297+
bool transpose) {
258298
auto& ctx = at::globalContext();
259299
#ifdef USE_FBGEMM
260300
if (ctx.qEngine() == at::QEngine::FBGEMM) {
261301
return PackedConvWeight<kSpatialDim>::prepack(
262-
weight, bias, stride, padding, dilation, groups);
302+
weight, bias, stride, padding, output_padding, dilation, groups,
303+
transpose);
263304
}
264305
#endif
265306

@@ -270,7 +311,8 @@ class QConvPackWeightInt8 final {
270311
"quantized::conv_prepack (qnnpack): QNNPACK only supports Conv1d "
271312
"and Conv2d now.");
272313
return PackedConvWeightsQnnp<kSpatialDim>::prepack(
273-
weight, bias, stride, padding, dilation, groups);
314+
weight, bias, stride, padding, output_padding, dilation, groups,
315+
transpose);
274316
}
275317
#endif
276318

@@ -283,31 +325,49 @@ class QConvPackWeightInt8 final {
283325

284326
class QConv1dPackWeightInt8 final {
285327
public:
286-
static c10::intrusive_ptr<ConvPackedParamsBase<2>> run(
328+
static c10::intrusive_ptr<ConvPackedParamsBase<2>> run_conv(
287329
Tensor weight,
288330
c10::optional<Tensor> bias,
289331
torch::List<int64_t> stride,
290332
torch::List<int64_t> padding,
291333
torch::List<int64_t> dilation,
292334
int64_t groups) {
335+
const torch::List<int64_t> output_padding({0});
336+
return _run(weight, bias, stride, padding, output_padding, dilation, groups,
337+
false);
338+
}
339+
340+
private:
341+
static c10::intrusive_ptr<ConvPackedParamsBase<2>> _run(
342+
Tensor weight,
343+
c10::optional<Tensor> bias,
344+
torch::List<int64_t> stride,
345+
torch::List<int64_t> padding,
346+
torch::List<int64_t> output_padding,
347+
torch::List<int64_t> dilation,
348+
int64_t groups,
349+
bool transpose) {
293350
auto& ctx = at::globalContext();
294351
if (weight.dim() == 3) {
295352
weight = weight.unsqueeze(quant_utils::kConv1dSqueezeDim + 2);
296353
}
297354
stride = quant_utils::MakeArgForConv1d(stride, 1);
298355
padding = quant_utils::MakeArgForConv1d(padding, 0);
356+
output_padding = quant_utils::MakeArgForConv1d(output_padding, 0);
299357
dilation = quant_utils::MakeArgForConv1d(dilation, 1);
300358
#ifdef USE_FBGEMM
301359
if (ctx.qEngine() == at::QEngine::FBGEMM) {
302360
return PackedConvWeight<2>::prepack(
303-
weight, bias, stride, padding, dilation, groups);
361+
weight, bias, stride, padding, output_padding, dilation, groups,
362+
transpose);
304363
}
305364
#endif
306365

307366
#ifdef USE_PYTORCH_QNNPACK
308367
if (ctx.qEngine() == at::QEngine::QNNPACK) {
309368
return PackedConvWeightsQnnp<2>::prepack(
310-
weight, bias, stride, padding, dilation, groups);
369+
weight, bias, stride, padding, output_padding, dilation, groups,
370+
transpose);
311371
}
312372
#endif
313373
TORCH_CHECK(
@@ -319,14 +379,14 @@ class QConv1dPackWeightInt8 final {
319379

320380
TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
321381
// conv_prepack is deprecated, please use conv2d_prepack for 2D conv.
322-
m.impl("conv_prepack", TORCH_FN(QConvPackWeightInt8<2>::run));
323-
m.impl("conv1d_prepack", TORCH_FN(QConv1dPackWeightInt8::run));
324-
m.impl("conv2d_prepack", TORCH_FN(QConvPackWeightInt8<2>::run));
325-
m.impl("conv3d_prepack", TORCH_FN(QConvPackWeightInt8<3>::run));
382+
m.impl("conv_prepack", TORCH_FN(QConvPackWeightInt8<2>::run_conv));
383+
m.impl("conv1d_prepack", TORCH_FN(QConv1dPackWeightInt8::run_conv));
384+
m.impl("conv2d_prepack", TORCH_FN(QConvPackWeightInt8<2>::run_conv));
385+
m.impl("conv3d_prepack", TORCH_FN(QConvPackWeightInt8<3>::run_conv));
326386
}
327387

328388
TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
329-
m.impl("conv2d_prepack", TORCH_FN(QConvPackWeightInt8<2>::run));
389+
m.impl("conv2d_prepack", TORCH_FN(QConvPackWeightInt8<2>::run_conv));
330390
}
331391

332392
} // namespace

0 commit comments

Comments
 (0)