-
Notifications
You must be signed in to change notification settings - Fork 26.3k
[distributions] Make more distributions jittable #10321
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
cc @apaszke |
|
Yes, |
torch/distributions/categorical.py
Outdated
| def log_prob(self, value): | ||
| if self._validate_args: | ||
| self._validate_sample(value) | ||
| value_shape = torch._C._infer_size(value.size(), self.batch_shape) if self.batch_shape else value.size() |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/distributions/categorical.py
Outdated
| if self._validate_args: | ||
| self._validate_sample(value) | ||
| value_shape = torch._C._infer_size(value.size(), self.batch_shape) if self.batch_shape else value.size() | ||
| param_shape = value_shape + (self._num_events,) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know much about distributions, but the semantics of the change look fine to me
|
Build failed due to timeout. Could someone restart the failing build? |
| return log_pmf.gather(-1, value.unsqueeze(-1).long()).squeeze(-1) | ||
| value = value.long().unsqueeze(-1) | ||
| value, log_pmf = torch.broadcast_tensors(value, self.logits) | ||
| value = value[..., :1] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
|
@apaszke I've added fixes for |
|
@apaszke Another look at this? |
apaszke
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My head hurts a bit when I try to figure out what's happening, because I don't remember which value is supposed to have what dimensions (or where does batch_dim or event_dim go), but I don't see anything alarming.
| .format(loc.shape, cov_factor.shape, cov_diag.shape)) | ||
| self.loc = loc_[..., 0] | ||
| self.cov_diag = cov_diag_[..., 0] | ||
| batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
neerajprad
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for making these distributions JITable.
| .format(loc.shape, cov_factor.shape, cov_diag.shape)) | ||
| self.loc = loc_[..., 0] | ||
| self.cov_diag = cov_diag_[..., 0] | ||
| batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:] |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apaszke has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
I've rebased this, if the tests still pass I will merge the changes edit: nevermind, looks like this was merged already but not closed (fcfb1c1#diff-55f664a63e28ba5ef39157ffbe5eea2c). Closing this PR because it is already merged |
|
Closing b/c already merged |
This uses @zou3519's new
torch.broadcast_tensors()#10075 to makeCategorical.log_prob()and the*Normal.__init__()methods jittable. Previously.log_prob()was failing due to calls totorch._C.infer_size()with errors likeAfter this change I'm able to jit many more of Pyro's tests.
Questions for reviewers
This assumes that
broadcast_tensorswill create avaluewith stride 0 in the rightmost dimension; that way we never create a huge tensor. Is this assumption valid?