Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
811b0ee
Change quantizer to account for input tensor's memory format.
kimishpatel Jul 28, 2020
42d035e
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 5, 2020
afe7be8
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 6, 2020
4ff59ef
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 7, 2020
b979075
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 10, 2020
ac9786a
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 11, 2020
072ec1a
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 11, 2020
ec085e5
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 12, 2020
65bab4e
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 13, 2020
fb5186f
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 13, 2020
9c6e33b
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 13, 2020
d3d24ed
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 14, 2020
45d2aa0
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 14, 2020
434019f
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 14, 2020
619da69
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 15, 2020
ba30e54
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 17, 2020
99b25ff
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 18, 2020
1e394c7
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 18, 2020
f9d796e
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 18, 2020
56aaf2e
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 18, 2020
7087dde
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 18, 2020
0a52957
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 18, 2020
2d38a48
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 18, 2020
049d119
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 18, 2020
802c7b3
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 18, 2020
5a6167b
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 19, 2020
4e0d517
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 19, 2020
05dcb8c
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 19, 2020
e831f8e
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 19, 2020
c349702
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 21, 2020
ecb3848
Update on "Change quantizer to account for input tensor's memory form…
kimishpatel Aug 21, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 108 additions & 25 deletions aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ namespace at {
namespace native {
namespace {

void check_tensor_memory_format(const Tensor& ref, const Tensor& other) {
TORCH_CHECK(
ref.is_contiguous(ref.suggest_memory_format()),
"Quantized tensor should be contiguous");
TORCH_CHECK(
other.is_contiguous(ref.suggest_memory_format()),
"Float tensor should be contiguous "
"in same memory format as quantizd tensor");
}

// ****************** HEY YOU! YES YOU! Read this! ********************
//
// Please read the README.md in this directory before editing this file
Expand Down Expand Up @@ -2329,6 +2339,7 @@ void quantize_tensor_per_tensor_affine_cpu(
int64_t zero_point) {
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cpu", [&]() {
check_tensor_memory_format(rtensor, qtensor);
const float* rd = rtensor.data_ptr<float>();
auto qd = reinterpret_cast<underlying_t*>(qtensor.data_ptr<scalar_t>());
fbgemm::TensorQuantizationParams qparams;
Expand Down Expand Up @@ -2357,6 +2368,7 @@ void dequantize_tensor_per_tensor_affine_cpu(
int64_t zero_point) {
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_cpu", [&]() {
check_tensor_memory_format(qtensor, rtensor);
const auto* qd =
reinterpret_cast<const underlying_t*>(qtensor.data_ptr<scalar_t>());
fbgemm::TensorQuantizationParams qparams;
Expand Down Expand Up @@ -2479,8 +2491,7 @@ void quantize_tensor_per_tensor_affine_cpu(
#if defined(__ARM_NEON__) || defined(__aarch64__)
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cpu", [&]() {
TORCH_CHECK(
rtensor.is_contiguous(), "Float tensor should be contiguous");
check_tensor_memory_format(rtensor, qtensor);
const float* const rdata = rtensor.data_ptr<float>();
quantize_tensor_arm<scalar_t>(
rdata, qtensor, rtensor.numel(), scale, zero_point);
Expand All @@ -2489,8 +2500,7 @@ void quantize_tensor_per_tensor_affine_cpu(
// Fallback path
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_tensor_affine_cpu", [&]() {
TORCH_CHECK(
rtensor.is_contiguous(), "Float tensor should be contiguous");
check_tensor_memory_format(rtensor, qtensor);
const float* const rdata = rtensor.data_ptr<float>();
auto qdata = qtensor.data_ptr<scalar_t>();
auto numel = rtensor.numel();
Expand All @@ -2508,6 +2518,7 @@ void dequantize_tensor_per_tensor_affine_cpu(
int64_t zero_point) {
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "dequantize_tensor_per_tensor_affine_cpu", [&]() {
check_tensor_memory_format(qtensor, rtensor);
const auto* qd = qtensor.data_ptr<scalar_t>();
float* rd = rtensor.data_ptr<float>();
auto numel = qtensor.numel();
Expand All @@ -2525,6 +2536,15 @@ void quantize_tensor_per_channel_affine_cpu(
Tensor scales,
Tensor zero_points,
int64_t axis) {
// TODO: channels last kernel can be made faster.
// For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
// For channels_last/3d however axis == 0 or 1.
// Since current implemntation on channels_last format does not
// cover per channel quant with arbitrary axis value, it is better
// to check and fail.
TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
"If tensor is channels_last contig then per channel quantization "
"is supported only for axis = 0 or 1.");
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_channel_affine_cpu", [&]() {
int64_t batches = size_to_dim_(axis, rtensor.sizes());
Expand All @@ -2533,15 +2553,33 @@ void quantize_tensor_per_channel_affine_cpu(
int64_t channel = rtensor.size(axis);
auto scales_data = scales.data_ptr<double>();
auto zero_points_data = zero_points.data_ptr<int64_t>();
check_tensor_memory_format(rtensor, qtensor);
const float* rdata = rtensor.data_ptr<float>();
auto qdata = qtensor.data_ptr<scalar_t>();
for (auto b = 0; b < batches; ++b) {
for (auto c = 0; c < channel; ++c) {
if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
// This code handles per channel quant when axis = 1 and
// channels_last contig.
// If axis = 0 and channels_last contig, implementation
// for channels first (NCHW) works.
for (auto b = 0; b < batches; ++b) {
for (auto e = 0; e < elements_per_channel; ++e) {
auto i = b * channel * elements_per_channel +
c * elements_per_channel + e;
qdata[i] = quantize_val<scalar_t>(
scales_data[c], zero_points_data[c], rdata[i]);
for (auto c = 0; c < channel; ++c) {
auto i = b * channel * elements_per_channel + e * channel + c;
qdata[i] = quantize_val<scalar_t>(
scales_data[c], zero_points_data[c], rdata[i]);
}
}
}
} else {
for (auto b = 0; b < batches; ++b) {
for (auto c = 0; c < channel; ++c) {
for (auto e = 0; e < elements_per_channel; ++e) {
auto i = b * channel * elements_per_channel +
c * elements_per_channel + e;
qdata[i] = quantize_val<scalar_t>(
scales_data[c], zero_points_data[c], rdata[i]);
}
}
}
}
Expand All @@ -2556,15 +2594,37 @@ void dequantize_per_channel_affine_kernel(
Tensor zero_points,
int64_t axis) {

int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel =
size_from_dim_(axis + 1, rtensor.sizes());
int64_t channel = rtensor.size(axis);
auto scales_data = scales.data_ptr<T>();
auto zero_points_data = zero_points.data_ptr<N>();
const auto* qd = qtensor.data_ptr<Q>();
float* rd = rtensor.data_ptr<float>();
// TODO: use parallel_for
// For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
// For channels_last/3d however axis == 0 or 1.
// Since current implemntation on channels_last format does not
// cover per channel quant with arbitrary axis value, it is better
// to check and fail.
TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
"If tensor is channels_last contig then per channel quantization "
"is supported only for axis = 0 or 1.");
int64_t batches = size_to_dim_(axis, rtensor.sizes());
int64_t elements_per_channel =
size_from_dim_(axis + 1, rtensor.sizes());
int64_t channel = rtensor.size(axis);
auto scales_data = scales.data_ptr<T>();
auto zero_points_data = zero_points.data_ptr<N>();
check_tensor_memory_format(qtensor, rtensor);
const auto* qd = qtensor.data_ptr<Q>();
float* rd = rtensor.data_ptr<float>();
if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
for (auto b = 0; b < batches; ++b) {
for (auto e = 0; e < elements_per_channel; ++e) {
for (auto c = 0; c < channel; ++c) {
auto i = b * channel * elements_per_channel + e * channel + c;
// We need to convert the qint8 value to float to ensure the
// subtraction subexpression returns a float
rd[i] = (static_cast<float>(qd[i].val_) - zero_points_data[c]) *
scales_data[c];
}
}
}
} else {
for (auto b = 0; b < batches; ++b) {
for (auto c = 0; c < channel; ++c) {
for (auto e = 0; e < elements_per_channel; ++e) {
Expand All @@ -2577,6 +2637,7 @@ void dequantize_per_channel_affine_kernel(
}
}
}
}
}

void dequantize_tensor_per_channel_affine_cpu(
Expand All @@ -2598,6 +2659,14 @@ void quantize_tensor_per_channel_float_qparams_cpu(
Tensor scales,
Tensor zero_points,
int64_t axis) {
// For contiguous tensors, e.g. NCHW, arbitrary axis can be used.
// For channels_last/3d however axis == 0 or 1.
// Since current implemntation on channels_last format does not
// cover per channel quant with arbitrary axis value, it is better
// to check and fail.
TORCH_CHECK(rtensor.is_contiguous() || (axis <=1),
"If tensor is channels_last contig then per channel quantization "
"is supported only for axis = 0 or 1.");
AT_DISPATCH_QINT_TYPES(
qtensor.scalar_type(), "quantize_tensor_per_channel_float_qparams_cpu", [&]() {
int64_t batches = size_to_dim_(axis, rtensor.sizes());
Expand All @@ -2606,15 +2675,29 @@ void quantize_tensor_per_channel_float_qparams_cpu(
int64_t channel = rtensor.size(axis);
auto scales_data = scales.data_ptr<float>();
auto zero_points_data = zero_points.data_ptr<float>();
check_tensor_memory_format(rtensor, qtensor);
const float* rdata = rtensor.data_ptr<float>();
auto qdata = qtensor.data_ptr<scalar_t>();
for (auto b = 0; b < batches; ++b) {
for (auto c = 0; c < channel; ++c) {
if (axis == 1 && (rtensor.is_contiguous(MemoryFormat::ChannelsLast) ||
rtensor.is_contiguous(MemoryFormat::ChannelsLast3d))) {
for (auto b = 0; b < batches; ++b) {
for (auto e = 0; e < elements_per_channel; ++e) {
auto i = b * channel * elements_per_channel +
c * elements_per_channel + e;
qdata[i] = quantize_val_float_qparams<scalar_t>(
scales_data[c], zero_points_data[c], rdata[i]);
for (auto c = 0; c < channel; ++c) {
auto i = b * channel * elements_per_channel + e * channel + c;
qdata[i] = quantize_val_float_qparams<scalar_t>(
scales_data[c], zero_points_data[c], rdata[i]);
}
}
}
} else {
for (auto b = 0; b < batches; ++b) {
for (auto c = 0; c < channel; ++c) {
for (auto e = 0; e < elements_per_channel; ++e) {
auto i = b * channel * elements_per_channel +
c * elements_per_channel + e;
qdata[i] = quantize_val_float_qparams<scalar_t>(
scales_data[c], zero_points_data[c], rdata[i]);
}
}
}
}
Expand Down
28 changes: 20 additions & 8 deletions aten/src/ATen/quantized/Quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,24 @@ Tensor PerTensorAffineQuantizer::quantize(Tensor rtensor) {
// quantizer that can be reused, so I'm using intrusive_from_this here
Tensor qtensor = new_qtensor(
rtensor.sizes(),
rtensor.options().dtype(scalar_type_),
rtensor.options()
.dtype(scalar_type_)
.memory_format(rtensor.suggest_memory_format()),
intrusive_from_this());

rtensor = rtensor.contiguous();
rtensor = rtensor.contiguous(rtensor.suggest_memory_format());
native::quantize_tensor_per_tensor_affine(
rtensor, qtensor, scale_, zero_point_);
return qtensor;
}

Tensor PerTensorAffineQuantizer::dequantize(Tensor qtensor) {
Tensor rtensor = at::empty(qtensor.sizes(), qtensor.options().dtype(at::kFloat));
qtensor = qtensor.contiguous();
Tensor rtensor = at::empty(
qtensor.sizes(),
qtensor.options()
.dtype(at::kFloat)
.memory_format(qtensor.suggest_memory_format()));
qtensor = qtensor.contiguous(qtensor.suggest_memory_format());
native::dequantize_tensor_per_tensor_affine(
qtensor, rtensor, scale_, zero_point_);
return rtensor;
Expand All @@ -142,17 +148,23 @@ Tensor PerChannelAffineQuantizer::quantize(Tensor rtensor) {
// quantizer that can be reused, so I'm using intrusive_from_this here
Tensor qtensor = new_qtensor(
rtensor.sizes(),
rtensor.options().dtype(scalar_type_),
rtensor.options()
.dtype(scalar_type_)
.memory_format(rtensor.suggest_memory_format()),
intrusive_from_this());
rtensor = rtensor.contiguous();
rtensor = rtensor.contiguous(rtensor.suggest_memory_format());
native::quantize_tensor_per_channel_affine(
rtensor, qtensor, scales_, zero_points_, axis_);
return qtensor;
}

Tensor PerChannelAffineQuantizer::dequantize(Tensor qtensor) {
Tensor rtensor = at::empty(qtensor.sizes(), qtensor.options().dtype(at::kFloat));
qtensor = qtensor.contiguous();
Tensor rtensor = at::empty(
qtensor.sizes(),
qtensor.options()
.dtype(at::kFloat)
.memory_format(qtensor.suggest_memory_format()));
qtensor = qtensor.contiguous(qtensor.suggest_memory_format());
native::dequantize_tensor_per_channel_affine(
qtensor, rtensor, scales_, zero_points_, axis_);
return rtensor;
Expand Down
2 changes: 1 addition & 1 deletion test/onnx/test_pytorch_onnx_caffe2_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def forward(self, x):
x = self.act1(x)
x = self.conv2(x)
x = self.dequant(x)
x = x.view(-1, 72).contiguous()
x = x.reshape(-1, 72).contiguous()
x = self.fc(x)
return x

Expand Down
79 changes: 78 additions & 1 deletion test/quantization/test_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,16 @@ def test_qtensor_quant_dequant(self):
scale = 0.02
zero_point = 2
for device in get_supported_device_types():
r = torch.rand(3, 2, dtype=torch.float, device=device) * 4 - 2
r = torch.rand(3, 2, 4, 5, dtype=torch.float, device=device) * 4 - 2
for memory_format in [torch.contiguous_format, torch.channels_last]:
r = r.contiguous(memory_format=memory_format)
for dtype in [torch.qint8, torch.quint8, torch.qint32]:
qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
rqr = qr.dequantize()
self.assertTrue(np.allclose(r.cpu().numpy(), rqr.cpu().numpy(), atol=2 / scale))
# Also check 5D tensors work.
for device in get_supported_device_types():
r = torch.rand(3, 2, 4, 5, 6, dtype=torch.float, device=device) * 4 - 2
for dtype in [torch.qint8, torch.quint8, torch.qint32]:
qr = torch.quantize_per_tensor(r, scale, zero_point, dtype)
rqr = qr.dequantize()
Expand Down Expand Up @@ -217,6 +226,35 @@ def test_qtensor_dtypes(self):
rqr = qr.dequantize()
self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))

def _test_quantize_per_channel(self, r, scales, zero_points, axis, float_params):

def _quantize_per_channel_ref_nd(data, scales, zero_points, float_params):
dims = data.size()
data = data.view(-1, dims[axis], np.prod(dims[axis + 1:]))
res = torch.empty_like(data)
quant_min, quant_max = 0, 255
for i in range(res.size()[0]):
for j in range(res.size()[1]):
for k in range(res.size()[2]):
if float_params:
inv_scale = 1.0 / scales[j]
res[i][j][k] = np.clip(
np.round(data[i][j][k] * inv_scale + zero_points[j]), quant_min, quant_max)
else:
res[i][j][k] = np.clip(
np.round(data[i][j][k] / scales[j]) + zero_points[j], quant_min, quant_max)
res = res.view(*dims)
return res

contig_format = torch.channels_last if r.ndim == 4 else torch.channels_last_3d
for memory_format in [torch.contiguous_format, contig_format]:
ref_res = _quantize_per_channel_ref_nd(r, scales, zero_points, float_params)
r_contig = r.contiguous(memory_format=memory_format)
qr = torch.quantize_per_channel(r_contig, scales, zero_points, axis, torch.quint8)
rqr = qr.dequantize()
self.assertTrue(np.allclose(qr.int_repr(), ref_res))
self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / np.min(scales.numpy())))

def test_qtensor_quantize_per_channel(self):
r = torch.rand(3, 2, dtype=torch.float) * 4 - 2
scales = torch.tensor([0.2, 0.03], dtype=torch.double)
Expand All @@ -235,6 +273,26 @@ def quantize_c(data, scales, zero_points):
self.assertTrue(np.allclose(qr.int_repr(), quantize_c(r, scales, zero_points)))
self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / np.min(scales.numpy())))

# Check 4D tensor with 2 different memory formats.
r = torch.rand(3, 2, 4, 5, dtype=torch.float) * 4 - 2
scales = torch.tensor([0.2, 0.03], dtype=torch.double)
zero_points = torch.tensor([5, 10], dtype=torch.long)
self._test_quantize_per_channel(r, scales, zero_points, 1 , False)

scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.double)
zero_points = torch.tensor([5, 10, 7], dtype=torch.long)
self._test_quantize_per_channel(r, scales, zero_points, 0, False)

# Check 5D tensor.
r = torch.rand(3, 2, 4, 5, 7, dtype=torch.float) * 4 - 2
scales = torch.tensor([0.2, 0.03], dtype=torch.double)
zero_points = torch.tensor([5, 10], dtype=torch.long)
self._test_quantize_per_channel(r, scales, zero_points, 1, False)

scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.double)
zero_points = torch.tensor([5, 10, 7], dtype=torch.long)
self._test_quantize_per_channel(r, scales, zero_points, 0, False)
Comment on lines +276 to +294
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: these can be in a loop as well, with r, scale, zero_points, axis being configurable


def test_quantize_per_channel_float_qparams(self):
r = torch.rand(3, 2, dtype=torch.float) * 4
scales = torch.tensor([0.2, 0.03], dtype=torch.float)
Expand All @@ -257,6 +315,25 @@ def quantize_ref(data, scales, zero_points):
self.assertTrue(np.allclose(qr.int_repr(), ref))
self.assertTrue(np.allclose(r.numpy(), dequant_tensor.numpy(), atol=1))

# Check 4D tensor with 2 different memory formats.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, I think maybe you can also merge this test with previous test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introduces unrelated changes. We should merge with previous one in a separate PR if we want to do that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, sounds good

r = torch.rand(3, 2, 4, 5, dtype=torch.float) * 4
scales = torch.tensor([0.2, 0.03], dtype=torch.float)
zero_points = torch.tensor([0.1, 0.2], dtype=torch.float)
self._test_quantize_per_channel(r, scales, zero_points, 1, True)

scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.float)
zero_points = torch.tensor([0.1, 0.2, 1.], dtype=torch.float)
self._test_quantize_per_channel(r, scales, zero_points, 0, True)

# Check 5D tensor.
r = torch.rand(3, 2, 4, 5, 7, dtype=torch.float) * 4 - 2
scales = torch.tensor([0.2, 0.03], dtype=torch.float)
zero_points = torch.tensor([0.1, 0.2], dtype=torch.float)
self._test_quantize_per_channel(r, scales, zero_points, 1, True)

scales = torch.tensor([0.2, 0.03, 0.5], dtype=torch.float)
zero_points = torch.tensor([0.1, 0.2, 1.], dtype=torch.float)
self._test_quantize_per_channel(r, scales, zero_points, 0, True)

def test_qtensor_permute(self):
scale = 0.02
Expand Down
Loading