Skip to content

Commit cca923c

Browse files
jerryzh168facebook-github-bot
authored andcommitted
Add dequantize_linear for JIT pass (#20107)
Summary: Pull Request resolved: #20107 att Reviewed By: nishantpdce Differential Revision: D15202187 fbshipit-source-id: 7d6274a67fcca695c0425587f35046fecbc2ccdc
1 parent cc02a1a commit cca923c

File tree

10 files changed

+86
-7
lines changed

10 files changed

+86
-7
lines changed

aten/src/ATen/Dispatch.h

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <ATen/core/Tensor.h>
4+
#include <c10/macros/Macros.h>
45
#include <c10/util/Half.h>
56
#include <c10/util/Exception.h>
67
#include <ATen/core/DeprecatedTypeProperties.h>
@@ -11,6 +12,13 @@
1112
return __VA_ARGS__(); \
1213
}
1314

15+
#define AT_QINT_PRIVATE_CASE_TYPE(enum_type, type, underlying_type, ...) \
16+
case enum_type: { \
17+
using scalar_t C10_UNUSED = type; \
18+
using underlying_t C10_UNUSED = underlying_type; \
19+
return __VA_ARGS__(); \
20+
}
21+
1422
namespace detail {
1523

1624
template <at::ScalarType N>
@@ -211,14 +219,14 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {}
211219
#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \
212220
[&] { \
213221
switch (TYPE) { \
214-
AT_PRIVATE_CASE_TYPE( \
215-
at::ScalarType::QInt8, qint8, __VA_ARGS__) \
216-
AT_PRIVATE_CASE_TYPE( \
217-
at::ScalarType::QUInt8, quint8, __VA_ARGS__) \
218-
AT_PRIVATE_CASE_TYPE( \
219-
at::ScalarType::QInt32, qint32, __VA_ARGS__) \
222+
AT_QINT_PRIVATE_CASE_TYPE( \
223+
at::ScalarType::QInt8, qint8, int8_t, __VA_ARGS__) \
224+
AT_QINT_PRIVATE_CASE_TYPE( \
225+
at::ScalarType::QUInt8, quint8, uint8_t, __VA_ARGS__) \
226+
AT_QINT_PRIVATE_CASE_TYPE( \
227+
at::ScalarType::QInt32, qint32, int, __VA_ARGS__) \
220228
default: \
221-
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
229+
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
222230
} \
223231
}()
224232

aten/src/ATen/core/Tensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,7 @@ class CAFFE2_API Tensor {
583583
Tensor to_mkldnn() const;
584584
Tensor quantize_linear(double scale, int64_t zero_point, ScalarType dtype) const;
585585
Tensor dequantize() const;
586+
Tensor dequantize_linear(double scale, int64_t zero_point, ScalarType dtype) const;
586587
Scalar q_scale() const;
587588
Scalar q_zero_point() const;
588589
Tensor int_repr() const;

aten/src/ATen/core/TensorMethods.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,9 @@ inline Tensor Tensor::quantize_linear(double scale, int64_t zero_point, ScalarTy
804804
inline Tensor Tensor::dequantize() const {
805805
return dispatch_type().dequantize(*this);
806806
}
807+
inline Tensor Tensor::dequantize_linear(double scale, int64_t zero_point, ScalarType dtype) const {
808+
return dispatch_type().dequantize_linear(*this, scale, zero_point, dtype);
809+
}
807810
inline Scalar Tensor::q_scale() const {
808811
return dispatch_type().q_scale(*this);
809812
}

aten/src/ATen/core/Type.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ struct CAFFE2_API Type {
393393
virtual Tensor to_mkldnn(const Tensor & self) const = 0;
394394
virtual Tensor quantize_linear(const Tensor & self, double scale, int64_t zero_point, ScalarType dtype) const = 0;
395395
virtual Tensor dequantize(const Tensor & self) const = 0;
396+
virtual Tensor dequantize_linear(const Tensor & self, double scale, int64_t zero_point, ScalarType dtype) const = 0;
396397
virtual Scalar q_scale(const Tensor & self) const = 0;
397398
virtual Scalar q_zero_point(const Tensor & self) const = 0;
398399
virtual Tensor int_repr(const Tensor & self) const = 0;

aten/src/ATen/native/native_functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2551,6 +2551,11 @@
25512551
dispatch:
25522552
QuantizedCPU: dequantize_quant
25532553

2554+
- func: dequantize_linear(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor
2555+
variants: function, method
2556+
dispatch:
2557+
CPU: dequantize_linear_cpu
2558+
25542559
- func: q_scale(Tensor self) -> Scalar
25552560
variants: function, method
25562561
dispatch:

aten/src/ATen/native/quantized/QTensor.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,21 @@ Tensor dequantize_quant(const Tensor& self) {
1616
return get_qtensorimpl(self)->quantizer()->dequantize(self);
1717
}
1818

19+
Tensor dequantize_linear_cpu(const Tensor& self, double scale, int64_t zero_point, ScalarType dtype) {
20+
AT_CHECK(isQIntType(toQIntType(self.scalar_type())),
21+
"Scalar type for quantized Tensor must have same underlying type as input.");
22+
AT_CHECK(dtype == ScalarType::Float, "ScalarType for target Tensor must be float.");
23+
Tensor f = at::empty(self.sizes(), self.options().dtype(dtype));
24+
AT_DISPATCH_QINT_TYPES(
25+
toQIntType(self.scalar_type()), "dequantize_linear_cpu", [&]() {
26+
underlying_t* qdata = self.data<underlying_t>();
27+
auto* fdata = f.data<float>();
28+
for (int i = 0; i < self.numel(); ++i) {
29+
fdata[i] = (static_cast<float>(qdata[i]) - zero_point) * scale;
30+
}});
31+
return f;
32+
}
33+
1934
Scalar q_scale_quant(const Tensor& self) {
2035
auto quantizer = get_qtensorimpl(self)->quantizer();
2136
AT_ASSERT(quantizer->qscheme() == kPerTensorAffine);

c10/core/ScalarType.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,36 @@ static inline bool isQIntType(ScalarType t) {
234234
return t == ScalarType:: QInt8 || t == ScalarType::QUInt8 || t == ScalarType::QInt32;
235235
}
236236

237+
static inline ScalarType toQIntType(ScalarType t) {
238+
switch (t) {
239+
case ScalarType::Byte:
240+
return ScalarType::QUInt8;
241+
case ScalarType::Char:
242+
return ScalarType::QInt8;
243+
case ScalarType::Int:
244+
return ScalarType::QInt32;
245+
default:
246+
return t;
247+
}
248+
}
249+
250+
static inline ScalarType toUnderlying(ScalarType t) {
251+
switch (t) {
252+
case ScalarType::QUInt8:
253+
return ScalarType::Byte;
254+
case ScalarType::QInt8:
255+
return ScalarType::Char;
256+
case ScalarType::QInt32:
257+
return ScalarType::Int;
258+
default:
259+
return t;
260+
}
261+
}
262+
263+
static inline bool isUnderlying(ScalarType type, ScalarType qtype) {
264+
return type == toUnderlying(qtype);
265+
}
266+
237267
static inline ScalarType promoteTypes(ScalarType a, ScalarType b) {
238268
// This is generated according to NumPy's promote_types
239269
constexpr auto u1 = ScalarType::Byte;

docs/source/tensors.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ view of a storage and defines numeric operations on it.
209209
.. automethod:: cumsum
210210
.. automethod:: data_ptr
211211
.. automethod:: dequantize
212+
.. automethod:: dequantize_linear
212213
.. automethod:: det
213214
.. automethod:: dense_dim
214215
.. automethod:: detach

test/test_torch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2804,6 +2804,12 @@ def test_qtensor_dtypes(self):
28042804
rqr = qr.dequantize()
28052805
self.assertTrue(np.allclose(r.numpy(), rqr.numpy(), atol=2 / scale))
28062806

2807+
def test_qtensor_dequantize_linear(self):
2808+
t = torch.arange(-10, 10, dtype=torch.int8)
2809+
scale = 3
2810+
zero_point = 2
2811+
qt = torch.dequantize_linear(t, scale, zero_point, torch.float)
2812+
28072813

28082814
@unittest.skipIf(torch.cuda.device_count() < 2, 'fewer than 2 GPUs detected')
28092815
def test_device_guard(self):

torch/_tensor_docs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3031,6 +3031,15 @@ def callable(a, b) -> number
30313031
See :func:`torch.det`
30323032
""")
30333033

3034+
add_docstr_all('dequantize_linear',
3035+
r"""
3036+
dequantize_linear(int_tensor, scale, zero_point) -> Tensor
3037+
3038+
Dequantize an int Tensor that represents the underlying quantized data
3039+
using affine quantization scheme with given scale and zero_point.
3040+
returns a float Tensor.
3041+
""")
3042+
30343043
add_docstr_all('where',
30353044
r"""
30363045
where(condition, y) -> Tensor

0 commit comments

Comments
 (0)