Skip to content

Commit 61afbbb

Browse files
andreh7ezyang
authored andcommitted
clamping the return value of uniform.cdf() to [0..1] (#7538)
* fix for #7532: clamping the return value of uniform.cdf() to the range [0,1] * removed whitespace around equals to pass flake8 tests * added a test for uniform.cdf() with arguments outside support
1 parent bccb727 commit 61afbbb

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

test/test_distributions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1126,6 +1126,10 @@ def test_uniform(self):
11261126
self.assertEqual(uniform.log_prob(above_high).item(), -float('inf'), allow_inf=True)
11271127
self.assertEqual(uniform.log_prob(below_low).item(), -float('inf'), allow_inf=True)
11281128

1129+
# check cdf computation when value outside range
1130+
self.assertEqual(uniform.cdf(below_low).item(), 0)
1131+
self.assertEqual(uniform.cdf(above_high).item(), 1)
1132+
11291133
set_rng_seed(1)
11301134
self._gradcheck_log_prob(Uniform, (low, high))
11311135
self._gradcheck_log_prob(Uniform, (low, 1.0))

torch/distributions/uniform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def cdf(self, value):
7171
if self._validate_args:
7272
self._validate_sample(value)
7373
result = (value - self.low) / (self.high - self.low)
74-
return result
74+
return result.clamp(min=0, max=1)
7575

7676
def icdf(self, value):
7777
if self._validate_args:

0 commit comments

Comments
 (0)