-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Implement torch.broadcast_tensors #10075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Nice! |
|
yes let's match numpy and do |
|
Should it be named "broadcast_arrays" or "broadcast_tensors"? @fmassa and I were thinking broadcast_tensors because we call our data "tensors" but numpy calls their data "arrays". I'll also try to implement varargs for this via a python wrapper function, it shouldn't be too bad. |
|
Nice! The additional behavior of # in torch/distributions/utils.py
def broadcast_all(*values):
"""docstring"""
if not all(map(torch.is_tensor, values)):
# promote floats to tensors
new_tensor = torch.tensor
for value in values:
if torch.is_tensor(value):
new_tensor = value.new_tensor
break
values = [v if torch.is_tensor(v) else new_tensor(v) for v in values]
return torch.broadcast_arrays(*values)This would also be a great test to see that distributions are compatible with the new version 😄 |
|
Or even put the scalar broadcasting in |
|
@vadimkantorov I think there was a discussion around using |
This exposes expand_outplace to python. Fixes pytorch#8076. Fixes pytorch#10041. I didn't name it torch.broadcast because numpy.broadcast does something slightly different (it returns an object with the correct shape information). Test Plan: new test_torch, test_autograd tests.
- s/broadcast_all/broadcast_tensors/ - broadcast_tensors now takes varargs
2909f12 to
c5e418b
Compare
|
This should be good for review, despite the hanging tests. I updated the following: |
torch/distributions/utils.py
Outdated
| return values | ||
| if not all(map(torch.is_tensor, values)): | ||
| # promote numbers to tensors of dtype torch.get_default_dtype() | ||
| def default_promotion(v): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
torch/distributions/utils.py
Outdated
| scalar_idxs = [i for i in range(len(values)) if isinstance(values[i], Number)] | ||
| tensor_idxs = [i for i in range(len(values)) if values[i].__class__.__name__ == 'Tensor'] | ||
| if len(scalar_idxs) + len(tensor_idxs) != len(values): | ||
| if not all(map(lambda v: torch.is_tensor(v) or isinstance(v, Number), values)): |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
facebook-github-bot
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: This exposes expand_outplace to python. Fixes #8076. Fixes #10041. I didn't name it torch.broadcast because numpy.broadcast does something slightly different (it returns an object with the correct shape information). Pull Request resolved: pytorch/pytorch#10075 Differential Revision: D9125816 Pulled By: zou3519 fbshipit-source-id: ebe17c8bb54a73ec84b8f76ce14aff3e9c56f4d1
|
This change is great. How can I start to use it? It does not appear to have landed in master yet. EDIT Sorry, I was pointing to a fork 😊 |
|
It should be on master, the following works for me on a latest checkout: |
Summary: This exposes expand_outplace to python. Fixes pytorch#8076. Fixes pytorch#10041. I didn't name it torch.broadcast because numpy.broadcast does something slightly different (it returns an object with the correct shape information). Pull Request resolved: pytorch#10075 Differential Revision: D9125816 Pulled By: zou3519 fbshipit-source-id: ebe17c8bb54a73ec84b8f76ce14aff3e9c56f4d1
Summary: This uses zou3519's new `torch.broadcast_tensors()` #10075 to make `Categorical.log_prob()` and the `*Normal.__init__()` methods jittable. Previously `.log_prob()` was failing due to calls to `torch._C.infer_size()` with errors like ``` def log_prob(self, value): if self._validate_args: self._validate_sample(value) > value_shape = torch._C._infer_size(value.size(), self.batch_shape) if self.batch_shape else value.size() E RuntimeError: expected int at position 0, but got: Tensor ``` After this change I'm able to jit many more of Pyro's tests. Reviewed By: ezyang Differential Revision: D9477487 Pulled By: apaszke fbshipit-source-id: 5f39b29c6b8fa606ad30b02fefe2dfb618e883d6
Summary: This uses zou3519's new `torch.broadcast_tensors()` pytorch#10075 to make `Categorical.log_prob()` and the `*Normal.__init__()` methods jittable. Previously `.log_prob()` was failing due to calls to `torch._C.infer_size()` with errors like ``` def log_prob(self, value): if self._validate_args: self._validate_sample(value) > value_shape = torch._C._infer_size(value.size(), self.batch_shape) if self.batch_shape else value.size() E RuntimeError: expected int at position 0, but got: Tensor ``` After this change I'm able to jit many more of Pyro's tests. Reviewed By: ezyang Differential Revision: D9477487 Pulled By: apaszke fbshipit-source-id: 5f39b29c6b8fa606ad30b02fefe2dfb618e883d6
This exposes expand_outplace to python. Fixes #8076. Fixes #10041.
I didn't name it torch.broadcast because numpy.broadcast does something
slightly different (it returns an object with the correct shape
information).
Test Plan: new test_torch, test_autograd tests.
cc @soumith @fritzo