Skip to content

Commit a43e09c

Browse files
ezyangpytorchmergebot
authored andcommitted
Implement gamma cdf (#89955)
Authored by tillahoffmann originally at #72518 Implements the cumulative distribution function for the gamma distribution. The tests needed a small adjustment to pass because gradients cannot be evaluated with respect to the first argument of the incomplete gamma function (and they're not needed for the test). Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: #89955 Approved by: https://github.com/wconstab, https://github.com/malfet
1 parent 5167108 commit a43e09c

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

test/distributions/test_distributions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2975,6 +2975,9 @@ def test_cdf_log_prob(self):
29752975
# Tests if the differentiation of the CDF gives the PDF at a given value
29762976
for Dist, params in EXAMPLES:
29772977
for i, param in enumerate(params):
2978+
# We do not need grads wrt params here, e.g. shape of gamma distribution.
2979+
param = {key: value.detach() if isinstance(value, torch.Tensor) else value
2980+
for key, value in param.items()}
29782981
dist = Dist(**param)
29792982
samples = dist.sample()
29802983
if not dist.support.is_discrete:

torch/distributions/gamma.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,8 @@ def _natural_params(self):
8686

8787
def _log_normalizer(self, x, y):
8888
return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())
89+
90+
def cdf(self, value):
91+
if self._validate_args:
92+
self._validate_sample(value)
93+
return torch.special.gammainc(self.concentration, self.rate * value)

0 commit comments

Comments
 (0)