Skip to content

Conversation

@fritzo
Copy link
Collaborator

@fritzo fritzo commented Aug 7, 2018

This uses @zou3519's new torch.broadcast_tensors() #10075 to make Categorical.log_prob() and the *Normal.__init__() methods jittable. Previously .log_prob() was failing due to calls to torch._C.infer_size() with errors like

    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
>       value_shape = torch._C._infer_size(value.size(), self.batch_shape) if self.batch_shape else value.size()
E       RuntimeError: expected int at position 0, but got: Tensor

After this change I'm able to jit many more of Pyro's tests.

Questions for reviewers

This assumes that broadcast_tensors will create a value with stride 0 in the rightmost dimension; that way we never create a huge tensor. Is this assumption valid?

@zou3519 Yes, broadcast_tensors will create a value with stride 0 in the rightmost dimension because broadcasting expands tensors (as opposed to using repeat, which creates new tensors).

@fritzo fritzo changed the title Make Categorical.log_prob() jittable [distributions] Make Categorical.log_prob() jittable Aug 7, 2018
@fritzo
Copy link
Collaborator Author

fritzo commented Aug 7, 2018

cc @apaszke

@zou3519
Copy link
Contributor

zou3519 commented Aug 7, 2018

Yes, broadcast_tensors will create a value with stride 0 in the rightmost dimension because broadcasting expands tensors (as opposed to using repeat, which creates new tensors).

def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
value_shape = torch._C._infer_size(value.size(), self.batch_shape) if self.batch_shape else value.size()

This comment was marked as off-topic.

This comment was marked as off-topic.

if self._validate_args:
self._validate_sample(value)
value_shape = torch._C._infer_size(value.size(), self.batch_shape) if self.batch_shape else value.size()
param_shape = value_shape + (self._num_events,)

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know much about distributions, but the semantics of the change look fine to me

@fritzo
Copy link
Collaborator Author

fritzo commented Aug 8, 2018

Build failed due to timeout. Could someone restart the failing build?

return log_pmf.gather(-1, value.unsqueeze(-1).long()).squeeze(-1)
value = value.long().unsqueeze(-1)
value, log_pmf = torch.broadcast_tensors(value, self.logits)
value = value[..., :1]

This comment was marked as off-topic.

This comment was marked as off-topic.

@fritzo fritzo changed the title [distributions] Make Categorical.log_prob() jittable [distributions] Make more distributions jittable Aug 9, 2018
@fritzo
Copy link
Collaborator Author

fritzo commented Aug 9, 2018

@apaszke I've added fixes for MultivariateNormal and LowRankMultivariateNormal to this PR. I hope this helps to clarify the pattern of using torch.broadcast_tensors to broadcast tensors of different target shape.

@soumith soumith added the ready for review (this tag is deprecated) All PRs are ready for review unless they are draft, WIP, or have undismissed requested changes label Aug 14, 2018
@ssnl
Copy link
Collaborator

ssnl commented Aug 14, 2018

@apaszke Another look at this?

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My head hurts a bit when I try to figure out what's happening, because I don't remember which value is supposed to have what dimensions (or where does batch_dim or event_dim go), but I don't see anything alarming.

.format(loc.shape, cov_factor.shape, cov_diag.shape))
self.loc = loc_[..., 0]
self.cov_diag = cov_diag_[..., 0]
batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:]

This comment was marked as off-topic.

This comment was marked as off-topic.

Copy link
Contributor

@neerajprad neerajprad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making these distributions JITable.

.format(loc.shape, cov_factor.shape, cov_diag.shape))
self.loc = loc_[..., 0]
self.cov_diag = cov_diag_[..., 0]
batch_shape, event_shape = self.loc.shape[:-1], self.loc.shape[-1:]

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apaszke has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zou3519
Copy link
Contributor

zou3519 commented Sep 25, 2018

I've rebased this, if the tests still pass I will merge the changes

edit: nevermind, looks like this was merged already but not closed (fcfb1c1#diff-55f664a63e28ba5ef39157ffbe5eea2c). Closing this PR because it is already merged

@zou3519
Copy link
Contributor

zou3519 commented Sep 25, 2018

Closing b/c already merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

open source ready for review (this tag is deprecated) All PRs are ready for review unless they are draft, WIP, or have undismissed requested changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants