@@ -96,6 +96,15 @@ class _ObserverBase(ObserverBase):
9696 - ``torch.per_channel_affine``
9797 - ``torch.per_channel_symmetric``
9898 """
99+
100+ # Version 1/None
101+ # self
102+ #
103+ # Version 2
104+ # self
105+ # |--- eps : Tensor
106+ _version = 2
107+
99108 def __init__ (self , dtype = torch .quint8 , qscheme = torch .per_tensor_affine ,
100109 reduce_range = False , quant_min = None , quant_max = None ):
101110 super (_ObserverBase , self ).__init__ (dtype = dtype )
@@ -106,7 +115,7 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
106115 reduce_range will be deprecated in a future release of PyTorch."
107116 )
108117 self .reduce_range = reduce_range
109- self .eps = torch .finfo (torch .float32 ).eps
118+ self .register_buffer ( ' eps' , torch .tensor ([ torch . finfo (torch .float32 ).eps ]))
110119 assert self .qscheme in (
111120 torch .per_tensor_affine ,
112121 torch .per_tensor_symmetric ,
@@ -126,6 +135,19 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
126135 self .quant_min = quant_min
127136 self .quant_max = quant_max
128137
138+ def _load_from_state_dict (self , state_dict , prefix , local_metadata , strict ,
139+ missing_keys , unexpected_keys , error_msgs ):
140+
141+ version = local_metadata .get ('version' , None )
142+
143+ if version is None or version == 1 :
144+ # eps was moved to a buffer in version 2
145+ eps = torch .tensor ([torch .finfo (torch .float32 ).eps ])
146+ state_dict [prefix + 'eps' ] = eps
147+
148+ super (ObserverBase , self )._load_from_state_dict (state_dict , prefix , local_metadata , strict ,
149+ missing_keys , unexpected_keys , error_msgs )
150+
129151 @torch .jit .export
130152 def _validate_qmin_qmax (self , quant_min , quant_max ):
131153 # type: (int, int) -> None
@@ -228,7 +250,7 @@ def _calculate_qparams(self, min_val, max_val):
228250 if self .qscheme == torch .per_tensor_symmetric or self .qscheme == torch .per_channel_symmetric :
229251 max_val_pos = torch .max (- min_val_neg , max_val_pos )
230252 scale = max_val_pos / (float (quant_max - quant_min ) / 2 )
231- scale = torch .max (scale , torch . tensor ( self .eps , device = device , dtype = scale . dtype ) )
253+ scale = torch .max (scale , self .eps )
232254 if self .dtype == torch .quint8 :
233255 if self .has_customized_qrange :
234256 # When customized quantization range is used, down-rounded midpoint of the range is chosen.
@@ -245,7 +267,7 @@ def _calculate_qparams(self, min_val, max_val):
245267 zero_point = - 1 * min_val / scale
246268 else :
247269 scale = (max_val_pos - min_val_neg ) / float (quant_max - quant_min )
248- scale = torch .max (scale , torch . tensor ( self .eps , device = device , dtype = scale . dtype ) )
270+ scale = torch .max (scale , self .eps )
249271 zero_point = quant_min - torch .round (min_val_neg / scale )
250272 zero_point = torch .max (zero_point , torch .tensor (quant_min , device = device , dtype = zero_point .dtype ))
251273 zero_point = torch .min (zero_point , torch .tensor (quant_max , device = device , dtype = zero_point .dtype ))
0 commit comments