Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
7b2f3a2
[quant] conv_transpose2d
Jun 22, 2020
99ceb00
Update on "[quant] conv_transpose1d / conv_transpose2d"
Jun 22, 2020
d0ba851
Update on "[quant] conv_transpose1d / conv_transpose2d"
Jun 22, 2020
b90009c
Update on "[quant] conv_transpose1d / conv_transpose2d"
Jun 23, 2020
fae6ded
Update on "[quant] conv_transpose1d / conv_transpose2d"
Jun 26, 2020
7010baa
Update on "[quant] conv_transpose1d / conv_transpose2d"
Jun 28, 2020
5bc0586
Update on "[quant] conv_transpose1d / conv_transpose2d"
Jun 30, 2020
81cad4e
Update on "[quant] conv_transpose1d / conv_transpose2d"
Jul 1, 2020
32d7da5
Update on "[quant] conv_transpose1d / conv_transpose2d"
Jul 5, 2020
20afb0e
Update on "[quant] conv_transpose1d / conv_transpose2d"
Jul 12, 2020
7e75b40
Update on "[quant] conv_transpose1d / conv_transpose2d"
Jul 17, 2020
726558d
Update on "[quant] conv_transpose1d / conv_transpose2d"
Jul 18, 2020
3a9ab5a
Update on "[quant] conv_transpose1d / conv_transpose2d"
Jul 21, 2020
6e37f81
Update on "[quant] conv_transpose1d / conv_transpose2d"
Jul 25, 2020
c147750
Update on "[quant] conv_transpose1d / conv_transpose2d"
Jul 30, 2020
fae001c
Update on "[quant] conv_transpose1d / conv_transpose2d"
Jul 30, 2020
99f439d
Update on "[quant] conv_transpose1d / conv_transpose2d"
Aug 1, 2020
c71d664
Update on "[quant] conv_transpose1d / conv_transpose2d"
Aug 5, 2020
1bd2bb5
Update on "[quant] conv_transpose1d / conv_transpose2d"
Aug 7, 2020
f08aa35
Update on "[quant] conv_transpose1d / conv_transpose2d"
Aug 12, 2020
db5190f
Update on "[quant] conv_transpose1d / conv_transpose2d"
Aug 18, 2020
b5919ab
Update on "[quant] conv_transpose1d / conv_transpose2d"
Aug 18, 2020
d245f9e
Update on "[quant] conv_transpose1d / conv_transpose2d"
Aug 18, 2020
af987ee
Update on "[quant] conv_transpose1d / conv_transpose2d"
Aug 25, 2020
035edb2
Update on "[quant] conv_transpose1d / conv_transpose2d"
Aug 25, 2020
96eddb6
Update on "[quant] conv_transpose1d / conv_transpose2d"
Sep 2, 2020
b3da271
Update on "[quant] conv_transpose1d / conv_transpose2d"
Sep 3, 2020
f336f85
Update on "[quant] conv_transpose1d / conv_transpose2d"
Sep 10, 2020
b1b2208
Update on "[quant] conv_transpose1d / conv_transpose2d"
Sep 11, 2020
e300acd
Update on "[quant] conv_transpose1d / conv_transpose2d"
Sep 11, 2020
a2ec6fc
Update on "[quant] conv_transpose1d / conv_transpose2d"
Sep 11, 2020
d3580fa
Update on "[quant] conv_transpose1d / conv_transpose2d"
Sep 12, 2020
22e676a
Update on "[quant] conv_transpose1d / conv_transpose2d"
Sep 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 150 additions & 34 deletions aten/src/ATen/native/quantized/cpu/qconv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,113 @@
#include <ATen/native/quantized/cpu/conv_packed_params.h>
#include <caffe2/utils/threadpool/pthreadpool-cpp.h>

namespace {
// To have a sanity check for maximum matrix size.
constexpr int64_t kReasonableMaxDim = 1000000;
}

template <int kSpatialDim = 2>
bool ConvDimChecks(
int64_t act_dims,
int64_t stride_dims,
int64_t padding_dims,
int64_t dilation_dims) {
int64_t output_padding_dims,
int64_t dilation_dims,
std::string func_name,
bool transpose = false) {
TORCH_CHECK(
act_dims == kSpatialDim + 2,
"quantized::conv",
func_name,
kSpatialDim,
"d(): Expected activation tensor to have ",
kSpatialDim + 2,
" dimensions.");
" dimensions, got ",
act_dims);
TORCH_CHECK(
stride_dims == kSpatialDim,
"quantized::conv",
func_name,
kSpatialDim,
"d(): Expected stride tensor to have ",
kSpatialDim,
" dimensions.");
" dimensions, got ",
stride_dims);
TORCH_CHECK(
padding_dims == kSpatialDim,
"quantized::conv",
func_name,
kSpatialDim,
"d(): Expected padding tensor to have ",
kSpatialDim,
" dimensions.");
" dimensions, got ",
padding_dims);
TORCH_CHECK(
!transpose || (output_padding_dims == kSpatialDim),
func_name,
kSpatialDim,
"d(): Expected output padding tensor to have ",
kSpatialDim,
" dimensions, got ",
output_padding_dims);
TORCH_CHECK(
dilation_dims == kSpatialDim,
"quantized::conv",
func_name,
kSpatialDim,
"d(): Expected dilation tensor to have ",
kSpatialDim,
" dimensions.");
" dimensions, got ",
dilation_dims);
return true;
}

inline int64_t compute_deconv_shape(int64_t input,
int64_t kernel,
int64_t stride,
int64_t input_padding,
int64_t output_padding,
int64_t dilation) {
int64_t out = (input - 1) * stride - 2 * input_padding
+ dilation * (kernel - 1) + output_padding + 1;
return out;
}

template <int64_t kSpatialDim>
at::SmallVector<int64_t, kSpatialDim + 2> MakeDeConvOutputShape(
int64_t N, int64_t M,
const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& kernel,
const torch::List<int64_t>& stride,
const torch::List<int64_t>& input_padding,
const torch::List<int64_t>& output_padding,
const torch::List<int64_t>& dilation) {
at::SmallVector<int64_t, kSpatialDim + 2> output_shape;
output_shape.resize(kSpatialDim + 2);
output_shape[0] = N; // Batch size
output_shape[1] = M; // Output channels
for (int64_t idx = 0; idx < kSpatialDim; ++idx) {
output_shape[idx + 2] = compute_deconv_shape(input_shape[idx],
kernel[idx],
stride[idx],
input_padding[idx],
output_padding[idx],
dilation[idx]);
TORCH_CHECK(output_shape[idx + 2] > 0,
"Output dimension is zero for ", idx, " axis;"
" kernel: ", kernel[idx],
", stride: ", stride[idx],
", input padding: ", input_padding[idx],
", output padding: ", output_padding[idx],
", dilation: ", dilation[idx])
TORCH_CHECK(output_shape[idx + 2] < kReasonableMaxDim,
"Output dimension is beyound reasonable maximum for ", idx,
" axis;"
" kernel: ", kernel[idx],
", stride: ", stride[idx],
", input padding: ", input_padding[idx],
", output padding: ", output_padding[idx],
", dilation: ", dilation[idx]);
}
return output_shape;
}

#ifdef USE_FBGEMM

template <int kSpatialDim = 2>
Expand Down Expand Up @@ -203,10 +273,13 @@ at::Tensor PackedConvWeight<kSpatialDim>::apply_impl(
//
// This might change when full memory format support lands
// See https://github.com/pytorch/pytorch/issues/23403
const std::string func_name = transpose() ? "quantized::conv_transpose"
: "quantized::conv";
TORCH_CHECK(
fbgemm::fbgemmSupportedCPU(), "Your CPU does not support FBGEMM.");
ConvDimChecks<kSpatialDim>(
act.ndimension(), stride_.size(), padding_.size(), dilation_.size());
act.ndimension(), stride().size(), padding().size(),
output_padding().size(), dilation().size(), func_name, transpose());

const int N = act.size(0);
const int C = act.size(1);
Expand Down Expand Up @@ -466,17 +539,25 @@ at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
const at::Tensor& act,
double output_scale,
int64_t output_zero_point) {
const std::string func_name = transpose() ? "quantized::conv_transpose"
: "quantized::conv";
TORCH_CHECK(!(kReluFused && transpose()),
kSpatialDim == 2,
func_name, kSpatialDim,
"d (qnnpack): ConvTranspose cannot be fused with ReLU.");
TORCH_CHECK(
kSpatialDim == 2,
"quantized::conv2d (qnnpack): QNNPACK only supports Conv2d "
"now.");
func_name, kSpatialDim,
"d (qnnpack): QNNPACK only supports Conv2d now.");
ConvDimChecks<kSpatialDim>(
act.ndimension(), stride_.size(), padding_.size(), dilation_.size());
act.ndimension(), stride().size(), padding().size(),
output_padding().size(), dilation().size(), func_name, transpose());

auto* pack_w = w.get();

// TODO Can be replaced with packB->getOutputChannels() when update pre-pack
// to actually do the packing.
const int out_ch_idx = transpose() ? 1 : 0;
const auto out_ch = bias.size(0);
// inputs are in semantic NCHW format
const int N = act.size(0);
Expand All @@ -500,10 +581,10 @@ at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(

// Re-quantizing the bias based on input scale and weight scale.
if (!input_scale.has_value() || input_scale.value() != act_input_scale) {
TORCH_CHECK( M == orig_weight.size(0),
"Output channel size of weight and bias must match.");
TORCH_CHECK( C == groups_ * orig_weight.size(1),
TORCH_CHECK(M == (transpose() ? groups() : 1) * orig_weight.size(out_ch_idx),
"Output channel size of weight and bias must match.");
TORCH_CHECK(C == (transpose() ? 1 : groups()) * orig_weight.size(1 - out_ch_idx),
"Input channel size of weight and bias must match.");

// Get the original weight and adjust it to uint8 from int8
auto weight_contig =
Expand Down Expand Up @@ -574,8 +655,15 @@ at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
}

TORCH_INTERNAL_ASSERT(pack_w != nullptr, "Packed Weights are NULL");
const auto output_shape = MakeConvOutputShape<kSpatialDim>(
N, M, {H, W}, kernel_, stride_, padding_, dilation_);
at::SmallVector<int64_t, kSpatialDim + 2> output_shape;
if (transpose()) {
output_shape = MakeDeConvOutputShape<kSpatialDim>(N, M, {H, W},
kernel_, stride(), padding(), output_padding(), dilation());
} else {
output_shape = MakeConvOutputShape<kSpatialDim>(N, M, {H, W},
kernel_, stride(), padding(), dilation());
}

if (act_nhwc.numel() > 0) {
TORCH_CHECK(
std::all_of(
Expand All @@ -596,22 +684,42 @@ at::Tensor PackedConvWeightsQnnp<kSpatialDim>::apply_impl(
output_zero_point,
c10::nullopt);

const pytorch_qnnp_status run_status = qnnpack::qnnpackConv(
conv_p,
convolution_op.get(),
pack_w->getPackedWeights(),
N,
H,
W,
act_nhwc.q_zero_point(),
reinterpret_cast<uint8_t*>(act_nhwc.template data_ptr<c10::quint8>()),
w_zero_points.data(),
requantization_scales.data(),
output.q_zero_point(),
output_min,
output_max,
reinterpret_cast<uint8_t*>(output.template data_ptr<c10::quint8>()),
caffe2::pthreadpool_());
pytorch_qnnp_status run_status;
if (transpose()) {
run_status = qnnpack::qnnpackDeConv(
conv_p,
convolution_op.get(),
pack_w->getPackedWeights(),
N,
H,
W,
act_nhwc.q_zero_point(),
reinterpret_cast<uint8_t*>(act_nhwc.template data_ptr<c10::quint8>()),
w_zero_points.data(),
requantization_scales.data(),
output.q_zero_point(),
output_min,
output_max,
reinterpret_cast<uint8_t*>(output.template data_ptr<c10::quint8>()),
caffe2::pthreadpool_());
} else {
run_status = qnnpack::qnnpackConv(
conv_p,
convolution_op.get(),
pack_w->getPackedWeights(),
N,
H,
W,
act_nhwc.q_zero_point(),
reinterpret_cast<uint8_t*>(act_nhwc.template data_ptr<c10::quint8>()),
w_zero_points.data(),
requantization_scales.data(),
output.q_zero_point(),
output_min,
output_max,
reinterpret_cast<uint8_t*>(output.template data_ptr<c10::quint8>()),
caffe2::pthreadpool_());
}

TORCH_INTERNAL_ASSERT(
run_status == pytorch_qnnp_status_success,
Expand Down Expand Up @@ -753,11 +861,19 @@ TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
m.impl("conv2d_relu", QConvInt8ForBC<2, true>::run);
m.impl("conv3d", QConvInt8ForBC<3, false>::run);
m.impl("conv3d_relu", QConvInt8ForBC<3, true>::run);

// transpose
m.impl("conv_transpose1d", QConv1dInt8<false>::run);
m.impl("conv_transpose2d", QConvInt8<2, false>::run);
}

TORCH_LIBRARY_IMPL(_quantized, QuantizedCPU, m) {
m.impl("conv2d", QConvInt8<2, false>::run);
m.impl("conv2d_relu", QConvInt8<2, true>::run);

// transpose
m.impl("conv_transpose1d", QConv1dInt8<false>::run);
m.impl("conv_transpose2d", QConvInt8<2, false>::run);
}

} // namespace
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/native/quantized/library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ TORCH_LIBRARY(quantized, m) {
m.def("conv3d_dilation(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int[]");
m.def("conv3d_groups(__torch__.torch.classes.quantized.Conv3dPackedParamsBase packed_weights) -> int");
// conv_tranpsose
m.def("conv_transpose1d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor");
m.def("conv_transpose2d(Tensor qx, __torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weight, float output_scale, int output_zero_point) -> Tensor");
m.def("conv_transpose1d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase");
m.def("conv_transpose2d_prepack(Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] output_padding, int[] dilation, int groups) -> __torch__.torch.classes.quantized.Conv2dPackedParamsBase");
m.def("conv_transpose1d_unpack(__torch__.torch.classes.quantized.Conv2dPackedParamsBase packed_weights) -> (Tensor unpacked_weights, Tensor? B_origin)");
Expand Down
Loading