11from __future__ import absolute_import , division , print_function , unicode_literals
22
3- import torch
4- import torch .nn as nn
3+ import warnings
54from abc import ABCMeta , abstractmethod
65from functools import partial
7- import warnings
86
97from torch ._jit_internal import Optional , List
8+ import torch
9+ import torch .nn as nn
10+
11+
12+ ABC = ABCMeta (str ("ABC" ), (object ,), {}) # compatible with Python 2 *and* 3:
1013
11- ABC = ABCMeta (str ('ABC' ), (object ,), {}) # compatible with Python 2 *and* 3:
1214
1315class ObserverBase (ABC , nn .Module ):
1416 r"""Observer base Module
@@ -20,7 +22,9 @@ class ObserverBase(ABC, nn.Module):
2022 the collected statistics.
2123 """
2224
23- def __init__ (self , dtype = torch .quint8 , qscheme = torch .per_tensor_affine , reduce_range = False ):
25+ def __init__ (
26+ self , dtype = torch .quint8 , qscheme = torch .per_tensor_affine , reduce_range = False
27+ ):
2428 super (ObserverBase , self ).__init__ ()
2529 self .dtype = dtype
2630 self .qscheme = qscheme
@@ -52,8 +56,10 @@ def _calculate_qparams(self, min_val, max_val):
5256 """
5357
5458 if max_val is None or min_val is None :
55- warnings .warn ("must run observer before calling calculate_qparams.\
56- Returning default scale and zero point " )
59+ warnings .warn (
60+ "must run observer before calling calculate_qparams.\
61+ Returning default scale and zero point "
62+ )
5763 return torch .tensor ([1.0 ]), torch .tensor ([0 ])
5864
5965 assert min_val <= max_val , "min {} should be less than max {}" .format (
@@ -102,7 +108,10 @@ class MinMaxObserver(ObserverBase):
102108 calculate_qparams will calculate scale and zero_point
103109 """
104110
105- __annotations__ = {'min_val' : Optional [torch .Tensor ], 'max_val' : Optional [torch .Tensor ]}
111+ __annotations__ = {
112+ "min_val" : Optional [torch .Tensor ],
113+ "max_val" : Optional [torch .Tensor ],
114+ }
106115
107116 def __init__ (self , ** kwargs ):
108117 # For x86 quantized kernels, we need to ensure that the vpmaddubsw instruction
@@ -115,8 +124,14 @@ def __init__(self, **kwargs):
115124 super (MinMaxObserver , self ).__init__ (** kwargs )
116125 self .min_val = None
117126 self .max_val = None
118- if self .qscheme == torch .per_tensor_symmetric and self .reduce_range and self .dtype == torch .quint8 :
119- raise NotImplementedError ("Cannot reduce range for symmetric quantization for quint8" )
127+ if (
128+ self .qscheme == torch .per_tensor_symmetric
129+ and self .reduce_range
130+ and self .dtype == torch .quint8
131+ ):
132+ raise NotImplementedError (
133+ "Cannot reduce range for symmetric quantization for quint8"
134+ )
120135
121136 def forward (self , x ):
122137 min_val = self .min_val
@@ -137,7 +152,69 @@ def calculate_qparams(self):
137152
138153 @torch .jit .export
139154 def extra_repr (self ):
140- return 'min_val={}, max_val={}' .format (self .min_val , self .max_val )
155+ return "min_val={}, max_val={}" .format (self .min_val , self .max_val )
156+
157+
158+ class HistogramObserver (ObserverBase ):
159+ r"""
160+ The module records the running histogram of tensor values along with
161+ min/max values. calculate_qparams will calculate scale and zero_point
162+ """
163+
164+ __annotations__ = {
165+ "min_val" : Optional [torch .Tensor ],
166+ "max_val" : Optional [torch .Tensor ],
167+ "histogram" : Optional [torch .Tensor ],
168+ }
169+
170+ def __init__ (self , bins = 2048 , ** kwargs ):
171+ super (HistogramObserver , self ).__init__ (** kwargs )
172+ self .bins = bins
173+ self .histogram = None
174+ self .min_val = None
175+ self .max_val = None
176+
177+ def forward (self , x ):
178+ min_val = self .min_val
179+ max_val = self .max_val
180+ histogram = self .histogram
181+ if min_val is None or max_val is None or histogram is None :
182+ min_val = torch .min (x )
183+ max_val = torch .max (x )
184+ range = max_val - min_val
185+ self .min_val = min_val - 0.5 * range
186+ self .max_val = max_val + 0.5 * range
187+ self .histogram = torch .histc (
188+ x , self .bins , min = min_val - 0.5 * range , max = max_val + 0.5 * range
189+ )
190+ else :
191+ if min_val < torch .min (x ) or max_val > torch .max (x ):
192+ warnings .warn ("Incoming data is outside the min_val/max_val range." )
193+ new_histogram = torch .histc (
194+ x , self .bins , min = min_val , max = max_val
195+ )
196+ self .histogram = new_histogram + histogram
197+
198+ @torch .jit .export
199+ def calculate_qparams (self ):
200+ min_val = self .min_val
201+ max_val = self .max_val
202+ histogram = self .histogram
203+
204+ if min_val is None or max_val is None or histogram is None :
205+ return self ._calculate_qparams (None , None )
206+ else :
207+ histogram_mask = torch .gt (histogram , 0 ).to (torch .int8 )
208+ c = torch .cumsum (histogram_mask , 0 )
209+ # Last non-zero bin
210+ max_bin = torch .argmax (histogram_mask )
211+ # Only one entry is non-zero, find it.
212+ min_bin = torch .argmax (torch .eq (c , 1 ))
213+ bin_width = (max_val - min_val ) / histogram .size ()[0 ]
214+ new_min = min_val + min_bin * bin_width
215+ new_max = min_val + (max_bin + 1 ) * bin_width
216+ return self ._calculate_qparams (new_min , new_max )
217+
141218
142219
143220class TensorObserver (ObserverBase ):
@@ -168,6 +245,7 @@ def get_tensor_value(self):
168245def observer (observer_cls , ** kwargs ):
169246 return partial (observer_cls , ** kwargs )
170247
248+
171249def default_observer (** kwargs ):
172250 # Restrict activations to be in the range (0,127)
173251 kwargs .setdefault ("reduce_range" , True )
0 commit comments