Skip to content

Commit 6b338c8

Browse files
zou3519facebook-github-bot
authored andcommitted
Implement torch.broadcast_tensors (#10075)
Summary: This exposes expand_outplace to python. Fixes #8076. Fixes #10041. I didn't name it torch.broadcast because numpy.broadcast does something slightly different (it returns an object with the correct shape information). Pull Request resolved: #10075 Differential Revision: D9125816 Pulled By: zou3519 fbshipit-source-id: ebe17c8bb54a73ec84b8f76ce14aff3e9c56f4d1
1 parent 191482f commit 6b338c8

File tree

6 files changed

+65
-31
lines changed

6 files changed

+65
-31
lines changed

aten/src/ATen/native/TensorShape.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
namespace at {
1313
namespace native {
1414

15+
std::vector<Tensor> broadcast_tensors(TensorList tensors) {
16+
return expand_outplace(tensors);
17+
}
18+
1519
static void check_cat_no_zero_dim(TensorList tensors) {
1620
for(size_t i = 0; i < tensors.size(); ++i) {
1721
auto& t = tensors[i];

aten/src/ATen/native/native_functions.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@
249249
- func: blackman_window(int64_t window_length, bool periodic, TensorOptions options={}) -> Tensor
250250
variants: function
251251

252+
- func: broadcast_tensors(TensorList tensors) -> TensorList
253+
variants: function
254+
252255
- func: cat(TensorList tensors, int64_t dim=0) -> Tensor
253256
variants: function
254257

test/test_autograd.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1850,6 +1850,16 @@ def backward(ctx, grad_output):
18501850
out.sum().backward()
18511851
self.assertEqual(x.grad.data, y_data)
18521852

1853+
def test_broadcast_tensors(self):
1854+
f_args_variable = (torch.randn(3, requires_grad=True),
1855+
torch.randn(1, 2, 1, requires_grad=True),
1856+
torch.randn(1, 1, requires_grad=True),
1857+
torch.randn(5, 1, 1, requires_grad=True))
1858+
f_args_tensor = deepcopy(unpack_variables(f_args_variable))
1859+
run_functional_checks(self, "test_broadcast_tensors", "broadcast",
1860+
lambda a, b, c, d: torch.broadcast_tensors(a, b, c, d),
1861+
True, f_args_variable, f_args_tensor)
1862+
18531863
def test_cat(self):
18541864
f_args_variable = (torch.randn(1, S, S, requires_grad=True),
18551865
torch.randn(2, S, S, requires_grad=True),

test/test_torch.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2939,6 +2939,17 @@ def test_broadcast_empty(self):
29392939
torch.randn(0, 7, 0, 6, 5, 0, 1) + torch.randn(1, 1, 5, 1, 7))
29402940
self.assertRaises(RuntimeError, lambda: torch.randn(7, 0) + torch.randn(2, 1))
29412941

2942+
def test_broadcast_tensors(self):
2943+
x0 = torch.randn(2, 1, 3)
2944+
x1 = torch.randn(3)
2945+
x2 = torch.randn(3, 1)
2946+
expected_size = (2, 3, 3)
2947+
2948+
y0, y1, y2 = torch.broadcast_tensors(x0, x1, x2)
2949+
self.assertTrue(y0.size() == expected_size)
2950+
self.assertTrue(y1.size() == expected_size)
2951+
self.assertTrue(y2.size() == expected_size)
2952+
29422953
@staticmethod
29432954
def _test_contiguous(self, cast):
29442955
x = cast(torch.randn(1, 16, 5, 5))

torch/distributions/utils.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,30 +32,19 @@ def _finfo(tensor):
3232
return _FINFO[tensor.storage_type()]
3333

3434

35-
def _broadcast_shape(shapes):
36-
r"""
37-
Given a list of tensor sizes, returns the size of the resulting broadcasted
38-
tensor.
39-
40-
Args:
41-
shapes (list of torch.Size): list of tensor sizes
42-
"""
43-
shape = torch.Size()
44-
for s in shapes:
45-
shape = torch._C._infer_size(s, shape)
46-
return shape
35+
# promote numbers to tensors of dtype torch.get_default_dtype()
36+
def _default_promotion(v):
37+
return torch.tensor(v, dtype=torch.get_default_dtype())
4738

4839

4940
def broadcast_all(*values):
5041
r"""
5142
Given a list of values (possibly containing numbers), returns a list where each
5243
value is broadcasted based on the following rules:
53-
- `torch.*Tensor` instances are broadcasted as per the `broadcasting rules
54-
<http://pytorch.org/docs/master/notes/broadcasting.html>`_
44+
- `torch.*Tensor` instances are broadcasted as per :ref:`_broadcasting-semantics`.
5545
- numbers.Number instances (scalars) are upcast to tensors having
5646
the same size and type as the first tensor passed to `values`. If all the
57-
values are scalars, then they are upcasted to Tensors having size
58-
`(1,)`.
47+
values are scalars, then they are upcasted to scalar Tensors.
5948
6049
Args:
6150
values (list of `numbers.Number` or `torch.*Tensor`)
@@ -64,22 +53,16 @@ def broadcast_all(*values):
6453
ValueError: if any of the values is not a `numbers.Number` or
6554
`torch.*Tensor` instance
6655
"""
67-
values = list(values)
68-
scalar_idxs = [i for i in range(len(values)) if isinstance(values[i], Number)]
69-
tensor_idxs = [i for i in range(len(values)) if values[i].__class__.__name__ == 'Tensor']
70-
if len(scalar_idxs) + len(tensor_idxs) != len(values):
56+
if not all(torch.is_tensor(v) or isinstance(v, Number) for v in values):
7157
raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.')
72-
if tensor_idxs:
73-
broadcast_shape = _broadcast_shape([values[i].size() for i in tensor_idxs])
74-
for idx in tensor_idxs:
75-
values[idx] = values[idx].expand(broadcast_shape)
76-
template = values[tensor_idxs[0]]
77-
for idx in scalar_idxs:
78-
values[idx] = template.new(template.size()).fill_(values[idx])
79-
else:
80-
for idx in scalar_idxs:
81-
values[idx] = torch.tensor(float(values[idx]))
82-
return values
58+
if not all(map(torch.is_tensor, values)):
59+
new_tensor = _default_promotion
60+
for value in values:
61+
if torch.is_tensor(value):
62+
new_tensor = value.new_tensor
63+
break
64+
values = [v if torch.is_tensor(v) else new_tensor(v) for v in values]
65+
return torch.broadcast_tensors(*values)
8366

8467

8568
def _sum_rightmost(value, dim):

torch/functional.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
'argmin',
1111
'btrifact',
1212
'btriunpack',
13+
'broadcast_tensors',
1314
'isfinite',
1415
'isinf',
1516
'isnan',
@@ -19,6 +20,28 @@
1920
]
2021

2122

23+
def broadcast_tensors(*tensors):
24+
r"""broadcast_tensors(*tensors) -> List of Tensors
25+
26+
Broadcasts the given tensors according to :ref:`_broadcasting-semantics`.
27+
28+
Args:
29+
*tensors: any number of tensors of the same type
30+
31+
Example::
32+
33+
>>> x = torch.arange(3).view(1, 3)
34+
>>> y = torch.arange(2).view(2, 1)
35+
>>> a, b = torch.broadcast_tensors(x, y)
36+
>>> a.size()
37+
torch.Size([2, 3])
38+
>>> a
39+
tensor([[0, 1, 2],
40+
[0, 1, 2]])
41+
"""
42+
return torch._C._VariableFunctions.broadcast_tensors(tensors)
43+
44+
2245
def split(tensor, split_size_or_sections, dim=0):
2346
r"""Splits the tensor into chunks.
2447

0 commit comments

Comments
 (0)