Skip to content

Commit 1563fdb

Browse files
Haixin Liufacebook-github-bot
authored andcommitted
Add histogram observer (#23959)
Summary: Pull Request resolved: #23959 Add histogram observer that records the running histogram of tensor values along with min/max values. ghstack-source-id: 90076996 Test Plan: Added a test test_histogram_observer buck test mode/dev caffe2/test:quantization -- 'test_histogram_observer' buck test mode/dev caffe2/test:quantization -- 'test_observer_scriptable' Differential Revision: D16692835 fbshipit-source-id: 0f047d3349cb9770fad4a2b6cb346c51d9e99cd4
1 parent c6b75ce commit 1563fdb

File tree

2 files changed

+124
-14
lines changed

2 files changed

+124
-14
lines changed

test/test_quantization.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import absolute_import, division, print_function, unicode_literals
2+
13
import unittest
24
import torch
35
import torch.nn as nn
@@ -9,7 +11,7 @@
911
QConfig_dynamic, default_weight_observer, dump_tensor,\
1012
quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules, \
1113
quantize_dynamic, default_qconfig, default_debug_qconfig, default_qat_qconfig, \
12-
default_dynamic_qconfig, MinMaxObserver, TensorObserver, QuantWrapper
14+
default_dynamic_qconfig, QuantWrapper, TensorObserver, MinMaxObserver, HistogramObserver
1315

1416
from common_utils import run_tests
1517
from common_quantization import QuantizationTestCase, SingleLayerLinearModel, \
@@ -775,8 +777,8 @@ def test_minmax_observer(self, qdtype, qscheme, reduce_range):
775777
self.assertEqual(qparams[1].item(), ref_zero_point)
776778
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
777779

778-
def test_observer_scriptable(self):
779-
obs = torch.quantization.default_observer()()
780+
@given(obs=st.sampled_from((torch.quantization.default_observer()(), HistogramObserver(bins=10))))
781+
def test_observer_scriptable(self, obs):
780782
scripted = torch.jit.script(obs)
781783

782784
x = torch.rand(3, 4)
@@ -827,5 +829,35 @@ def test_tensor_observer_scriptable(self, qdtype, qscheme):
827829
loaded = torch.jit.load(buf)
828830
self.assertTrue(torch.equal(obs.get_tensor_value()[0], loaded.get_tensor_value()[0]))
829831

832+
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
833+
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
834+
reduce_range=st.booleans())
835+
def test_histogram_observer(self, qdtype, qscheme, reduce_range):
836+
myobs = HistogramObserver(bins=10, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
837+
x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0])
838+
y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0])
839+
myobs(x)
840+
myobs(y)
841+
self.assertEqual(myobs.min_val, -1.5)
842+
self.assertEqual(myobs.max_val, 8.5)
843+
self.assertEqual(myobs.histogram, [0., 0., 1., 2., 1., 2., 3., 2., 1., 1.])
844+
qparams = myobs.calculate_qparams()
845+
if reduce_range:
846+
if qscheme == torch.per_tensor_symmetric:
847+
ref_scale = 0.066666 * 255 / 127
848+
ref_zero_point = 0 if qdtype is torch.qint8 else 128
849+
else:
850+
ref_scale = 0.0333333 * 255 / 127
851+
ref_zero_point = -64 if qdtype is torch.qint8 else 0
852+
else:
853+
if qscheme == torch.per_tensor_symmetric:
854+
ref_scale = 0.066666
855+
ref_zero_point = 0 if qdtype is torch.qint8 else 128
856+
else:
857+
ref_scale = 0.0333333
858+
ref_zero_point = -128 if qdtype is torch.qint8 else 0
859+
self.assertEqual(qparams[1].item(), ref_zero_point)
860+
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
861+
830862
if __name__ == '__main__':
831863
run_tests()

torch/quantization/observer.py

Lines changed: 89 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from __future__ import absolute_import, division, print_function, unicode_literals
22

3-
import torch
4-
import torch.nn as nn
3+
import warnings
54
from abc import ABCMeta, abstractmethod
65
from functools import partial
7-
import warnings
86

97
from torch._jit_internal import Optional, List
8+
import torch
9+
import torch.nn as nn
10+
11+
12+
ABC = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3:
1013

11-
ABC = ABCMeta(str('ABC'), (object,), {}) # compatible with Python 2 *and* 3:
1214

1315
class ObserverBase(ABC, nn.Module):
1416
r"""Observer base Module
@@ -20,7 +22,9 @@ class ObserverBase(ABC, nn.Module):
2022
the collected statistics.
2123
"""
2224

23-
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False):
25+
def __init__(
26+
self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False
27+
):
2428
super(ObserverBase, self).__init__()
2529
self.dtype = dtype
2630
self.qscheme = qscheme
@@ -52,8 +56,10 @@ def _calculate_qparams(self, min_val, max_val):
5256
"""
5357

5458
if max_val is None or min_val is None:
55-
warnings.warn("must run observer before calling calculate_qparams.\
56-
Returning default scale and zero point ")
59+
warnings.warn(
60+
"must run observer before calling calculate_qparams.\
61+
Returning default scale and zero point "
62+
)
5763
return torch.tensor([1.0]), torch.tensor([0])
5864

5965
assert min_val <= max_val, "min {} should be less than max {}".format(
@@ -102,7 +108,10 @@ class MinMaxObserver(ObserverBase):
102108
calculate_qparams will calculate scale and zero_point
103109
"""
104110

105-
__annotations__ = {'min_val' : Optional[torch.Tensor], 'max_val' : Optional[torch.Tensor]}
111+
__annotations__ = {
112+
"min_val": Optional[torch.Tensor],
113+
"max_val": Optional[torch.Tensor],
114+
}
106115

107116
def __init__(self, **kwargs):
108117
# For x86 quantized kernels, we need to ensure that the vpmaddubsw instruction
@@ -115,8 +124,14 @@ def __init__(self, **kwargs):
115124
super(MinMaxObserver, self).__init__(**kwargs)
116125
self.min_val = None
117126
self.max_val = None
118-
if self.qscheme == torch.per_tensor_symmetric and self.reduce_range and self.dtype == torch.quint8:
119-
raise NotImplementedError("Cannot reduce range for symmetric quantization for quint8")
127+
if (
128+
self.qscheme == torch.per_tensor_symmetric
129+
and self.reduce_range
130+
and self.dtype == torch.quint8
131+
):
132+
raise NotImplementedError(
133+
"Cannot reduce range for symmetric quantization for quint8"
134+
)
120135

121136
def forward(self, x):
122137
min_val = self.min_val
@@ -137,7 +152,69 @@ def calculate_qparams(self):
137152

138153
@torch.jit.export
139154
def extra_repr(self):
140-
return 'min_val={}, max_val={}'.format(self.min_val, self.max_val)
155+
return "min_val={}, max_val={}".format(self.min_val, self.max_val)
156+
157+
158+
class HistogramObserver(ObserverBase):
159+
r"""
160+
The module records the running histogram of tensor values along with
161+
min/max values. calculate_qparams will calculate scale and zero_point
162+
"""
163+
164+
__annotations__ = {
165+
"min_val": Optional[torch.Tensor],
166+
"max_val": Optional[torch.Tensor],
167+
"histogram": Optional[torch.Tensor],
168+
}
169+
170+
def __init__(self, bins=2048, **kwargs):
171+
super(HistogramObserver, self).__init__(**kwargs)
172+
self.bins = bins
173+
self.histogram = None
174+
self.min_val = None
175+
self.max_val = None
176+
177+
def forward(self, x):
178+
min_val = self.min_val
179+
max_val = self.max_val
180+
histogram = self.histogram
181+
if min_val is None or max_val is None or histogram is None:
182+
min_val = torch.min(x)
183+
max_val = torch.max(x)
184+
range = max_val - min_val
185+
self.min_val = min_val - 0.5 * range
186+
self.max_val = max_val + 0.5 * range
187+
self.histogram = torch.histc(
188+
x, self.bins, min=min_val - 0.5 * range, max=max_val + 0.5 * range
189+
)
190+
else:
191+
if min_val < torch.min(x) or max_val > torch.max(x):
192+
warnings.warn("Incoming data is outside the min_val/max_val range.")
193+
new_histogram = torch.histc(
194+
x, self.bins, min=min_val, max=max_val
195+
)
196+
self.histogram = new_histogram + histogram
197+
198+
@torch.jit.export
199+
def calculate_qparams(self):
200+
min_val = self.min_val
201+
max_val = self.max_val
202+
histogram = self.histogram
203+
204+
if min_val is None or max_val is None or histogram is None:
205+
return self._calculate_qparams(None, None)
206+
else:
207+
histogram_mask = torch.gt(histogram, 0).to(torch.int8)
208+
c = torch.cumsum(histogram_mask, 0)
209+
# Last non-zero bin
210+
max_bin = torch.argmax(histogram_mask)
211+
# Only one entry is non-zero, find it.
212+
min_bin = torch.argmax(torch.eq(c, 1))
213+
bin_width = (max_val - min_val) / histogram.size()[0]
214+
new_min = min_val + min_bin * bin_width
215+
new_max = min_val + (max_bin + 1) * bin_width
216+
return self._calculate_qparams(new_min, new_max)
217+
141218

142219

143220
class TensorObserver(ObserverBase):
@@ -168,6 +245,7 @@ def get_tensor_value(self):
168245
def observer(observer_cls, **kwargs):
169246
return partial(observer_cls, **kwargs)
170247

248+
171249
def default_observer(**kwargs):
172250
# Restrict activations to be in the range (0,127)
173251
kwargs.setdefault("reduce_range", True)

0 commit comments

Comments
 (0)