11from __future__ import absolute_import , division , print_function , unicode_literals
22
3- import warnings
3+ import torch
4+ import torch .nn as nn
45from abc import ABCMeta , abstractmethod
56from functools import partial
7+ import warnings
68
79from torch ._jit_internal import Optional , List
8- import torch
9- import torch .nn as nn
10-
11-
12- ABC = ABCMeta (str ("ABC" ), (object ,), {}) # compatible with Python 2 *and* 3:
1310
11+ ABC = ABCMeta (str ('ABC' ), (object ,), {}) # compatible with Python 2 *and* 3:
1412
1513class ObserverBase (ABC , nn .Module ):
1614 r"""Observer base Module
@@ -22,9 +20,7 @@ class ObserverBase(ABC, nn.Module):
2220 the collected statistics.
2321 """
2422
25- def __init__ (
26- self , dtype = torch .quint8 , qscheme = torch .per_tensor_affine , reduce_range = False
27- ):
23+ def __init__ (self , dtype = torch .quint8 , qscheme = torch .per_tensor_affine , reduce_range = False ):
2824 super (ObserverBase , self ).__init__ ()
2925 self .dtype = dtype
3026 self .qscheme = qscheme
@@ -56,10 +52,8 @@ def _calculate_qparams(self, min_val, max_val):
5652 """
5753
5854 if max_val is None or min_val is None :
59- warnings .warn (
60- "must run observer before calling calculate_qparams.\
61- Returning default scale and zero point "
62- )
55+ warnings .warn ("must run observer before calling calculate_qparams.\
56+ Returning default scale and zero point " )
6357 return torch .tensor ([1.0 ]), torch .tensor ([0 ])
6458
6559 assert min_val <= max_val , "min {} should be less than max {}" .format (
@@ -108,10 +102,7 @@ class MinMaxObserver(ObserverBase):
108102 calculate_qparams will calculate scale and zero_point
109103 """
110104
111- __annotations__ = {
112- "min_val" : Optional [torch .Tensor ],
113- "max_val" : Optional [torch .Tensor ],
114- }
105+ __annotations__ = {'min_val' : Optional [torch .Tensor ], 'max_val' : Optional [torch .Tensor ]}
115106
116107 def __init__ (self , ** kwargs ):
117108 # For x86 quantized kernels, we need to ensure that the vpmaddubsw instruction
@@ -124,14 +115,8 @@ def __init__(self, **kwargs):
124115 super (MinMaxObserver , self ).__init__ (** kwargs )
125116 self .min_val = None
126117 self .max_val = None
127- if (
128- self .qscheme == torch .per_tensor_symmetric
129- and self .reduce_range
130- and self .dtype == torch .quint8
131- ):
132- raise NotImplementedError (
133- "Cannot reduce range for symmetric quantization for quint8"
134- )
118+ if self .qscheme == torch .per_tensor_symmetric and self .reduce_range and self .dtype == torch .quint8 :
119+ raise NotImplementedError ("Cannot reduce range for symmetric quantization for quint8" )
135120
136121 def forward (self , x ):
137122 min_val = self .min_val
@@ -152,69 +137,7 @@ def calculate_qparams(self):
152137
153138 @torch .jit .export
154139 def extra_repr (self ):
155- return "min_val={}, max_val={}" .format (self .min_val , self .max_val )
156-
157-
158- class HistogramObserver (ObserverBase ):
159- r"""
160- The module records the running histogram of tensor values along with
161- min/max values. calculate_qparams will calculate scale and zero_point
162- """
163-
164- __annotations__ = {
165- "min_val" : Optional [torch .Tensor ],
166- "max_val" : Optional [torch .Tensor ],
167- "histogram" : Optional [torch .Tensor ],
168- }
169-
170- def __init__ (self , bins = 2048 , ** kwargs ):
171- super (HistogramObserver , self ).__init__ (** kwargs )
172- self .bins = bins
173- self .histogram = None
174- self .min_val = None
175- self .max_val = None
176-
177- def forward (self , x ):
178- min_val = self .min_val
179- max_val = self .max_val
180- histogram = self .histogram
181- if min_val is None or max_val is None or histogram is None :
182- min_val = torch .min (x )
183- max_val = torch .max (x )
184- range = max_val - min_val
185- self .min_val = min_val - 0.5 * range
186- self .max_val = max_val + 0.5 * range
187- self .histogram = torch .histc (
188- x , self .bins , min = min_val - 0.5 * range , max = max_val + 0.5 * range
189- )
190- else :
191- if min_val < torch .min (x ) or max_val > torch .max (x ):
192- warnings .warn ("Incoming data is outside the min_val/max_val range." )
193- new_histogram = torch .histc (
194- x , self .bins , min = min_val , max = max_val
195- )
196- self .histogram = new_histogram + histogram
197-
198- @torch .jit .export
199- def calculate_qparams (self ):
200- min_val = self .min_val
201- max_val = self .max_val
202- histogram = self .histogram
203-
204- if min_val is None or max_val is None or histogram is None :
205- return self ._calculate_qparams (None , None )
206- else :
207- histogram_mask = torch .gt (histogram , 0 ).to (torch .int8 )
208- c = torch .cumsum (histogram_mask , 0 )
209- # Last non-zero bin
210- max_bin = torch .argmax (histogram_mask )
211- # Only one entry is non-zero, find it.
212- min_bin = torch .argmax (torch .eq (c , 1 ))
213- bin_width = (max_val - min_val ) / histogram .size ()[0 ]
214- new_min = min_val + min_bin * bin_width
215- new_max = min_val + (max_bin + 1 ) * bin_width
216- return self ._calculate_qparams (new_min , new_max )
217-
140+ return 'min_val={}, max_val={}' .format (self .min_val , self .max_val )
218141
219142
220143class TensorObserver (ObserverBase ):
@@ -245,7 +168,6 @@ def get_tensor_value(self):
245168def observer (observer_cls , ** kwargs ):
246169 return partial (observer_cls , ** kwargs )
247170
248-
249171def default_observer (** kwargs ):
250172 # Restrict activations to be in the range (0,127)
251173 kwargs .setdefault ("reduce_range" , True )
0 commit comments