Skip to content

Commit 88d20eb

Browse files
author
Zafar
committed
[quant] Prep for conv_transpose packing
ghstack-source-id: 2c7d3eb Pull Request resolved: #39714
1 parent eace053 commit 88d20eb

File tree

6 files changed

+173
-44
lines changed

6 files changed

+173
-44
lines changed

aten/src/ATen/native/quantized/cpu/conv_packed_params.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ struct ConvPackedParamsBase : public torch::jit::CustomClassHolder {
1818

1919
virtual torch::List<int64_t> stride() const = 0;
2020
virtual torch::List<int64_t> padding() const = 0;
21+
virtual torch::List<int64_t> output_padding() const = 0;
2122
virtual torch::List<int64_t> dilation() const = 0;
2223
virtual int64_t groups() const = 0;
24+
virtual bool transpose() const = 0;
2325
};

aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -212,16 +212,20 @@ Tensor ConvertToChannelsLast3dTensor(const Tensor& src) {
212212

213213
template <int kSpatialDim = 2>
214214
CAFFE2_API torch::jit::class_<ConvPackedParamsBase<kSpatialDim>> register_conv_params() {
215+
// Note: SerializationType order should be fixed.
216+
// See onnx/unpack_quantized_weights.cpp
215217
using SerializationType = std::tuple<
216-
at::Tensor,
217-
c10::optional<at::Tensor>,
218+
at::Tensor /*weight*/,
219+
c10::optional<at::Tensor> /*bias*/,
218220
// these are meant to be torch::List<int64_t> but
219221
// it's not supported by onnx, so we'll use Tensor as
220222
// a workaround
221-
torch::List<at::Tensor>,
222-
torch::List<at::Tensor>,
223-
torch::List<at::Tensor>,
224-
at::Tensor>;
223+
torch::List<at::Tensor> /*stride*/,
224+
torch::List<at::Tensor> /*padding*/,
225+
torch::List<at::Tensor> /*dilation*/,
226+
at::Tensor /*groups*/,
227+
at::Tensor /*transpose*/,
228+
torch::List<at::Tensor> /*output_padding*/>;
225229
static auto register_conv_params =
226230
torch::jit::class_<ConvPackedParamsBase<kSpatialDim>>(
227231
"quantized", "Conv" + c10::to_string(kSpatialDim) + "dPackedParamsBase")
@@ -233,46 +237,67 @@ CAFFE2_API torch::jit::class_<ConvPackedParamsBase<kSpatialDim>> register_conv_p
233237
std::tie(weight, bias) = params->unpack();
234238
torch::List<at::Tensor> stride;
235239
torch::List<at::Tensor> padding;
240+
torch::List<at::Tensor> output_padding;
236241
torch::List<at::Tensor> dilation;
237242
at::Tensor groups;
243+
at::Tensor transpose;
238244
for (int64_t s : params->stride()) {
239245
stride.emplace_back(at::tensor(s));
240246
}
241247
for (int64_t p : params->padding()) {
242248
padding.emplace_back(at::tensor(p));
243249
}
250+
for (int64_t p : params->output_padding()) {
251+
output_padding.emplace_back(at::tensor(p));
252+
}
244253
for (int64_t d : params->dilation()) {
245254
dilation.emplace_back(at::tensor(d));
246255
}
247256
groups = at::tensor(params->groups());
257+
transpose = at::tensor((uint8_t)params->transpose());
248258
return std::make_tuple(
249259
std::move(weight),
250260
std::move(bias),
251261
stride,
252262
padding,
253263
dilation,
254-
groups);
264+
groups,
265+
transpose,
266+
output_padding);
255267
},
256268
[](SerializationType state)
257269
-> c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> { // __setstate__
258270
at::Tensor weight;
259271
c10::optional<at::Tensor> bias;
260272
torch::List<at::Tensor> stride_tensor, padding_tensor,
261-
dilation_tensor;
273+
output_padding_tensor, dilation_tensor;
262274
at::Tensor groups_tensor;
263-
torch::List<int64_t> stride, padding, dilation;
275+
at::Tensor transpose_tensor;
276+
torch::List<int64_t> stride, padding, output_padding, dilation;
264277
int64_t groups;
265-
std::tie(weight, bias, stride_tensor, padding_tensor, dilation_tensor, groups_tensor) = state;
278+
uint8_t transpose;
279+
std::tie(weight,
280+
bias,
281+
stride_tensor,
282+
padding_tensor,
283+
dilation_tensor,
284+
groups_tensor,
285+
transpose_tensor,
286+
output_padding_tensor) = state;
266287
for (at::Tensor s : stride_tensor) {
267288
stride.emplace_back(s[0].item<int64_t>());
268289
}
269290
for (at::Tensor p : padding_tensor) {
270291
padding.emplace_back(p[0].item<int64_t>());
271292
}
293+
for (at::Tensor p : output_padding_tensor) {
294+
output_padding.emplace_back(p[0].item<int64_t>());
295+
}
272296
for (at::Tensor d : dilation_tensor) {
273297
dilation.emplace_back(d[0].item<int64_t>());
274298
}
275299
groups = groups_tensor[0].item<int64_t>();
300+
transpose = transpose_tensor[0].item<uint8_t>();
276301
auto& ctx = at::globalContext();
277302

278303
#ifdef USE_FBGEMM
@@ -282,8 +307,10 @@ CAFFE2_API torch::jit::class_<ConvPackedParamsBase<kSpatialDim>> register_conv_p
282307
bias,
283308
stride,
284309
padding,
310+
output_padding,
285311
dilation,
286-
groups);
312+
groups,
313+
transpose);
287314
}
288315
#endif // USE_FBGEMM
289316
#ifdef USE_PYTORCH_QNNPACK
@@ -297,8 +324,10 @@ CAFFE2_API torch::jit::class_<ConvPackedParamsBase<kSpatialDim>> register_conv_p
297324
bias,
298325
stride,
299326
padding,
327+
output_padding,
300328
dilation,
301-
groups);
329+
groups,
330+
transpose);
302331
}
303332
#endif // USE_PYTORCH_QNNPACK
304333
TORCH_CHECK(

aten/src/ATen/native/quantized/cpu/fbgemm_utils.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,10 @@ struct CAFFE2_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
123123
c10::optional<at::Tensor> bias,
124124
torch::List<int64_t> stride,
125125
torch::List<int64_t> padding,
126+
torch::List<int64_t> output_padding,
126127
torch::List<int64_t> dilation,
127128
int64_t groups,
129+
uint8_t transpose,
128130
std::vector<int32_t> col_offsets,
129131
std::vector<int64_t> kernel,
130132
std::vector<float> w_scale,
@@ -134,8 +136,10 @@ struct CAFFE2_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
134136
bias(std::move(bias)),
135137
stride_(std::move(stride)),
136138
padding_(std::move(padding)),
139+
output_padding_(std::move(output_padding)),
137140
dilation_(std::move(dilation)),
138141
groups_(groups),
142+
transpose_(transpose),
139143
col_offsets(std::move(col_offsets)),
140144
kernel(std::move(kernel)),
141145
w_scale(std::move(w_scale)),
@@ -146,8 +150,10 @@ struct CAFFE2_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
146150
c10::optional<at::Tensor> bias;
147151
torch::List<int64_t> stride_;
148152
torch::List<int64_t> padding_;
153+
torch::List<int64_t> output_padding_;
149154
torch::List<int64_t> dilation_;
150155
int64_t groups_;
156+
uint8_t transpose_;
151157
std::vector<int32_t> col_offsets;
152158
std::vector<int64_t> kernel;
153159
std::vector<float> w_scale;
@@ -171,8 +177,10 @@ struct CAFFE2_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
171177
c10::optional<at::Tensor> bias,
172178
torch::List<int64_t> stride,
173179
torch::List<int64_t> padding,
180+
torch::List<int64_t> output_padding,
174181
torch::List<int64_t> dilation,
175-
int64_t groups);
182+
int64_t groups,
183+
bool transpose);
176184

177185
const float* GetBiasData(at::Tensor* bias);
178186

@@ -190,6 +198,10 @@ struct CAFFE2_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
190198
return padding_;
191199
}
192200

201+
torch::List<int64_t> output_padding() const override {
202+
return output_padding_;
203+
}
204+
193205
torch::List<int64_t> dilation() const override {
194206
return dilation_;
195207
}
@@ -198,6 +210,10 @@ struct CAFFE2_API PackedConvWeight : public ConvPackedParamsBase<kSpatialDim> {
198210
return groups_;
199211
}
200212

213+
bool transpose() const override {
214+
return (bool)transpose_;
215+
}
216+
201217
private:
202218
template <bool ReluFused>
203219
at::Tensor apply_impl(

0 commit comments

Comments
 (0)