Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,24 @@ Probability distributions - torch.distributions
:undoc-members:
:show-inheritance:

:hidden:`HalfCauchy`
~~~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torch.distributions.half_cauchy
.. autoclass:: HalfCauchy
:members:
:undoc-members:
:show-inheritance:

:hidden:`HalfNormal`
~~~~~~~~~~~~~~~~~~~~~~~

.. currentmodule:: torch.distributions.half_normal
.. autoclass:: HalfNormal
:members:
:undoc-members:
:show-inheritance:

:hidden:`Independent`
~~~~~~~~~~~~~~~~~~~~~

Expand Down
136 changes: 130 additions & 6 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
Cauchy, Chi2, Dirichlet, Distribution,
Exponential, ExponentialFamily,
FisherSnedecor, Gamma, Geometric, Gumbel,
HalfCauchy, HalfNormal,
Independent, Laplace, LogisticNormal,
LogNormal, Multinomial, MultivariateNormal,
Normal, OneHotCategorical, Pareto, Poisson,
Expand Down Expand Up @@ -180,6 +181,15 @@ def is_all_nan(tensor):
'scale': torch.tensor(torch.randn(1).abs(), requires_grad=True),
},
]),
Example(HalfCauchy, [
{'scale': 1.0},
{'scale': torch.tensor([[1.0], [1.0]])}
]),
Example(HalfNormal, [
{'scale': torch.tensor(torch.randn(5, 5).abs(), requires_grad=True)},
{'scale': torch.tensor(torch.randn(1).abs(), requires_grad=True)},
{'scale': torch.tensor([1e-5, 1e-5], requires_grad=True)}
]),
Example(Independent, [
{
'base_distribution': Normal(torch.randn(2, 3, requires_grad=True),
Expand Down Expand Up @@ -464,6 +474,15 @@ def is_all_nan(tensor):
'scale': torch.tensor([1., -1.], requires_grad=True),
},
]),
Example(HalfCauchy, [
{'scale': -1.0},
{'scale': 0.0},
{'scale': torch.tensor([[-0.000001], [1.0]])}
]),
Example(HalfNormal, [
{'scale': torch.tensor([0., 1.], requires_grad=True)},
{'scale': torch.tensor([1., -1.], requires_grad=True)},
]),
Example(Laplace, [
{
'loc': torch.tensor([1., 1.], requires_grad=True),
Expand Down Expand Up @@ -613,7 +632,7 @@ def _check_sampler_sampler(self, torch_dist, ref_dist, message, multivariate=Fal
samples.sort(key=lambda x: x[0])
samples = np.array(samples)[:, 1]

# Aggragate into bins filled with roughly zero-mean unit-variance RVs.
# Aggregate into bins filled with roughly zero-mean unit-variance RVs.
num_bins = 10
samples_per_bin = len(samples) // num_bins
bins = samples.reshape((num_bins, samples_per_bin)).mean(axis=1)
Expand Down Expand Up @@ -1175,9 +1194,9 @@ def test_cauchy(self):
self.assertEqual(Cauchy(0.0, 1.0).sample((1,)).size(), (1,))

set_rng_seed(1)
self._gradcheck_log_prob(Uniform, (loc, scale))
self._gradcheck_log_prob(Uniform, (loc, 1.0))
self._gradcheck_log_prob(Uniform, (0.0, scale))
self._gradcheck_log_prob(Cauchy, (loc, scale))
self._gradcheck_log_prob(Cauchy, (loc, 1.0))
self._gradcheck_log_prob(Cauchy, (0.0, scale))

state = torch.get_rng_state()
eps = loc.new(loc.size()).cauchy_()
Expand All @@ -1189,6 +1208,73 @@ def test_cauchy(self):
loc.grad.zero_()
scale.grad.zero_()

def test_halfcauchy(self):
scale = torch.ones(5, 5, requires_grad=True)
scale_1d = torch.ones(1, requires_grad=True)
self.assertTrue(is_all_nan(HalfCauchy(scale_1d).mean))
self.assertEqual(HalfCauchy(scale_1d).variance, float('inf'), allow_inf=True)
self.assertEqual(HalfCauchy(scale).sample().size(), (5, 5))
self.assertEqual(HalfCauchy(scale).sample((7,)).size(), (7, 5, 5))
self.assertEqual(HalfCauchy(scale_1d).sample().size(), (1,))
self.assertEqual(HalfCauchy(scale_1d).sample((1,)).size(), (1, 1))
self.assertEqual(HalfCauchy(1.0).sample((1,)).size(), (1,))

set_rng_seed(1)
self._gradcheck_log_prob(HalfCauchy, (scale,))
self._gradcheck_log_prob(HalfCauchy, (1.0,))

state = torch.get_rng_state()
eps = scale.new(scale.size()).cauchy_().abs_()
torch.set_rng_state(state)
c = HalfCauchy(scale).rsample()
c.backward(torch.ones_like(c))
self.assertEqual(scale.grad, eps)
scale.grad.zero_()

def test_halfnormal(self):
std = torch.tensor(torch.randn(5, 5).abs(), requires_grad=True)
std_1d = torch.randn(1, requires_grad=True)
std_delta = torch.tensor([1e-5, 1e-5])
self.assertEqual(HalfNormal(std).sample().size(), (5, 5))
self.assertEqual(HalfNormal(std).sample((7,)).size(), (7, 5, 5))
self.assertEqual(HalfNormal(std_1d).sample((1,)).size(), (1, 1))
self.assertEqual(HalfNormal(std_1d).sample().size(), (1,))
self.assertEqual(HalfNormal(.6).sample((1,)).size(), (1,))
self.assertEqual(HalfNormal(50.0).sample((1,)).size(), (1,))

# sample check for extreme value of std
set_rng_seed(1)
self.assertEqual(HalfNormal(std_delta).sample(sample_shape=(1, 2)),
torch.tensor([[[0.0, 0.0], [0.0, 0.0]]]),
prec=1e-4)

self._gradcheck_log_prob(HalfNormal, (std,))
self._gradcheck_log_prob(HalfNormal, (1.0,))

# check .log_prob() can broadcast.
dist = HalfNormal(torch.ones(2, 1, 4))
log_prob = dist.log_prob(torch.ones(3, 1))
self.assertEqual(log_prob.shape, (2, 3, 4))

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_halfnormal_logprob(self):
std = torch.tensor(torch.randn(5, 1).abs(), requires_grad=True)

def ref_log_prob(idx, x, log_prob):
s = std.view(-1)[idx].detach()
expected = scipy.stats.halfnorm(scale=s).logpdf(x)
self.assertAlmostEqual(log_prob, expected, places=3)

self._check_log_prob(HalfNormal(std), ref_log_prob)

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_halfnormal_sample(self):
set_rng_seed(0) # see Note [Randomized statistical tests]
for std in [0.1, 1.0, 10.0]:
self._check_sampler_sampler(HalfNormal(std),
scipy.stats.halfnorm(scale=std),
'HalfNormal(scale={})'.format(std))

def test_lognormal(self):
mean = torch.randn(5, 5, requires_grad=True)
std = torch.tensor(torch.randn(5, 5).abs(), requires_grad=True)
Expand Down Expand Up @@ -2447,6 +2533,32 @@ def test_cauchy_shape_tensor_params(self):
self.assertRaises(ValueError, cauchy.log_prob, self.tensor_sample_2)
self.assertEqual(cauchy.log_prob(torch.ones(2, 1)).size(), torch.Size((2, 2)))

def test_halfcauchy_shape_scalar_params(self):
halfcauchy = HalfCauchy(1)
self.assertEqual(halfcauchy._batch_shape, torch.Size())
self.assertEqual(halfcauchy._event_shape, torch.Size())
self.assertEqual(halfcauchy.sample().size(), torch.Size())
self.assertEqual(halfcauchy.sample(torch.Size((3, 2))).size(),
torch.Size((3, 2)))
self.assertRaises(ValueError, halfcauchy.log_prob, self.scalar_sample)
self.assertEqual(halfcauchy.log_prob(self.tensor_sample_1).size(),
torch.Size((3, 2)))
self.assertEqual(halfcauchy.log_prob(self.tensor_sample_2).size(),
torch.Size((3, 2, 3)))

def test_halfcauchy_shape_tensor_params(self):
halfcauchy = HalfCauchy(torch.tensor([1., 1.]))
self.assertEqual(halfcauchy._batch_shape, torch.Size((2,)))
self.assertEqual(halfcauchy._event_shape, torch.Size(()))
self.assertEqual(halfcauchy.sample().size(), torch.Size((2,)))
self.assertEqual(halfcauchy.sample(torch.Size((3, 2))).size(),
torch.Size((3, 2, 2)))
self.assertEqual(halfcauchy.log_prob(self.tensor_sample_1).size(),
torch.Size((3, 2)))
self.assertRaises(ValueError, halfcauchy.log_prob, self.tensor_sample_2)
self.assertEqual(halfcauchy.log_prob(torch.ones(2, 1)).size(),
torch.Size((2, 2)))

def test_dirichlet_shape(self):
dist = Dirichlet(torch.tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
Expand Down Expand Up @@ -2647,6 +2759,7 @@ def __init__(self, probs):
exponential = pairwise(Exponential, [1.0, 2.5, 5.0, 10.0])
gamma = pairwise(Gamma, [1.0, 2.5, 1.0, 2.5], [1.5, 1.5, 3.5, 3.5])
gumbel = pairwise(Gumbel, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
halfnormal = pairwise(HalfNormal, [1.0, 2.0, 1.0, 2.0])
laplace = pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
lognormal = pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
normal = pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
Expand Down Expand Up @@ -2698,6 +2811,7 @@ def __init__(self, probs):
(gamma, normal),
(gumbel, gumbel),
(gumbel, normal),
(halfnormal, halfnormal),
(laplace, laplace),
(lognormal, lognormal),
(laplace, normal),
Expand Down Expand Up @@ -3153,6 +3267,14 @@ def setUp(self):
Gumbel(random_var, positive_var2),
scipy.stats.gumbel_r(random_var, positive_var2)
),
(
HalfCauchy(positive_var),
scipy.stats.halfcauchy(scale=positive_var)
),
(
HalfNormal(positive_var2),
scipy.stats.halfnorm(scale=positive_var2)
),
(
Laplace(random_var, positive_var2),
scipy.stats.laplace(random_var, positive_var2)
Expand Down Expand Up @@ -3198,7 +3320,8 @@ def setUp(self):

def test_mean(self):
for pytorch_dist, scipy_dist in self.distribution_pairs:
if isinstance(pytorch_dist, Cauchy): # Cauchy distribution's mean is nan, skipping check
if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
# Cauchy, HalfCauchy distributions' mean is nan, skipping check
continue
elif isinstance(pytorch_dist, MultivariateNormal):
self.assertEqual(pytorch_dist.mean, scipy_dist.mean, allow_inf=True, message=pytorch_dist)
Expand All @@ -3207,7 +3330,8 @@ def test_mean(self):

def test_variance_stddev(self):
for pytorch_dist, scipy_dist in self.distribution_pairs:
if isinstance(pytorch_dist, Cauchy): # Cauchy distribution's standard deviation is nan, skipping check
if isinstance(pytorch_dist, (Cauchy, HalfCauchy)):
# Cauchy, HalfCauchy distributions' standard deviation is nan, skipping check
continue
elif isinstance(pytorch_dist, (Multinomial, OneHotCategorical)):
self.assertEqual(pytorch_dist.variance, np.diag(scipy_dist.cov()), message=pytorch_dist)
Expand Down
2 changes: 2 additions & 0 deletions torch/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@
from .gamma import Gamma
from .geometric import Geometric
from .gumbel import Gumbel
from .half_cauchy import HalfCauchy
from .half_normal import HalfNormal
from .independent import Independent
from .kl import kl_divergence, register_kl
from .laplace import Laplace
Expand Down
57 changes: 57 additions & 0 deletions torch/distributions/half_cauchy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import math

from torch.distributions import constraints
from torch.distributions.transforms import AbsTransform
from torch.distributions.cauchy import Cauchy
from torch.distributions.transformed_distribution import TransformedDistribution


class HalfCauchy(TransformedDistribution):
r"""
Creates a half-normal distribution parameterized by `scale` where::

X ~ Cauchy(0, scale)
Y = |X| ~ HalfCauchy(scale)

Example::

>>> m = HalfCauchy(torch.tensor([1.0]))
>>> m.sample() # half-cauchy distributed with scale=1
tensor([ 2.3214])

Args:
scale (float or Tensor): scale of the full Cauchy distribution
"""
arg_constraints = {'scale': constraints.positive}
support = constraints.positive
has_rsample = True

def __init__(self, scale, validate_args=None):
super(HalfCauchy, self).__init__(Cauchy(0, scale), AbsTransform(),
validate_args=validate_args)

@property
def scale(self):
return self.base_dist.scale

@property
def mean(self):
return self.base_dist.mean

@property
def variance(self):
return self.base_dist.variance

def log_prob(self, value):
log_prob = self.base_dist.log_prob(value) + math.log(2)
log_prob[value.expand(log_prob.shape) < 0] = -float('inf')
return log_prob

def cdf(self, value):
return 2 * self.base_dist.cdf(value) - 1

def icdf(self, prob):
return self.base_dist.icdf((prob + 1) / 2)

def entropy(self):
return self.base_dist.entropy() - math.log(2)
57 changes: 57 additions & 0 deletions torch/distributions/half_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import math

from torch.distributions import constraints
from torch.distributions.transforms import AbsTransform
from torch.distributions.normal import Normal
from torch.distributions.transformed_distribution import TransformedDistribution


class HalfNormal(TransformedDistribution):
r"""
Creates a half-normal distribution parameterized by `scale` where::

X ~ Normal(0, scale)
Y = |X| ~ HalfNormal(scale)

Example::

>>> m = HalfNormal(torch.tensor([1.0]))
>>> m.sample() # half-normal distributed with scale=1
tensor([ 0.1046])

Args:
scale (float or Tensor): scale of the full Normal distribution
"""
arg_constraints = {'scale': constraints.positive}
support = constraints.positive
has_rsample = True

def __init__(self, scale, validate_args=None):
super(HalfNormal, self).__init__(Normal(0, scale), AbsTransform(),
validate_args=validate_args)

@property
def scale(self):
return self.base_dist.scale

@property
def mean(self):
return self.scale * math.sqrt(2 / math.pi)

@property
def variance(self):
return self.scale.pow(2) * (1 - 2 / math.pi)

def log_prob(self, value):
log_prob = self.base_dist.log_prob(value) + math.log(2)
log_prob[value.expand(log_prob.shape) < 0] = -float('inf')
return log_prob

def cdf(self, value):
return 2 * self.base_dist.cdf(value) - 1

def icdf(self, prob):
return self.base_dist.icdf((prob + 1) / 2)

def entropy(self):
return self.base_dist.entropy() - math.log(2)
7 changes: 6 additions & 1 deletion torch/distributions/kl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
from .gamma import Gamma
from .geometric import Geometric
from .gumbel import Gumbel
from .half_normal import HalfNormal
from .laplace import Laplace
from .log_normal import LogNormal
from .logistic_normal import LogisticNormal
from .multivariate_normal import MultivariateNormal, _batch_mahalanobis, _batch_diag, _batch_inverse
from .normal import Normal
Expand Down Expand Up @@ -273,6 +273,11 @@ def _kl_geometric_geometric(p, q):
return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits


@register_kl(HalfNormal, HalfNormal)
def _kl_halfnormal_halfnormal(p, q):
return _kl_normal_normal(p.base_dist, q.base_dist)


@register_kl(Laplace, Laplace)
def _kl_laplace_laplace(p, q):
# From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf
Expand Down
Loading