Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 33 additions & 15 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,17 @@ Tensor& addr_out(Tensor &result, const Tensor& self, const Tensor& vec1, const T
}

Tensor dot(const Tensor& self, const Tensor& tensor) {
if (self.dim() != 1) {
AT_ERROR("Expected argument self to have 1 dimension, but has %d", self.dim());
}
if (tensor.dim() != 1) {
AT_ERROR("Expected argument tensor to have 1 dimension, but has %d", tensor.dim());
}
check_1d(self, "self", "dot");
check_1d(tensor, "tensor", "dot");
return self._dot(tensor);
}

Tensor& dot_out(Tensor& result, const Tensor& self, const Tensor& tensor) {
result.resize_({});
// dispatching through type ensures we don't allow mismatched types.
return self.type().fill_(result, self.dot(tensor));
}

/*
Matrix product of two Tensors.
The behavior depends on the dimensionality of the Tensors as follows:
Expand All @@ -200,18 +202,21 @@ The behavior depends on the dimensionality of the Tensors as follows:
must be broadcastable). For example, if tensor1 is a (j x 1 x n x m) Tensor
and tensor2 is a (k x m x p) Tensor, the returned tensor will be an (j x k x n x p) Tensor.
*/
Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
Tensor matmul(at::optional<Tensor> out_opt, const Tensor& tensor1, const Tensor& tensor2) {
auto dim_tensor1 = tensor1.dim();
auto dim_tensor2 = tensor2.dim();
auto has_out = out_opt.has_value();
Tensor out = out_opt.value_or(Tensor());

if (dim_tensor1 == 1 && dim_tensor2 == 1) {
return tensor1.dot(tensor2);
return has_out ? at::native::dot_out(out, tensor1, tensor2) : tensor1.dot(tensor2);
} else if (dim_tensor1 == 2 && dim_tensor2 == 1) {
return tensor1.mv(tensor2);
return has_out ? at::native::mv_out(out, tensor1, tensor2) : tensor1.mv(tensor2);
} else if (dim_tensor1 == 1 && dim_tensor2 == 2) {
return tensor1.unsqueeze(0).mm(tensor2).squeeze_(0);
return has_out ? at::native::mm_out(out, tensor1.unsqueeze(0), tensor2).squeeze_(0)
: tensor1.unsqueeze(0).mm(tensor2).squeeze_(0);
} else if (dim_tensor1 == 2 && dim_tensor2 == 2) {
return tensor1.mm(tensor2);
return has_out ? at::native::mm_out(out, tensor1, tensor2) : tensor1.mm(tensor2);
} else if (dim_tensor1 >= 3 && (dim_tensor2 == 1 || dim_tensor2 == 2)) {
// optimization: use mm instead of bmm by folding tensor1's batch into
// its leading matrix dimension.
Expand All @@ -227,7 +232,9 @@ Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {

// fold the batch into the first dimension
Tensor t1 = tensor1.contiguous().view({-1, size1[size1.size() - 1]});
return at::_unsafe_view(t1.mm(t2), output_size);
Tensor output = has_out ? at::_unsafe_view(at::mm_out(out, t1, t2), output_size)
: at::_unsafe_view(t1.mm(t2), output_size);
return has_out ? out.set_(output) : output;
} else if ((dim_tensor1 >= 1 && dim_tensor2 >= 1) && (dim_tensor1 >= 3 || dim_tensor2 >= 3)) {
// We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
// we track m1 vs m2 separately even though they must match for nicer error messages
Expand Down Expand Up @@ -260,8 +267,6 @@ Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
Tensor tensor1_expanded = tensor1.expand(tensor1_expand_size).contiguous().view(tensor1_bmm_view);
Tensor tensor2_expanded = tensor2.expand(tensor2_expand_size).contiguous().view(tensor2_bmm_view);

Tensor output = tensor1_expanded.bmm(tensor2_expanded);

// reshape batches back into result
std::vector<int64_t> output_shape(expand_batch_portion);
if (dim_tensor1 > 1) {
Expand All @@ -270,13 +275,26 @@ Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
if (dim_tensor2 > 1) {
output_shape.push_back(p);
}
return at::_unsafe_view(output, output_shape);

Tensor output = has_out ? at::_unsafe_view(at::bmm_out(out, tensor1_expanded, tensor2_expanded), output_shape)
: at::_unsafe_view(tensor1_expanded.bmm(tensor2_expanded), output_shape);

return has_out ? out.set_(output) : output;
}

AT_ERROR("both arguments to matmul need to be at least 1D, but they are %dD and %dD",
dim_tensor1, dim_tensor2);

}

Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
return at::native::matmul(at::nullopt, tensor1, tensor2);
}

Tensor& matmul_out(Tensor &result, const Tensor & tensor1, const Tensor & tensor2) {
at::native::matmul(at::optional<Tensor>(result), tensor1, tensor2);
return result;
}

}
}
6 changes: 6 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,9 @@

- func: dot(Tensor self, Tensor tensor) -> Tensor

- func: dot_out(Tensor result, Tensor self, Tensor tensor) -> Tensor
variants: function

- func: einsum(std::string equation, TensorList tensors) -> Tensor
variants: function

Expand Down Expand Up @@ -469,6 +472,9 @@

- func: matmul(Tensor self, Tensor other) -> Tensor

- func: matmul_out(Tensor result, Tensor self, Tensor other) -> Tensor
variants: function

- func: max_values(Tensor self, int64_t dim, bool keepdim=false) -> Tensor

- func: max_pool1d(Tensor self, IntList[1] kernel_size, IntList[1] stride={}, IntList[1] padding=0, IntList[1] dilation=1, bool ceil_mode=false) -> (Tensor, Tensor)
Expand Down
10 changes: 10 additions & 0 deletions test/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def test_dot(self):
for i, j in zip(v1, v2):
res2 += i * j
self.assertEqual(res1, res2)
out = torch.randn(()).type(tname)
torch.dot(v1, v2, out=out)
self.assertEqual(res1, out)

# Test 0-strided
for tname, _prec in types.items():
Expand All @@ -102,6 +105,9 @@ def test_dot(self):
for i, j in zip(v1, v2):
res2 += i * j
self.assertEqual(res1, res2)
out = torch.randn(()).type(tname)
torch.dot(v1, v2, out=out)
self.assertEqual(res1, out)

def test_ger(self):
types = {
Expand Down Expand Up @@ -2525,6 +2531,10 @@ def maybe_squeeze_result(l, r, result):
# test torch.matmul function as well
torch_result = maybe_squeeze_result(l, r, torch.matmul(l, r))
self.assertEqual(truth, torch_result)
# test torch.matmul with out
out = torch.zeros_like(torch_result)
torch.matmul(l, r, out=out)
self.assertEqual(truth, maybe_squeeze_result(l, r, out))

# compare to bmm
bmm_result = (torch.bmm(lhs_expanded.contiguous().view(-1, *lhs_mat_dims),
Expand Down