Skip to content
22 changes: 19 additions & 3 deletions test/quantization/test_workflow_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,14 @@ def test_per_tensor_dynamic_quant_observers(self, X, reduce_range):
self.assertEqual(ref[0], qparams[0])
self.assertEqual(ref[1], qparams[1])


@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric)),
qscheme=st.sampled_from((torch.per_channel_affine, torch.per_channel_symmetric, torch.per_channel_affine_float_qparams)),
ch_axis=st.sampled_from((0, 1, 2, 3)), reduce_range=st.booleans())
def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range):
# reduce_range cannot be true for symmetric quantization with uint8
if qscheme == torch.per_channel_affine_float_qparams:
reduce_range = False
if qdtype == torch.quint8 and qscheme == torch.per_channel_symmetric:
reduce_range = False
ObserverList = [PerChannelMinMaxObserver(reduce_range=reduce_range,
Expand Down Expand Up @@ -338,13 +341,22 @@ def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range):
[-26, -128],
[-35, -58],
]
per_channel_affine_float_qparams_ref_scales = [
[0.0196, 0.0471],
[0.0353, 0.0196],
[0.0392, 0.0235],
[0.0431, 0.0431],
]
per_channel_affine_quint8_zp = [[0, 85], [113, 0], [102, 0], [93, 70]]

self.assertEqual(myobs.min_vals, ref_min_vals[ch_axis])
self.assertEqual(myobs.max_vals, ref_max_vals[ch_axis])
if qscheme == torch.per_channel_symmetric:
ref_scales = per_channel_symmetric_ref_scales[ch_axis]
ref_zero_points = [0, 0] if qdtype is torch.qint8 else [128, 128]
elif qscheme == torch.per_channel_affine_float_qparams:
ref_scales = per_channel_affine_float_qparams_ref_scales[ch_axis]
ref_zero_points = [-1 * ref_min_vals[ch_axis][i] / ref_scales[i] for i in range(len(ref_scales))]
else:
ref_scales = per_channel_affine_ref_scales[ch_axis]
ref_zero_points = (
Expand All @@ -356,9 +368,12 @@ def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range):
if reduce_range:
ref_scales = [s * 255 / 127 for s in ref_scales]
ref_zero_points = [math.floor(z / 2) for z in ref_zero_points]
self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype), atol=0.0001))
if qscheme == torch.per_channel_affine_float_qparams:
self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype), atol=1))
else:
self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype)))

self.assertTrue(torch.allclose(qparams[0], torch.tensor(ref_scales, dtype=qparams[0].dtype)))
self.assertTrue(torch.allclose(qparams[1], torch.tensor(ref_zero_points, dtype=qparams[1].dtype)))

# Test for serializability
state_dict = myobs.state_dict()
Expand All @@ -375,6 +390,7 @@ def test_per_channel_observers(self, qdtype, qscheme, ch_axis, reduce_range):
self.assertEqual(myobs.max_vals, loaded_obs.max_vals)
self.assertEqual(myobs.calculate_qparams(), loaded_obs.calculate_qparams())


def test_observer_scriptable(self):
obs_list = [MinMaxObserver(), MovingAverageMinMaxObserver(), MinMaxDynamicQuantObserver()]
for obs in obs_list:
Expand Down
39 changes: 25 additions & 14 deletions torch/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
torch.per_tensor_symmetric,
torch.per_channel_affine,
torch.per_channel_symmetric,
torch.per_channel_affine_float_qparams,
), "Default Observer only works for per_tensor_affine, \
per_tensor_symmetric, per_channel_affine and \
per_channel_symmetric quantization scheme"
per_tensor_symmetric, per_channel_affine, \
per_channel_symmetric and per_channel_float_qparams quantization scheme"
assert self.dtype in (
torch.qint8,
torch.quint8,
Expand Down Expand Up @@ -213,39 +214,49 @@ def _calculate_qparams(self, min_val, max_val):
)

qmin, qmax = self._calculate_qmin_qmax()
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))

min_val = torch.min(min_val, torch.zeros_like(min_val))
max_val = torch.max(max_val, torch.zeros_like(max_val))

scale = torch.ones(min_val.size(), dtype=torch.float32)
zero_point = torch.zeros(min_val.size(), dtype=torch.int64)
device = 'cuda' if min_val.is_cuda else 'cpu'
scale = torch.ones(min_val_neg.size(), dtype=torch.float32)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64)
device = 'cuda' if min_val_neg.is_cuda else 'cpu'

if self.qscheme == torch.per_tensor_symmetric or self.qscheme == torch.per_channel_symmetric:
max_val = torch.max(-min_val, max_val)
scale = max_val / (float(qmax - qmin) / 2)
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(qmax - qmin) / 2)
scale = torch.max(scale, torch.tensor(self.eps, device=device, dtype=scale.dtype))
if self.dtype == torch.quint8:
if self.is_dynamic_qrange:
# When dynamic quantization range is used, down-rounded midpoint of the range is chosen.
zero_point = zero_point.new_full(zero_point.size(), (qmin + qmax) // 2)
else:
zero_point = zero_point.new_full(zero_point.size(), 128)
else:
elif self.qscheme == torch.per_channel_affine_float_qparams:
scale = (max_val - min_val) / float(qmax - qmin)
scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
# We use the quantize function
# xq = Round(Xf * inv_scale + zero_point),
# setting zero_point to (-1 * min *inv_scale) we get
# Xq = Round((Xf - min) * inv_scale)
zero_point = -1 * min_val / scale
else:
scale = (max_val_pos - min_val_neg) / float(qmax - qmin)
scale = torch.max(scale, torch.tensor(self.eps, device=device, dtype=scale.dtype))
zero_point = qmin - torch.round(min_val / scale)
zero_point = qmin - torch.round(min_val_neg / scale)
zero_point = torch.max(zero_point, torch.tensor(qmin, device=device, dtype=zero_point.dtype))
zero_point = torch.min(zero_point, torch.tensor(qmax, device=device, dtype=zero_point.dtype))

# For scalar values, cast them to Tensors of size 1 to keep the shape
# consistent with default values in FakeQuantize.
if len(scale.shape) == 0:
# TODO: switch to scale.item() after adding JIT support
scale = torch.tensor([float(scale)], dtype=scale.dtype)
scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
if len(zero_point.shape) == 0:
# TODO: switch to zero_point.item() after adding JIT support
zero_point = torch.tensor([int(zero_point)], dtype=zero_point.dtype)
zero_point = torch.tensor([int(zero_point)], dtype=zero_point.dtype, device=device)
if self.qscheme == torch.per_channel_affine_float_qparams:
zero_point = torch.tensor([float(zero_point)], dtype=zero_point.dtype, device=device)


return scale, zero_point

Expand Down