Skip to content

Commit 0aff3cc

Browse files
fritzofacebook-github-bot
authored andcommitted
Fix broadcasting bug in StudentT (#12148)
Summary: This fixes a broadcasting error with the `StudentT` distribution - [x] added a regression test - [x] strengthened parameter broadcasting tests Pull Request resolved: #12148 Differential Revision: D10099226 Pulled By: soumith fbshipit-source-id: 0c5eb14180d158f8fff28ceb9e7cd3471c2bb803
1 parent b0248df commit 0aff3cc

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

test/test_distributions.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2378,12 +2378,20 @@ def test_valid_parameter_broadcasting(self):
23782378
(1, 2)),
23792379
(StudentT(df=torch.tensor([1.]), scale=torch.tensor([[1.]])),
23802380
(1, 1)),
2381+
(StudentT(df=1., loc=torch.zeros(5, 1), scale=torch.ones(3)),
2382+
(5, 3)),
23812383
]
23822384

23832385
for dist, expected_size in valid_examples:
2384-
dist_sample_size = dist.sample().size()
2385-
self.assertEqual(dist_sample_size, expected_size,
2386-
'actual size: {} != expected size: {}'.format(dist_sample_size, expected_size))
2386+
actual_size = dist.sample().size()
2387+
self.assertEqual(actual_size, expected_size,
2388+
'{} actual size: {} != expected size: {}'.format(dist, actual_size, expected_size))
2389+
2390+
sample_shape = torch.Size((2,))
2391+
expected_size = sample_shape + expected_size
2392+
actual_size = dist.sample(sample_shape).size()
2393+
self.assertEqual(actual_size, expected_size,
2394+
'{} actual size: {} != expected size: {}'.format(dist, actual_size, expected_size))
23872395

23882396
def test_invalid_parameter_broadcasting(self):
23892397
# invalid broadcasting cases; should throw error

torch/distributions/studentT.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def variance(self):
4141

4242
def __init__(self, df, loc=0., scale=1., validate_args=None):
4343
self.df, self.loc, self.scale = broadcast_all(df, loc, scale)
44-
self._chi2 = Chi2(df)
45-
batch_shape = torch.Size() if isinstance(df, Number) else self.df.size()
44+
self._chi2 = Chi2(self.df)
45+
batch_shape = self.df.size()
4646
super(StudentT, self).__init__(batch_shape, validate_args=validate_args)
4747

4848
def expand(self, batch_shape, _instance=None):

0 commit comments

Comments
 (0)