-
Notifications
You must be signed in to change notification settings - Fork 1
Binomial with vectorized total count #148
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
35d252a to
f0da71f
Compare
fritzo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Math looks good. Could you update the docstring?
torch/distributions/binomial.py
Outdated
|
|
||
| def sample(self, sample_shape=torch.Size()): | ||
| shape = self._extended_shape(sample_shape) + (self.total_count,) | ||
| total_count = self._get_homogeneous_count() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gosh I guess since we're already summing Bernoullis we could draw inhomogeneous samples via
def sample(self, sample_shape=torch.Size()):
with torch.no_grad():
max_count = self.total_count.max().item()
shape = self._extended_shape(sample_shape) + (max_count,)
bernoullis = torch.bernoulli(self.probs.unsqueeze(-1).expand(shape))
if self.total_count.min() != max_count:
arange = torch.arange(max_count, out=self.total_count.new_empty(max_count))
bernoullis *= (arange < self.total_count.unsqueeze(-1)).type_as(bernoullis)
return bernoullis.sum(dim=-1)WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking about it. :) It may lead to us to creating some large intermediate tensors, but it should work fine! Let me make that change and then we can get in @apaszke's opinion.
| bernoullis = torch.bernoulli(self.probs.unsqueeze(-1).expand(shape)) | ||
| if self.total_count.min() != max_count: | ||
| arange = torch.arange(max_count, out=self.total_count.new_empty(max_count)) | ||
| bernoullis *= (arange < self.total_count.unsqueeze(-1)).type_as(bernoullis) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super neat trick that you came up with, @fritzo. :)
|
@fritzo - could you take another look? I added a few tests, and used your suggestion for vectorized sampling over N. |
fritzo
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. We could use a couple more tests as commented below.
torch/distributions/binomial.py
Outdated
|
|
||
| def sample(self, sample_shape=torch.Size()): | ||
| shape = self._extended_shape(sample_shape) + (self.total_count,) | ||
| max_count = int(self.total_count.max().item()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd move this into the no_grad() context just to be safe.
| self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2))) | ||
| self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2) | ||
|
|
||
| def test_binomial_shape_vectorized_n(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would also be nice to draw a ton of samples and assert that (sample <= total_count).all(), just to make sure we got the masking correct. Maybe also numerically test the mean of Binomial(total_count=torch.tensor([0,1,2,5,10]), probs=0.5)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. We can test it at the boundary itself, i.e. with p=1, that should be sufficient to validate. I did that locally, but I should add it as a test. Will add the other test for the mean too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! This should also serve as a template for how to support inhomogeneous Multinomial.
| def test_binomial_vectorized_count(self): | ||
| set_rng_seed(0) | ||
| total_count = torch.tensor([[4., 7.], [3., 8.]]) | ||
| bin0 = Binomial(total_count, torch.tensor(1.)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With a probs this high, all you need is a single sample 😉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, this is doing an exact match! I am drawing way more samples for bin1 with p=0.5 to make sure that the invariance holds.
|
cc. @1Reinier |
|
Sent upstream to pytorch#6720. |
Based on a suggestion from @fritzo prompted by pyro-ppl/pyro#675, this attempts to restrict same
total_countonly to thesampleandenumerate_supportmethods for the Binomial distribution, so that we can score samples with vectorizedtotal_count. Happy to hear suggestions if there is a non-hacky way to implement sampling for vectorizedtotal_count(numpy has this!).