Skip to content

Commit a5dfba0

Browse files
vkuzofacebook-github-bot
authored andcommitted
observers: make eps a buffer (#43149)
Summary: Pull Request resolved: #43149 This value doesn't change, making it a buffer to only pay the cost of creating a tensor once. Test Plan: Imported from OSS Reviewed By: jerryzh168 Differential Revision: D23170428 fbshipit-source-id: 6b963951a573efcc5b5a57649c814590b448dd72
1 parent 5aa61af commit a5dfba0

File tree

1 file changed

+25
-3
lines changed

1 file changed

+25
-3
lines changed

torch/quantization/observer.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,15 @@ class _ObserverBase(ObserverBase):
9696
- ``torch.per_channel_affine``
9797
- ``torch.per_channel_symmetric``
9898
"""
99+
100+
# Version 1/None
101+
# self
102+
#
103+
# Version 2
104+
# self
105+
# |--- eps : Tensor
106+
_version = 2
107+
99108
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
100109
reduce_range=False, quant_min=None, quant_max=None):
101110
super(_ObserverBase, self).__init__(dtype=dtype)
@@ -106,7 +115,7 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
106115
reduce_range will be deprecated in a future release of PyTorch."
107116
)
108117
self.reduce_range = reduce_range
109-
self.eps = torch.finfo(torch.float32).eps
118+
self.register_buffer('eps', torch.tensor([torch.finfo(torch.float32).eps]))
110119
assert self.qscheme in (
111120
torch.per_tensor_affine,
112121
torch.per_tensor_symmetric,
@@ -126,6 +135,19 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
126135
self.quant_min = quant_min
127136
self.quant_max = quant_max
128137

138+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
139+
missing_keys, unexpected_keys, error_msgs):
140+
141+
version = local_metadata.get('version', None)
142+
143+
if version is None or version == 1:
144+
# eps was moved to a buffer in version 2
145+
eps = torch.tensor([torch.finfo(torch.float32).eps])
146+
state_dict[prefix + 'eps'] = eps
147+
148+
super(ObserverBase, self)._load_from_state_dict(state_dict, prefix, local_metadata, strict,
149+
missing_keys, unexpected_keys, error_msgs)
150+
129151
@torch.jit.export
130152
def _validate_qmin_qmax(self, quant_min, quant_max):
131153
# type: (int, int) -> None
@@ -228,7 +250,7 @@ def _calculate_qparams(self, min_val, max_val):
228250
if self.qscheme == torch.per_tensor_symmetric or self.qscheme == torch.per_channel_symmetric:
229251
max_val_pos = torch.max(-min_val_neg, max_val_pos)
230252
scale = max_val_pos / (float(quant_max - quant_min) / 2)
231-
scale = torch.max(scale, torch.tensor(self.eps, device=device, dtype=scale.dtype))
253+
scale = torch.max(scale, self.eps)
232254
if self.dtype == torch.quint8:
233255
if self.has_customized_qrange:
234256
# When customized quantization range is used, down-rounded midpoint of the range is chosen.
@@ -245,7 +267,7 @@ def _calculate_qparams(self, min_val, max_val):
245267
zero_point = -1 * min_val / scale
246268
else:
247269
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
248-
scale = torch.max(scale, torch.tensor(self.eps, device=device, dtype=scale.dtype))
270+
scale = torch.max(scale, self.eps)
249271
zero_point = quant_min - torch.round(min_val_neg / scale)
250272
zero_point = torch.max(zero_point, torch.tensor(quant_min, device=device, dtype=zero_point.dtype))
251273
zero_point = torch.min(zero_point, torch.tensor(quant_max, device=device, dtype=zero_point.dtype))

0 commit comments

Comments
 (0)