Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions torch/quantization/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,13 @@ def _calculate_qparams(self, min_val, max_val):
return torch.tensor([1.0]), torch.tensor([0])

if min_val.dim() == 0 or max_val.dim() == 0:
if min_val == float('inf') and max_val == float('-inf'):
warnings.warn(
"must run observer before calling calculate_qparams.\
Returning default scale and zero point "
)
return torch.tensor([1.0]), torch.tensor([0])

assert min_val <= max_val, "min {} should be less than max {}".format(
min_val, max_val
)
Expand Down Expand Up @@ -364,8 +371,8 @@ def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
reduce_range=reduce_range,
quant_min=quant_min,
quant_max=quant_max)
self.register_buffer('min_val', torch.tensor([]))
self.register_buffer('max_val', torch.tensor([]))
self.register_buffer('min_val', torch.tensor(float('inf')))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we serialize tensors with inf/-inf correctly?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems ok

>>> m
EmbeddingBag(1, 2, mode=mean)
>>> m.register_buffer("v", torch.tensor([float('inf')]))
>>> m.v
tensor([inf])
>>> ms = torch.jit.script(m)
>>> ms
RecursiveScriptModule(original_name=EmbeddingBag)
>>> ms.v
tensor([inf])
>>> torch.jit.save(ms, "/tmp/test.pt")
>>> ms_2 = torch.jit.load("/tmp/test.pt")
>>> ms_2.v
tensor([inf])

self.register_buffer('max_val', torch.tensor(float('-inf')))
if self.qscheme == torch.per_tensor_symmetric and \
self.reduce_range and \
self.dtype == torch.quint8:
Expand All @@ -376,16 +383,8 @@ def forward(self, x_orig):
r"""Records the running minimum and maximum of ``x``."""
x = x_orig.detach() # avoid keeping autograd tape
x = x.to(self.min_val.dtype)
min_val = self.min_val
max_val = self.max_val
if min_val.numel() == 0 or max_val.numel() == 0:
min_val = torch.min(x)
max_val = torch.max(x)
else:
min_val = torch.min(torch.min(x), min_val)
max_val = torch.max(torch.max(x), max_val)
self.min_val.resize_(min_val.shape)
self.max_val.resize_(max_val.shape)
min_val = torch.min(torch.min(x), self.min_val)
max_val = torch.max(torch.max(x), self.max_val)
self.min_val.copy_(min_val)
self.max_val.copy_(max_val)
return x_orig
Expand Down Expand Up @@ -477,14 +476,12 @@ def forward(self, x_orig):
x = x.to(self.min_val.dtype)
min_val = self.min_val
max_val = self.max_val
if min_val.numel() == 0 or max_val.numel() == 0:
if min_val == float('inf') and max_val == float('-inf'):
min_val = torch.min(x)
max_val = torch.max(x)
else:
min_val = min_val + self.averaging_constant * (torch.min(x) - min_val)
max_val = max_val + self.averaging_constant * (torch.max(x) - max_val)
self.min_val.resize_(min_val.shape)
self.max_val.resize_(max_val.shape)
self.min_val.copy_(min_val)
self.max_val.copy_(max_val)
return x_orig
Expand Down Expand Up @@ -514,7 +511,7 @@ class MinMaxDynamicQuantObserver(MinMaxObserver):
def calculate_qparams(self):
r"""Calculates the quantization parameters."""

if self.max_val.numel() == 0 or self.min_val.numel() == 0:
if self.max_val == float('-inf') and self.min_val == float('inf'):
return torch.tensor([1.0]), torch.tensor([0])

assert self.min_val <= self.max_val, "min {} should be less than max {}".format(
Expand Down