Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 3 additions & 35 deletions test/test_quantization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import unittest
import torch
import torch.nn as nn
Expand All @@ -11,7 +9,7 @@
QConfig_dynamic, default_weight_observer, dump_tensor,\
quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules, \
quantize_dynamic, default_qconfig, default_debug_qconfig, default_qat_qconfig, \
default_dynamic_qconfig, QuantWrapper, TensorObserver, MinMaxObserver, HistogramObserver
default_dynamic_qconfig, MinMaxObserver, TensorObserver, QuantWrapper

from common_utils import run_tests
from common_quantization import QuantizationTestCase, SingleLayerLinearModel, \
Expand Down Expand Up @@ -777,8 +775,8 @@ def test_minmax_observer(self, qdtype, qscheme, reduce_range):
self.assertEqual(qparams[1].item(), ref_zero_point)
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)

@given(obs=st.sampled_from((torch.quantization.default_observer()(), HistogramObserver(bins=10))))
def test_observer_scriptable(self, obs):
def test_observer_scriptable(self):
obs = torch.quantization.default_observer()()
scripted = torch.jit.script(obs)

x = torch.rand(3, 4)
Expand Down Expand Up @@ -829,35 +827,5 @@ def test_tensor_observer_scriptable(self, qdtype, qscheme):
loaded = torch.jit.load(buf)
self.assertTrue(torch.equal(obs.get_tensor_value()[0], loaded.get_tensor_value()[0]))

@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
reduce_range=st.booleans())
def test_histogram_observer(self, qdtype, qscheme, reduce_range):
myobs = HistogramObserver(bins=10, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0])
y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0])
myobs(x)
myobs(y)
self.assertEqual(myobs.min_val, -1.5)
self.assertEqual(myobs.max_val, 8.5)
self.assertEqual(myobs.histogram, [0., 0., 1., 2., 1., 2., 3., 2., 1., 1.])
qparams = myobs.calculate_qparams()
if reduce_range:
if qscheme == torch.per_tensor_symmetric:
ref_scale = 0.066666 * 255 / 127
ref_zero_point = 0 if qdtype is torch.qint8 else 128
else:
ref_scale = 0.0333333 * 255 / 127
ref_zero_point = -64 if qdtype is torch.qint8 else 0
else:
if qscheme == torch.per_tensor_symmetric:
ref_scale = 0.066666
ref_zero_point = 0 if qdtype is torch.qint8 else 128
else:
ref_scale = 0.0333333
ref_zero_point = -128 if qdtype is torch.qint8 else 0
self.assertEqual(qparams[1].item(), ref_zero_point)
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)

if __name__ == '__main__':
run_tests()
100 changes: 11 additions & 89 deletions torch/quantization/observer.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import warnings
import torch
import torch.nn as nn
from abc import ABCMeta, abstractmethod
from functools import partial
import warnings

from torch._jit_internal import Optional, List
import torch
import torch.nn as nn


ABC = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3:

ABC = ABCMeta(str('ABC'), (object,), {}) # compatible with Python 2 *and* 3:

class ObserverBase(ABC, nn.Module):
r"""Observer base Module
Expand All @@ -22,9 +20,7 @@ class ObserverBase(ABC, nn.Module):
the collected statistics.
"""

def __init__(
self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False
):
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False):
super(ObserverBase, self).__init__()
self.dtype = dtype
self.qscheme = qscheme
Expand Down Expand Up @@ -56,10 +52,8 @@ def _calculate_qparams(self, min_val, max_val):
"""

if max_val is None or min_val is None:
warnings.warn(
"must run observer before calling calculate_qparams.\
Returning default scale and zero point "
)
warnings.warn("must run observer before calling calculate_qparams.\
Returning default scale and zero point ")
return torch.tensor([1.0]), torch.tensor([0])

assert min_val <= max_val, "min {} should be less than max {}".format(
Expand Down Expand Up @@ -108,10 +102,7 @@ class MinMaxObserver(ObserverBase):
calculate_qparams will calculate scale and zero_point
"""

__annotations__ = {
"min_val": Optional[torch.Tensor],
"max_val": Optional[torch.Tensor],
}
__annotations__ = {'min_val' : Optional[torch.Tensor], 'max_val' : Optional[torch.Tensor]}

def __init__(self, **kwargs):
# For x86 quantized kernels, we need to ensure that the vpmaddubsw instruction
Expand All @@ -124,14 +115,8 @@ def __init__(self, **kwargs):
super(MinMaxObserver, self).__init__(**kwargs)
self.min_val = None
self.max_val = None
if (
self.qscheme == torch.per_tensor_symmetric
and self.reduce_range
and self.dtype == torch.quint8
):
raise NotImplementedError(
"Cannot reduce range for symmetric quantization for quint8"
)
if self.qscheme == torch.per_tensor_symmetric and self.reduce_range and self.dtype == torch.quint8:
raise NotImplementedError("Cannot reduce range for symmetric quantization for quint8")

def forward(self, x):
min_val = self.min_val
Expand All @@ -152,69 +137,7 @@ def calculate_qparams(self):

@torch.jit.export
def extra_repr(self):
return "min_val={}, max_val={}".format(self.min_val, self.max_val)


class HistogramObserver(ObserverBase):
r"""
The module records the running histogram of tensor values along with
min/max values. calculate_qparams will calculate scale and zero_point
"""

__annotations__ = {
"min_val": Optional[torch.Tensor],
"max_val": Optional[torch.Tensor],
"histogram": Optional[torch.Tensor],
}

def __init__(self, bins=2048, **kwargs):
super(HistogramObserver, self).__init__(**kwargs)
self.bins = bins
self.histogram = None
self.min_val = None
self.max_val = None

def forward(self, x):
min_val = self.min_val
max_val = self.max_val
histogram = self.histogram
if min_val is None or max_val is None or histogram is None:
min_val = torch.min(x)
max_val = torch.max(x)
range = max_val - min_val
self.min_val = min_val - 0.5 * range
self.max_val = max_val + 0.5 * range
self.histogram = torch.histc(
x, self.bins, min=min_val - 0.5 * range, max=max_val + 0.5 * range
)
else:
if min_val < torch.min(x) or max_val > torch.max(x):
warnings.warn("Incoming data is outside the min_val/max_val range.")
new_histogram = torch.histc(
x, self.bins, min=min_val, max=max_val
)
self.histogram = new_histogram + histogram

@torch.jit.export
def calculate_qparams(self):
min_val = self.min_val
max_val = self.max_val
histogram = self.histogram

if min_val is None or max_val is None or histogram is None:
return self._calculate_qparams(None, None)
else:
histogram_mask = torch.gt(histogram, 0).to(torch.int8)
c = torch.cumsum(histogram_mask, 0)
# Last non-zero bin
max_bin = torch.argmax(histogram_mask)
# Only one entry is non-zero, find it.
min_bin = torch.argmax(torch.eq(c, 1))
bin_width = (max_val - min_val) / histogram.size()[0]
new_min = min_val + min_bin * bin_width
new_max = min_val + (max_bin + 1) * bin_width
return self._calculate_qparams(new_min, new_max)

return 'min_val={}, max_val={}'.format(self.min_val, self.max_val)


class TensorObserver(ObserverBase):
Expand Down Expand Up @@ -245,7 +168,6 @@ def get_tensor_value(self):
def observer(observer_cls, **kwargs):
return partial(observer_cls, **kwargs)


def default_observer(**kwargs):
# Restrict activations to be in the range (0,127)
kwargs.setdefault("reduce_range", True)
Expand Down