Skip to content

Commit 9f6b6b8

Browse files
smessmerfacebook-github-bot
authored andcommitted
Back out "[quant][observer] Add histogram observer" (#26236)
Summary: Pull Request resolved: #26236 Original diff broke oss CI. Reverting. Original commit changeset: 0f047d3349cb ghstack-source-id: 90125990 Test Plan: testinprod Reviewed By: hx89 Differential Revision: D17385490 fbshipit-source-id: 4258502bbc0e3a6dd6852c8ce01ed05eee618b1a
1 parent 3051e36 commit 9f6b6b8

File tree

2 files changed

+14
-124
lines changed

2 files changed

+14
-124
lines changed

test/test_quantization.py

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from __future__ import absolute_import, division, print_function, unicode_literals
2-
31
import unittest
42
import torch
53
import torch.nn as nn
@@ -11,7 +9,7 @@
119
QConfig_dynamic, default_weight_observer, dump_tensor,\
1210
quantize, prepare, convert, prepare_qat, quantize_qat, fuse_modules, \
1311
quantize_dynamic, default_qconfig, default_debug_qconfig, default_qat_qconfig, \
14-
default_dynamic_qconfig, QuantWrapper, TensorObserver, MinMaxObserver, HistogramObserver
12+
default_dynamic_qconfig, MinMaxObserver, TensorObserver, QuantWrapper
1513

1614
from common_utils import run_tests
1715
from common_quantization import QuantizationTestCase, SingleLayerLinearModel, \
@@ -777,8 +775,8 @@ def test_minmax_observer(self, qdtype, qscheme, reduce_range):
777775
self.assertEqual(qparams[1].item(), ref_zero_point)
778776
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
779777

780-
@given(obs=st.sampled_from((torch.quantization.default_observer()(), HistogramObserver(bins=10))))
781-
def test_observer_scriptable(self, obs):
778+
def test_observer_scriptable(self):
779+
obs = torch.quantization.default_observer()()
782780
scripted = torch.jit.script(obs)
783781

784782
x = torch.rand(3, 4)
@@ -829,35 +827,5 @@ def test_tensor_observer_scriptable(self, qdtype, qscheme):
829827
loaded = torch.jit.load(buf)
830828
self.assertTrue(torch.equal(obs.get_tensor_value()[0], loaded.get_tensor_value()[0]))
831829

832-
@given(qdtype=st.sampled_from((torch.qint8, torch.quint8)),
833-
qscheme=st.sampled_from((torch.per_tensor_affine, torch.per_tensor_symmetric)),
834-
reduce_range=st.booleans())
835-
def test_histogram_observer(self, qdtype, qscheme, reduce_range):
836-
myobs = HistogramObserver(bins=10, dtype=qdtype, qscheme=qscheme, reduce_range=reduce_range)
837-
x = torch.tensor([1.0, 2.0, 2.0, 3.0, 4.0, 5.0, 6.0])
838-
y = torch.tensor([4.0, 5.0, 5.0, 6.0, 7.0, 8.0])
839-
myobs(x)
840-
myobs(y)
841-
self.assertEqual(myobs.min_val, -1.5)
842-
self.assertEqual(myobs.max_val, 8.5)
843-
self.assertEqual(myobs.histogram, [0., 0., 1., 2., 1., 2., 3., 2., 1., 1.])
844-
qparams = myobs.calculate_qparams()
845-
if reduce_range:
846-
if qscheme == torch.per_tensor_symmetric:
847-
ref_scale = 0.066666 * 255 / 127
848-
ref_zero_point = 0 if qdtype is torch.qint8 else 128
849-
else:
850-
ref_scale = 0.0333333 * 255 / 127
851-
ref_zero_point = -64 if qdtype is torch.qint8 else 0
852-
else:
853-
if qscheme == torch.per_tensor_symmetric:
854-
ref_scale = 0.066666
855-
ref_zero_point = 0 if qdtype is torch.qint8 else 128
856-
else:
857-
ref_scale = 0.0333333
858-
ref_zero_point = -128 if qdtype is torch.qint8 else 0
859-
self.assertEqual(qparams[1].item(), ref_zero_point)
860-
self.assertAlmostEqual(qparams[0].item(), ref_scale, delta=1e-5)
861-
862830
if __name__ == '__main__':
863831
run_tests()

torch/quantization/observer.py

Lines changed: 11 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
from __future__ import absolute_import, division, print_function, unicode_literals
22

3-
import warnings
3+
import torch
4+
import torch.nn as nn
45
from abc import ABCMeta, abstractmethod
56
from functools import partial
7+
import warnings
68

79
from torch._jit_internal import Optional, List
8-
import torch
9-
import torch.nn as nn
10-
11-
12-
ABC = ABCMeta(str("ABC"), (object,), {}) # compatible with Python 2 *and* 3:
1310

11+
ABC = ABCMeta(str('ABC'), (object,), {}) # compatible with Python 2 *and* 3:
1412

1513
class ObserverBase(ABC, nn.Module):
1614
r"""Observer base Module
@@ -22,9 +20,7 @@ class ObserverBase(ABC, nn.Module):
2220
the collected statistics.
2321
"""
2422

25-
def __init__(
26-
self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False
27-
):
23+
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine, reduce_range=False):
2824
super(ObserverBase, self).__init__()
2925
self.dtype = dtype
3026
self.qscheme = qscheme
@@ -56,10 +52,8 @@ def _calculate_qparams(self, min_val, max_val):
5652
"""
5753

5854
if max_val is None or min_val is None:
59-
warnings.warn(
60-
"must run observer before calling calculate_qparams.\
61-
Returning default scale and zero point "
62-
)
55+
warnings.warn("must run observer before calling calculate_qparams.\
56+
Returning default scale and zero point ")
6357
return torch.tensor([1.0]), torch.tensor([0])
6458

6559
assert min_val <= max_val, "min {} should be less than max {}".format(
@@ -108,10 +102,7 @@ class MinMaxObserver(ObserverBase):
108102
calculate_qparams will calculate scale and zero_point
109103
"""
110104

111-
__annotations__ = {
112-
"min_val": Optional[torch.Tensor],
113-
"max_val": Optional[torch.Tensor],
114-
}
105+
__annotations__ = {'min_val' : Optional[torch.Tensor], 'max_val' : Optional[torch.Tensor]}
115106

116107
def __init__(self, **kwargs):
117108
# For x86 quantized kernels, we need to ensure that the vpmaddubsw instruction
@@ -124,14 +115,8 @@ def __init__(self, **kwargs):
124115
super(MinMaxObserver, self).__init__(**kwargs)
125116
self.min_val = None
126117
self.max_val = None
127-
if (
128-
self.qscheme == torch.per_tensor_symmetric
129-
and self.reduce_range
130-
and self.dtype == torch.quint8
131-
):
132-
raise NotImplementedError(
133-
"Cannot reduce range for symmetric quantization for quint8"
134-
)
118+
if self.qscheme == torch.per_tensor_symmetric and self.reduce_range and self.dtype == torch.quint8:
119+
raise NotImplementedError("Cannot reduce range for symmetric quantization for quint8")
135120

136121
def forward(self, x):
137122
min_val = self.min_val
@@ -152,69 +137,7 @@ def calculate_qparams(self):
152137

153138
@torch.jit.export
154139
def extra_repr(self):
155-
return "min_val={}, max_val={}".format(self.min_val, self.max_val)
156-
157-
158-
class HistogramObserver(ObserverBase):
159-
r"""
160-
The module records the running histogram of tensor values along with
161-
min/max values. calculate_qparams will calculate scale and zero_point
162-
"""
163-
164-
__annotations__ = {
165-
"min_val": Optional[torch.Tensor],
166-
"max_val": Optional[torch.Tensor],
167-
"histogram": Optional[torch.Tensor],
168-
}
169-
170-
def __init__(self, bins=2048, **kwargs):
171-
super(HistogramObserver, self).__init__(**kwargs)
172-
self.bins = bins
173-
self.histogram = None
174-
self.min_val = None
175-
self.max_val = None
176-
177-
def forward(self, x):
178-
min_val = self.min_val
179-
max_val = self.max_val
180-
histogram = self.histogram
181-
if min_val is None or max_val is None or histogram is None:
182-
min_val = torch.min(x)
183-
max_val = torch.max(x)
184-
range = max_val - min_val
185-
self.min_val = min_val - 0.5 * range
186-
self.max_val = max_val + 0.5 * range
187-
self.histogram = torch.histc(
188-
x, self.bins, min=min_val - 0.5 * range, max=max_val + 0.5 * range
189-
)
190-
else:
191-
if min_val < torch.min(x) or max_val > torch.max(x):
192-
warnings.warn("Incoming data is outside the min_val/max_val range.")
193-
new_histogram = torch.histc(
194-
x, self.bins, min=min_val, max=max_val
195-
)
196-
self.histogram = new_histogram + histogram
197-
198-
@torch.jit.export
199-
def calculate_qparams(self):
200-
min_val = self.min_val
201-
max_val = self.max_val
202-
histogram = self.histogram
203-
204-
if min_val is None or max_val is None or histogram is None:
205-
return self._calculate_qparams(None, None)
206-
else:
207-
histogram_mask = torch.gt(histogram, 0).to(torch.int8)
208-
c = torch.cumsum(histogram_mask, 0)
209-
# Last non-zero bin
210-
max_bin = torch.argmax(histogram_mask)
211-
# Only one entry is non-zero, find it.
212-
min_bin = torch.argmax(torch.eq(c, 1))
213-
bin_width = (max_val - min_val) / histogram.size()[0]
214-
new_min = min_val + min_bin * bin_width
215-
new_max = min_val + (max_bin + 1) * bin_width
216-
return self._calculate_qparams(new_min, new_max)
217-
140+
return 'min_val={}, max_val={}'.format(self.min_val, self.max_val)
218141

219142

220143
class TensorObserver(ObserverBase):
@@ -245,7 +168,6 @@ def get_tensor_value(self):
245168
def observer(observer_cls, **kwargs):
246169
return partial(observer_cls, **kwargs)
247170

248-
249171
def default_observer(**kwargs):
250172
# Restrict activations to be in the range (0,127)
251173
kwargs.setdefault("reduce_range", True)

0 commit comments

Comments
 (0)