Skip to content

Conversation

@joh4n
Copy link
Contributor

@joh4n joh4n commented Aug 15, 2018

Support broadcasting in _kl_categorical_categorical

this makes it possible to do:

import torch.distributions as dist
import torch
p_dist = dist.Categorical(torch.ones(1,10))
q_dist = dist.Categorical(torch.ones(100,10))
dist.kl_divergence(p_dist, q_dist)

@vishwakftw
Copy link
Contributor

Could you please add a test in test_distributions.py?

@joh4n
Copy link
Contributor Author

joh4n commented Aug 15, 2018

I added a basic test for it

(binomial30, binomial30),
(binomial_vectorized_count, binomial_vectorized_count),
(categorical, categorical),
(Categorical(torch.ones(1, 10)), Categorical(torch.ones(3, 10))),

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@vishwakftw vishwakftw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you for fixing this.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants