@@ -172,15 +172,17 @@ Tensor& addr_out(Tensor &result, const Tensor& self, const Tensor& vec1, const T
172172}
173173
174174Tensor dot (const Tensor& self, const Tensor& tensor) {
175- if (self.dim () != 1 ) {
176- AT_ERROR (" Expected argument self to have 1 dimension, but has %d" , self.dim ());
177- }
178- if (tensor.dim () != 1 ) {
179- AT_ERROR (" Expected argument tensor to have 1 dimension, but has %d" , tensor.dim ());
180- }
175+ check_1d (self, " self" , " dot" );
176+ check_1d (tensor, " tensor" , " dot" );
181177 return self._dot (tensor);
182178}
183179
180+ Tensor& dot_out (Tensor& result, const Tensor& self, const Tensor& tensor) {
181+ result.resize_ ({});
182+ // dispatching through type ensures we don't allow mismatched types.
183+ return self.type ().fill_ (result, self.dot (tensor));
184+ }
185+
184186/*
185187Matrix product of two Tensors.
186188The behavior depends on the dimensionality of the Tensors as follows:
@@ -200,18 +202,21 @@ The behavior depends on the dimensionality of the Tensors as follows:
200202 must be broadcastable). For example, if tensor1 is a (j x 1 x n x m) Tensor
201203 and tensor2 is a (k x m x p) Tensor, the returned tensor will be an (j x k x n x p) Tensor.
202204*/
203- Tensor matmul (const Tensor & tensor1, const Tensor & tensor2) {
205+ Tensor matmul (at::optional<Tensor> out_opt, const Tensor& tensor1, const Tensor& tensor2) {
204206 auto dim_tensor1 = tensor1.dim ();
205207 auto dim_tensor2 = tensor2.dim ();
208+ auto has_out = out_opt.has_value ();
209+ Tensor out = out_opt.value_or (Tensor ());
206210
207211 if (dim_tensor1 == 1 && dim_tensor2 == 1 ) {
208- return tensor1.dot (tensor2);
212+ return has_out ? at::native::dot_out (out, tensor1, tensor2) : tensor1.dot (tensor2);
209213 } else if (dim_tensor1 == 2 && dim_tensor2 == 1 ) {
210- return tensor1.mv (tensor2);
214+ return has_out ? at::native::mv_out (out, tensor1, tensor2) : tensor1.mv (tensor2);
211215 } else if (dim_tensor1 == 1 && dim_tensor2 == 2 ) {
212- return tensor1.unsqueeze (0 ).mm (tensor2).squeeze_ (0 );
216+ return has_out ? at::native::mm_out (out, tensor1.unsqueeze (0 ), tensor2).squeeze_ (0 )
217+ : tensor1.unsqueeze (0 ).mm (tensor2).squeeze_ (0 );
213218 } else if (dim_tensor1 == 2 && dim_tensor2 == 2 ) {
214- return tensor1.mm (tensor2);
219+ return has_out ? at::native::mm_out (out, tensor1, tensor2) : tensor1.mm (tensor2);
215220 } else if (dim_tensor1 >= 3 && (dim_tensor2 == 1 || dim_tensor2 == 2 )) {
216221 // optimization: use mm instead of bmm by folding tensor1's batch into
217222 // its leading matrix dimension.
@@ -227,7 +232,9 @@ Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
227232
228233 // fold the batch into the first dimension
229234 Tensor t1 = tensor1.contiguous ().view ({-1 , size1[size1.size () - 1 ]});
230- return at::_unsafe_view (t1.mm (t2), output_size);
235+ Tensor output = has_out ? at::_unsafe_view (at::mm_out (out, t1, t2), output_size)
236+ : at::_unsafe_view (t1.mm (t2), output_size);
237+ return has_out ? out.set_ (output) : output;
231238 } else if ((dim_tensor1 >= 1 && dim_tensor2 >= 1 ) && (dim_tensor1 >= 3 || dim_tensor2 >= 3 )) {
232239 // We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
233240 // we track m1 vs m2 separately even though they must match for nicer error messages
@@ -260,8 +267,6 @@ Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
260267 Tensor tensor1_expanded = tensor1.expand (tensor1_expand_size).contiguous ().view (tensor1_bmm_view);
261268 Tensor tensor2_expanded = tensor2.expand (tensor2_expand_size).contiguous ().view (tensor2_bmm_view);
262269
263- Tensor output = tensor1_expanded.bmm (tensor2_expanded);
264-
265270 // reshape batches back into result
266271 std::vector<int64_t > output_shape (expand_batch_portion);
267272 if (dim_tensor1 > 1 ) {
@@ -270,13 +275,26 @@ Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
270275 if (dim_tensor2 > 1 ) {
271276 output_shape.push_back (p);
272277 }
273- return at::_unsafe_view (output, output_shape);
278+
279+ Tensor output = has_out ? at::_unsafe_view (at::bmm_out (out, tensor1_expanded, tensor2_expanded), output_shape)
280+ : at::_unsafe_view (tensor1_expanded.bmm (tensor2_expanded), output_shape);
281+
282+ return has_out ? out.set_ (output) : output;
274283 }
275284
276285 AT_ERROR (" both arguments to matmul need to be at least 1D, but they are %dD and %dD" ,
277286 dim_tensor1, dim_tensor2);
278287
279288}
280289
290+ Tensor matmul (const Tensor & tensor1, const Tensor & tensor2) {
291+ return at::native::matmul (at::nullopt , tensor1, tensor2);
292+ }
293+
294+ Tensor& matmul_out (Tensor &result, const Tensor & tensor1, const Tensor & tensor2) {
295+ at::native::matmul (at::optional<Tensor>(result), tensor1, tensor2);
296+ return result;
297+ }
298+
281299}
282300}
0 commit comments