@@ -32,30 +32,19 @@ def _finfo(tensor):
3232 return _FINFO [tensor .storage_type ()]
3333
3434
35- def _broadcast_shape (shapes ):
36- r"""
37- Given a list of tensor sizes, returns the size of the resulting broadcasted
38- tensor.
39-
40- Args:
41- shapes (list of torch.Size): list of tensor sizes
42- """
43- shape = torch .Size ()
44- for s in shapes :
45- shape = torch ._C ._infer_size (s , shape )
46- return shape
35+ # promote numbers to tensors of dtype torch.get_default_dtype()
36+ def _default_promotion (v ):
37+ return torch .tensor (v , dtype = torch .get_default_dtype ())
4738
4839
4940def broadcast_all (* values ):
5041 r"""
5142 Given a list of values (possibly containing numbers), returns a list where each
5243 value is broadcasted based on the following rules:
53- - `torch.*Tensor` instances are broadcasted as per the `broadcasting rules
54- <http://pytorch.org/docs/master/notes/broadcasting.html>`_
44+ - `torch.*Tensor` instances are broadcasted as per :ref:`_broadcasting-semantics`.
5545 - numbers.Number instances (scalars) are upcast to tensors having
5646 the same size and type as the first tensor passed to `values`. If all the
57- values are scalars, then they are upcasted to Tensors having size
58- `(1,)`.
47+ values are scalars, then they are upcasted to scalar Tensors.
5948
6049 Args:
6150 values (list of `numbers.Number` or `torch.*Tensor`)
@@ -64,22 +53,16 @@ def broadcast_all(*values):
6453 ValueError: if any of the values is not a `numbers.Number` or
6554 `torch.*Tensor` instance
6655 """
67- values = list (values )
68- scalar_idxs = [i for i in range (len (values )) if isinstance (values [i ], Number )]
69- tensor_idxs = [i for i in range (len (values )) if values [i ].__class__ .__name__ == 'Tensor' ]
70- if len (scalar_idxs ) + len (tensor_idxs ) != len (values ):
56+ if not all (torch .is_tensor (v ) or isinstance (v , Number ) for v in values ):
7157 raise ValueError ('Input arguments must all be instances of numbers.Number or torch.tensor.' )
72- if tensor_idxs :
73- broadcast_shape = _broadcast_shape ([values [i ].size () for i in tensor_idxs ])
74- for idx in tensor_idxs :
75- values [idx ] = values [idx ].expand (broadcast_shape )
76- template = values [tensor_idxs [0 ]]
77- for idx in scalar_idxs :
78- values [idx ] = template .new (template .size ()).fill_ (values [idx ])
79- else :
80- for idx in scalar_idxs :
81- values [idx ] = torch .tensor (float (values [idx ]))
82- return values
58+ if not all (map (torch .is_tensor , values )):
59+ new_tensor = _default_promotion
60+ for value in values :
61+ if torch .is_tensor (value ):
62+ new_tensor = value .new_tensor
63+ break
64+ values = [v if torch .is_tensor (v ) else new_tensor (v ) for v in values ]
65+ return torch .broadcast_tensors (* values )
8366
8467
8568def _sum_rightmost (value , dim ):
0 commit comments