@@ -212,16 +212,20 @@ Tensor ConvertToChannelsLast3dTensor(const Tensor& src) {
212212
213213template <int kSpatialDim = 2 >
214214CAFFE2_API torch::jit::class_<ConvPackedParamsBase<kSpatialDim >> register_conv_params () {
215+ // Note: SerializationType order should be fixed.
216+ // See onnx/unpack_quantized_weights.cpp
215217 using SerializationType = std::tuple<
216- at::Tensor,
217- c10::optional<at::Tensor>,
218+ at::Tensor /* weight */ ,
219+ c10::optional<at::Tensor> /* bias */ ,
218220 // these are meant to be torch::List<int64_t> but
219221 // it's not supported by onnx, so we'll use Tensor as
220222 // a workaround
221- torch::List<at::Tensor>,
222- torch::List<at::Tensor>,
223- torch::List<at::Tensor>,
224- at::Tensor>;
223+ torch::List<at::Tensor> /* stride*/ ,
224+ torch::List<at::Tensor> /* padding*/ ,
225+ torch::List<at::Tensor> /* dilation*/ ,
226+ at::Tensor /* groups*/ ,
227+ at::Tensor /* transpose*/ ,
228+ torch::List<at::Tensor> /* output_padding*/ >;
225229 static auto register_conv_params =
226230 torch::jit::class_<ConvPackedParamsBase<kSpatialDim >>(
227231 " quantized" , " Conv" + c10::to_string (kSpatialDim ) + " dPackedParamsBase" )
@@ -233,46 +237,67 @@ CAFFE2_API torch::jit::class_<ConvPackedParamsBase<kSpatialDim>> register_conv_p
233237 std::tie (weight, bias) = params->unpack ();
234238 torch::List<at::Tensor> stride;
235239 torch::List<at::Tensor> padding;
240+ torch::List<at::Tensor> output_padding;
236241 torch::List<at::Tensor> dilation;
237242 at::Tensor groups;
243+ at::Tensor transpose;
238244 for (int64_t s : params->stride ()) {
239245 stride.emplace_back (at::tensor (s));
240246 }
241247 for (int64_t p : params->padding ()) {
242248 padding.emplace_back (at::tensor (p));
243249 }
250+ for (int64_t p : params->output_padding ()) {
251+ output_padding.emplace_back (at::tensor (p));
252+ }
244253 for (int64_t d : params->dilation ()) {
245254 dilation.emplace_back (at::tensor (d));
246255 }
247256 groups = at::tensor (params->groups ());
257+ transpose = at::tensor ((uint8_t )params->transpose ());
248258 return std::make_tuple (
249259 std::move (weight),
250260 std::move (bias),
251261 stride,
252262 padding,
253263 dilation,
254- groups);
264+ groups,
265+ transpose,
266+ output_padding);
255267 },
256268 [](SerializationType state)
257269 -> c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim >> { // __setstate__
258270 at::Tensor weight;
259271 c10::optional<at::Tensor> bias;
260272 torch::List<at::Tensor> stride_tensor, padding_tensor,
261- dilation_tensor;
273+ output_padding_tensor, dilation_tensor;
262274 at::Tensor groups_tensor;
263- torch::List<int64_t > stride, padding, dilation;
275+ at::Tensor transpose_tensor;
276+ torch::List<int64_t > stride, padding, output_padding, dilation;
264277 int64_t groups;
265- std::tie (weight, bias, stride_tensor, padding_tensor, dilation_tensor, groups_tensor) = state;
278+ uint8_t transpose;
279+ std::tie (weight,
280+ bias,
281+ stride_tensor,
282+ padding_tensor,
283+ dilation_tensor,
284+ groups_tensor,
285+ transpose_tensor,
286+ output_padding_tensor) = state;
266287 for (at::Tensor s : stride_tensor) {
267288 stride.emplace_back (s[0 ].item <int64_t >());
268289 }
269290 for (at::Tensor p : padding_tensor) {
270291 padding.emplace_back (p[0 ].item <int64_t >());
271292 }
293+ for (at::Tensor p : output_padding_tensor) {
294+ output_padding.emplace_back (p[0 ].item <int64_t >());
295+ }
272296 for (at::Tensor d : dilation_tensor) {
273297 dilation.emplace_back (d[0 ].item <int64_t >());
274298 }
275299 groups = groups_tensor[0 ].item <int64_t >();
300+ transpose = transpose_tensor[0 ].item <uint8_t >();
276301 auto & ctx = at::globalContext ();
277302
278303#ifdef USE_FBGEMM
@@ -282,8 +307,10 @@ CAFFE2_API torch::jit::class_<ConvPackedParamsBase<kSpatialDim>> register_conv_p
282307 bias,
283308 stride,
284309 padding,
310+ output_padding,
285311 dilation,
286- groups);
312+ groups,
313+ transpose);
287314 }
288315#endif // USE_FBGEMM
289316#ifdef USE_PYTORCH_QNNPACK
@@ -297,8 +324,10 @@ CAFFE2_API torch::jit::class_<ConvPackedParamsBase<kSpatialDim>> register_conv_p
297324 bias,
298325 stride,
299326 padding,
327+ output_padding,
300328 dilation,
301- groups);
329+ groups,
330+ transpose);
302331 }
303332#endif // USE_PYTORCH_QNNPACK
304333 TORCH_CHECK (
0 commit comments