Skip to content

Conversation

@neerajprad
Copy link

Based on a suggestion from @fritzo prompted by pyro-ppl/pyro#675, this attempts to restrict same total_count only to the sample and enumerate_support methods for the Binomial distribution, so that we can score samples with vectorized total_count. Happy to hear suggestions if there is a non-hacky way to implement sampling for vectorized total_count (numpy has this!).

@neerajprad neerajprad requested review from alicanb and fritzo April 18, 2018 06:42
Copy link

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Math looks good. Could you update the docstring?


def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape) + (self.total_count,)
total_count = self._get_homogeneous_count()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gosh I guess since we're already summing Bernoullis we could draw inhomogeneous samples via

def sample(self, sample_shape=torch.Size()):
    with torch.no_grad():
        max_count = self.total_count.max().item()
        shape = self._extended_shape(sample_shape) + (max_count,)
        bernoullis = torch.bernoulli(self.probs.unsqueeze(-1).expand(shape))
        if self.total_count.min() != max_count:
            arange = torch.arange(max_count, out=self.total_count.new_empty(max_count))
            bernoullis *= (arange < self.total_count.unsqueeze(-1)).type_as(bernoullis)
        return bernoullis.sum(dim=-1)

WDYT?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about it. :) It may lead to us to creating some large intermediate tensors, but it should work fine! Let me make that change and then we can get in @apaszke's opinion.

bernoullis = torch.bernoulli(self.probs.unsqueeze(-1).expand(shape))
if self.total_count.min() != max_count:
arange = torch.arange(max_count, out=self.total_count.new_empty(max_count))
bernoullis *= (arange < self.total_count.unsqueeze(-1)).type_as(bernoullis)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super neat trick that you came up with, @fritzo. :)

@neerajprad
Copy link
Author

neerajprad commented Apr 18, 2018

@fritzo - could you take another look? I added a few tests, and used your suggestion for vectorized sampling over N.

@neerajprad neerajprad changed the title Allowing for vectorized counts in Binomial log prob Binomial with vectorized total count Apr 18, 2018
Copy link

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. We could use a couple more tests as commented below.


def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape) + (self.total_count,)
max_count = int(self.total_count.max().item())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd move this into the no_grad() context just to be safe.

self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3, 2)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)

def test_binomial_shape_vectorized_n(self):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be nice to draw a ton of samples and assert that (sample <= total_count).all(), just to make sure we got the masking correct. Maybe also numerically test the mean of Binomial(total_count=torch.tensor([0,1,2,5,10]), probs=0.5)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. We can test it at the boundary itself, i.e. with p=1, that should be sufficient to validate. I did that locally, but I should add it as a test. Will add the other test for the mean too.

Copy link

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! This should also serve as a template for how to support inhomogeneous Multinomial.

def test_binomial_vectorized_count(self):
set_rng_seed(0)
total_count = torch.tensor([[4., 7.], [3., 8.]])
bin0 = Binomial(total_count, torch.tensor(1.))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With a probs this high, all you need is a single sample 😉

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, this is doing an exact match! I am drawing way more samples for bin1 with p=0.5 to make sure that the invariance holds.

@neerajprad
Copy link
Author

cc. @1Reinier

@neerajprad
Copy link
Author

Sent upstream to pytorch#6720.

@neerajprad neerajprad closed this Apr 18, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants