-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Disallow scalar parameters in Dirichlet and Categorical #11589
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torch/distributions/dirichlet.py
Outdated
| def __init__(self, concentration, validate_args=None): | ||
| if concentration.dim() < 1: | ||
| raise ValueError("`concentration` parameter must be at least one-dimensional.") | ||
| self.concentration, = broadcast_all(concentration) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
Nice catch @neerajprad ! I think we should also add this to |
Good point; I'll address it in this PR itself.
Hmm..I didn't quite know that we had a real_vector constraint as well! :) It seems like we are only using it in |
|
IIRC the purpose of the |
|
Thanks for providing that context, @fritzo. I’ll remove the last commit. I don’t have a strong opinion on its removal, and will leave it until further discussion. |
2aec45e to
45e314d
Compare
|
I think this should be good to merge, pending any further comments. |
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
This adds a small check in
DirichletandCategorical__init__methods to ensure that scalar parameters are not admissible.Motivation
Currently,
Dirichletthrows no error when provided with a scalar parameter, but if weexpanda scalar instance, it inherits the empty event shape from the original instance and gives unexpected results.The alternative to this check is to promote
event_shapeto betorch.Size((1,))if the original instance was a scalar, but that seems to add a bit more complexity (and changes the behavior ofexpandin that it would affect theevent_shapeas well as thebatch_shapenow). Does this seem reasonable? cc. @alicanb, @fritzo.Additionally, based on review comments, this removes
real_vectorconstraint. This was only being used inMultivariateNormal, but I am happy to revert this if we want to keep it around for backwards compatibility.