Skip to content

Commit a0f1101

Browse files
mattipfacebook-github-bot
authored andcommitted
clamp Categorical logit from -inf to min_fifo when calculating entropy (#41002)
Summary: Fixes gh-40553 by clamping logit values when calculating Categorical.entropy Pull Request resolved: #41002 Reviewed By: mruberry Differential Revision: D22436432 Pulled By: ngimel fbshipit-source-id: 08b7c7b0c15ab4e5a56b3a8ec0d0237ad360202e
1 parent 359cdc2 commit a0f1101

File tree

4 files changed

+11
-3
lines changed

4 files changed

+11
-3
lines changed

aten/src/ATen/core/ivalue.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ std::ostream& IValue::repr(
400400
case IValue::Tag::Double: {
401401
double d = v.toDouble();
402402
int c = std::fpclassify(d);
403-
if (c == FP_NORMAL || c == FP_ZERO) {
403+
if ((c == FP_NORMAL || c == FP_ZERO ) && std::abs(d) < 1e10) {
404404
int64_t i = int64_t(d);
405405
if (double(i) == d) {
406406
return out << i << ".";

test/test_distributions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1178,6 +1178,11 @@ def ref_log_prob(idx, val, log_prob):
11781178
# check entropy computation
11791179
self.assertEqual(Categorical(p).entropy(), torch.tensor([1.0114, 1.0297]), atol=1e-4, rtol=0)
11801180
self.assertEqual(Categorical(s).entropy(), torch.tensor([0.0, 0.0]))
1181+
# issue gh-40553
1182+
logits = p.log()
1183+
logits[1, 1] = logits[0, 2] = float('-inf')
1184+
e = Categorical(logits=logits).entropy()
1185+
self.assertEqual(e, torch.tensor([0.6365, 0.5983]), atol=1e-4, rtol=0)
11811186

11821187
def test_categorical_enumerate_support(self):
11831188
examples = [

torch/csrc/jit/ir/node_hashing.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ size_t HashNode::operator()(const Node* k) const {
200200
} else if (
201201
type->isSubtypeOf(NumberType::get()) &&
202202
k->kindOf(attr::value) == AttributeKind::f) {
203-
constant_hash = std::hash<float>{}(k->f(attr::value));
203+
constant_hash = std::hash<double>{}(k->f(attr::value));
204204
} else if (type->isSubtypeOf(BoolType::get())) {
205205
constant_hash = std::hash<bool>{}(k->i(attr::value));
206206
}

torch/distributions/categorical.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ def __init__(self, probs=None, logits=None, validate_args=None):
5151
else:
5252
if logits.dim() < 1:
5353
raise ValueError("`logits` parameter must be at least one-dimensional.")
54+
# Normalize
5455
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
5556
self._param = self.probs if probs is not None else self.logits
5657
self._num_events = self._param.size()[-1]
@@ -115,7 +116,9 @@ def log_prob(self, value):
115116
return log_pmf.gather(-1, value).squeeze(-1)
116117

117118
def entropy(self):
118-
p_log_p = self.logits * self.probs
119+
min_real = torch.finfo(self.logits.dtype).min
120+
logits = torch.clamp(self.logits, min=min_real)
121+
p_log_p = logits * self.probs
119122
return -p_log_p.sum(-1)
120123

121124
def enumerate_support(self, expand=True):

0 commit comments

Comments
 (0)