Skip to content

Conversation

@zou3519
Copy link
Contributor

@zou3519 zou3519 commented Jul 31, 2018

This exposes expand_outplace to python. Fixes #8076. Fixes #10041.

I didn't name it torch.broadcast because numpy.broadcast does something
slightly different (it returns an object with the correct shape
information).

Test Plan: new test_torch, test_autograd tests.

cc @soumith @fritzo

@fmassa
Copy link
Member

fmassa commented Jul 31, 2018

Nice!
I believe the equivalent numpy function is numpy.broadcast_arrays, should we aim at maintaining the same name?

@soumith
Copy link
Contributor

soumith commented Jul 31, 2018

yes let's match numpy and do broadcast_arrays

@zou3519
Copy link
Contributor Author

zou3519 commented Jul 31, 2018

Should it be named "broadcast_arrays" or "broadcast_tensors"? @fmassa and I were thinking broadcast_tensors because we call our data "tensors" but numpy calls their data "arrays".

I'll also try to implement varargs for this via a python wrapper function, it shouldn't be too bad.

@fritzo
Copy link
Collaborator

fritzo commented Jul 31, 2018

Nice! The additional behavior of torch.distributions.utils.broadcast_all() is to handle python floats. I believe we can update broadcast_all() to use your broadcast_arrays() as follows (in either this PR or a follow-up PR):

# in torch/distributions/utils.py
def broadcast_all(*values):
    """docstring"""
    if not all(map(torch.is_tensor, values)):
        # promote floats to tensors
        new_tensor = torch.tensor
        for value in values:
            if torch.is_tensor(value):
                new_tensor = value.new_tensor
                break
        values = [v if torch.is_tensor(v) else new_tensor(v) for v in values]
    return torch.broadcast_arrays(*values)

This would also be a great test to see that distributions are compatible with the new version 😄

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Jul 31, 2018

Or even put the scalar broadcasting in torch.broadcast_all/arrays? :) If torch.where uses it, it would automatically solve the torch.where not accepting scalars.

@fmassa
Copy link
Member

fmassa commented Jul 31, 2018

@vadimkantorov I think there was a discussion around using torch.as_tensor (or something equivalent) in the beginning of every PyTorch function, so that it also works for numbers and numpy arrays.
Not sure what's the status of that. @gchanan did we decide on something about that?

zou3519 added 2 commits July 31, 2018 11:14
This exposes expand_outplace to python. Fixes pytorch#8076. Fixes pytorch#10041.

I didn't name it torch.broadcast because numpy.broadcast does something
slightly different (it returns an object with the correct shape
information).

Test Plan: new test_torch, test_autograd tests.
- s/broadcast_all/broadcast_tensors/
- broadcast_tensors now takes varargs
@zou3519 zou3519 force-pushed the pytorch-broadcast branch from 2909f12 to c5e418b Compare July 31, 2018 18:14
@zou3519 zou3519 changed the title Implement torch.broadcast_all Implement torch.broadcast_tensors Jul 31, 2018
@zou3519
Copy link
Contributor Author

zou3519 commented Aug 1, 2018

This should be good for review, despite the hanging tests.

I updated the following:

  • renamed torch.broadcast_all to torch.broadcast_tensors (@fmassa)
  • Changed torch.distributions.utils.broadcast_all() to use torch.broadcast_tensors (@fritzo). All the distribution tests pass so it looks like it's good to go :)

return values
if not all(map(torch.is_tensor, values)):
# promote numbers to tensors of dtype torch.get_default_dtype()
def default_promotion(v):

This comment was marked as off-topic.

scalar_idxs = [i for i in range(len(values)) if isinstance(values[i], Number)]
tensor_idxs = [i for i in range(len(values)) if values[i].__class__.__name__ == 'Tensor']
if len(scalar_idxs) + len(tensor_idxs) != len(values):
if not all(map(lambda v: torch.is_tensor(v) or isinstance(v, Number), values)):

This comment was marked as off-topic.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zou3519 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@zou3519 zou3519 deleted the pytorch-broadcast branch August 2, 2018 02:33
zdevito pushed a commit to zdevito/ATen that referenced this pull request Aug 2, 2018
Summary:
This exposes expand_outplace to python. Fixes #8076. Fixes #10041.

I didn't name it torch.broadcast because numpy.broadcast does something
slightly different (it returns an object with the correct shape
information).
Pull Request resolved: pytorch/pytorch#10075

Differential Revision: D9125816

Pulled By: zou3519

fbshipit-source-id: ebe17c8bb54a73ec84b8f76ce14aff3e9c56f4d1
@fritzo
Copy link
Collaborator

fritzo commented Aug 7, 2018

This change is great. How can I start to use it? It does not appear to have landed in master yet.

EDIT Sorry, I was pointing to a fork 😊

@zou3519
Copy link
Contributor Author

zou3519 commented Aug 7, 2018

It should be on master, the following works for me on a latest checkout:

import torch
x = torch.randn(2, 2)
y = torch.randn(1)
torch.broadcast_tensors(x, y)

goodlux pushed a commit to goodlux/pytorch that referenced this pull request Aug 15, 2018
Summary:
This exposes expand_outplace to python. Fixes pytorch#8076. Fixes pytorch#10041.

I didn't name it torch.broadcast because numpy.broadcast does something
slightly different (it returns an object with the correct shape
information).
Pull Request resolved: pytorch#10075

Differential Revision: D9125816

Pulled By: zou3519

fbshipit-source-id: ebe17c8bb54a73ec84b8f76ce14aff3e9c56f4d1
facebook-github-bot pushed a commit that referenced this pull request Aug 23, 2018
Summary:
This uses zou3519's new `torch.broadcast_tensors()` #10075 to make `Categorical.log_prob()` and the `*Normal.__init__()` methods jittable. Previously `.log_prob()` was failing due to calls to `torch._C.infer_size()` with errors like
```
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
>       value_shape = torch._C._infer_size(value.size(), self.batch_shape) if self.batch_shape else value.size()
E       RuntimeError: expected int at position 0, but got: Tensor
```
After this change I'm able to jit many more of Pyro's tests.

Reviewed By: ezyang

Differential Revision: D9477487

Pulled By: apaszke

fbshipit-source-id: 5f39b29c6b8fa606ad30b02fefe2dfb618e883d6
PenghuiCheng pushed a commit to PenghuiCheng/pytorch that referenced this pull request Sep 11, 2018
Summary:
This uses zou3519's new `torch.broadcast_tensors()` pytorch#10075 to make `Categorical.log_prob()` and the `*Normal.__init__()` methods jittable. Previously `.log_prob()` was failing due to calls to `torch._C.infer_size()` with errors like
```
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
>       value_shape = torch._C._infer_size(value.size(), self.batch_shape) if self.batch_shape else value.size()
E       RuntimeError: expected int at position 0, but got: Tensor
```
After this change I'm able to jit many more of Pyro's tests.

Reviewed By: ezyang

Differential Revision: D9477487

Pulled By: apaszke

fbshipit-source-id: 5f39b29c6b8fa606ad30b02fefe2dfb618e883d6
@ezyang ezyang added the merged label Jun 26, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[JIT] Support torch.distributions.utils.broadcast_all() [pytorch] [feature request] Add torch.broadcast (e.g. for using with torch.stack)

7 participants