Skip to content

Commit a08091a

Browse files
authored
Implement matmul_out and dot_out. (#6961)
* Implement matmul_out and dot_out. * Fix autograd by only calling _out variants if we have an out ourselves. * Disallow mismatched types in dot_out. * Make sure out variant doesn't have a method. * Do proper type conversion.
1 parent 4949394 commit a08091a

File tree

3 files changed

+49
-15
lines changed

3 files changed

+49
-15
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,17 @@ Tensor& addr_out(Tensor &result, const Tensor& self, const Tensor& vec1, const T
172172
}
173173

174174
Tensor dot(const Tensor& self, const Tensor& tensor) {
175-
if (self.dim() != 1) {
176-
AT_ERROR("Expected argument self to have 1 dimension, but has %d", self.dim());
177-
}
178-
if (tensor.dim() != 1) {
179-
AT_ERROR("Expected argument tensor to have 1 dimension, but has %d", tensor.dim());
180-
}
175+
check_1d(self, "self", "dot");
176+
check_1d(tensor, "tensor", "dot");
181177
return self._dot(tensor);
182178
}
183179

180+
Tensor& dot_out(Tensor& result, const Tensor& self, const Tensor& tensor) {
181+
result.resize_({});
182+
// dispatching through type ensures we don't allow mismatched types.
183+
return self.type().fill_(result, self.dot(tensor));
184+
}
185+
184186
/*
185187
Matrix product of two Tensors.
186188
The behavior depends on the dimensionality of the Tensors as follows:
@@ -200,18 +202,21 @@ The behavior depends on the dimensionality of the Tensors as follows:
200202
must be broadcastable). For example, if tensor1 is a (j x 1 x n x m) Tensor
201203
and tensor2 is a (k x m x p) Tensor, the returned tensor will be an (j x k x n x p) Tensor.
202204
*/
203-
Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
205+
Tensor matmul(at::optional<Tensor> out_opt, const Tensor& tensor1, const Tensor& tensor2) {
204206
auto dim_tensor1 = tensor1.dim();
205207
auto dim_tensor2 = tensor2.dim();
208+
auto has_out = out_opt.has_value();
209+
Tensor out = out_opt.value_or(Tensor());
206210

207211
if (dim_tensor1 == 1 && dim_tensor2 == 1) {
208-
return tensor1.dot(tensor2);
212+
return has_out ? at::native::dot_out(out, tensor1, tensor2) : tensor1.dot(tensor2);
209213
} else if (dim_tensor1 == 2 && dim_tensor2 == 1) {
210-
return tensor1.mv(tensor2);
214+
return has_out ? at::native::mv_out(out, tensor1, tensor2) : tensor1.mv(tensor2);
211215
} else if (dim_tensor1 == 1 && dim_tensor2 == 2) {
212-
return tensor1.unsqueeze(0).mm(tensor2).squeeze_(0);
216+
return has_out ? at::native::mm_out(out, tensor1.unsqueeze(0), tensor2).squeeze_(0)
217+
: tensor1.unsqueeze(0).mm(tensor2).squeeze_(0);
213218
} else if (dim_tensor1 == 2 && dim_tensor2 == 2) {
214-
return tensor1.mm(tensor2);
219+
return has_out ? at::native::mm_out(out, tensor1, tensor2) : tensor1.mm(tensor2);
215220
} else if (dim_tensor1 >= 3 && (dim_tensor2 == 1 || dim_tensor2 == 2)) {
216221
// optimization: use mm instead of bmm by folding tensor1's batch into
217222
// its leading matrix dimension.
@@ -227,7 +232,9 @@ Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
227232

228233
// fold the batch into the first dimension
229234
Tensor t1 = tensor1.contiguous().view({-1, size1[size1.size() - 1]});
230-
return at::_unsafe_view(t1.mm(t2), output_size);
235+
Tensor output = has_out ? at::_unsafe_view(at::mm_out(out, t1, t2), output_size)
236+
: at::_unsafe_view(t1.mm(t2), output_size);
237+
return has_out ? out.set_(output) : output;
231238
} else if ((dim_tensor1 >= 1 && dim_tensor2 >= 1) && (dim_tensor1 >= 3 || dim_tensor2 >= 3)) {
232239
// We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
233240
// we track m1 vs m2 separately even though they must match for nicer error messages
@@ -260,8 +267,6 @@ Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
260267
Tensor tensor1_expanded = tensor1.expand(tensor1_expand_size).contiguous().view(tensor1_bmm_view);
261268
Tensor tensor2_expanded = tensor2.expand(tensor2_expand_size).contiguous().view(tensor2_bmm_view);
262269

263-
Tensor output = tensor1_expanded.bmm(tensor2_expanded);
264-
265270
// reshape batches back into result
266271
std::vector<int64_t> output_shape(expand_batch_portion);
267272
if (dim_tensor1 > 1) {
@@ -270,13 +275,26 @@ Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
270275
if (dim_tensor2 > 1) {
271276
output_shape.push_back(p);
272277
}
273-
return at::_unsafe_view(output, output_shape);
278+
279+
Tensor output = has_out ? at::_unsafe_view(at::bmm_out(out, tensor1_expanded, tensor2_expanded), output_shape)
280+
: at::_unsafe_view(tensor1_expanded.bmm(tensor2_expanded), output_shape);
281+
282+
return has_out ? out.set_(output) : output;
274283
}
275284

276285
AT_ERROR("both arguments to matmul need to be at least 1D, but they are %dD and %dD",
277286
dim_tensor1, dim_tensor2);
278287

279288
}
280289

290+
Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
291+
return at::native::matmul(at::nullopt, tensor1, tensor2);
292+
}
293+
294+
Tensor& matmul_out(Tensor &result, const Tensor & tensor1, const Tensor & tensor2) {
295+
at::native::matmul(at::optional<Tensor>(result), tensor1, tensor2);
296+
return result;
297+
}
298+
281299
}
282300
}

aten/src/ATen/native/native_functions.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,9 @@
281281

282282
- func: dot(Tensor self, Tensor tensor) -> Tensor
283283

284+
- func: dot_out(Tensor result, Tensor self, Tensor tensor) -> Tensor
285+
variants: function
286+
284287
- func: einsum(std::string equation, TensorList tensors) -> Tensor
285288
variants: function
286289

@@ -468,6 +471,9 @@
468471

469472
- func: matmul(Tensor self, Tensor other) -> Tensor
470473

474+
- func: matmul_out(Tensor result, Tensor self, Tensor other) -> Tensor
475+
variants: function
476+
471477
- func: max_values(Tensor self, int64_t dim, bool keepdim=false) -> Tensor
472478

473479
- func: max_pool1d(Tensor self, IntList[1] kernel_size, IntList[1] stride={}, IntList[1] padding=0, IntList[1] dilation=1, bool ceil_mode=false) -> (Tensor, Tensor)

test/test_torch.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def test_dot(self):
9292
for i, j in zip(v1, v2):
9393
res2 += i * j
9494
self.assertEqual(res1, res2)
95+
out = torch.randn(()).type(tname)
96+
torch.dot(v1, v2, out=out)
97+
self.assertEqual(res1, out)
9598

9699
# Test 0-strided
97100
for tname, _prec in types.items():
@@ -102,6 +105,9 @@ def test_dot(self):
102105
for i, j in zip(v1, v2):
103106
res2 += i * j
104107
self.assertEqual(res1, res2)
108+
out = torch.randn(()).type(tname)
109+
torch.dot(v1, v2, out=out)
110+
self.assertEqual(res1, out)
105111

106112
def test_ger(self):
107113
types = {
@@ -2604,6 +2610,10 @@ def maybe_squeeze_result(l, r, result):
26042610
# test torch.matmul function as well
26052611
torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r))
26062612
self.assertEqual(truth, torch_result)
2613+
# test torch.matmul with out
2614+
out = torch.zeros_like(torch_result)
2615+
torch.matmul(l, r, out=out)
2616+
self.assertEqual(truth, maybe_squeeze_result(l, r, out))
26072617

26082618
# compare to bmm
26092619
bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims),

0 commit comments

Comments
 (0)