Skip to content

Commit b9e68e0

Browse files
IvanYashchukfacebook-github-bot
authored andcommitted
Fix the bug in THCTensor_(baddbmm) and ATen's addmm_cuda for strided views input (#42425)
Summary: Fixes #42418. The problem was that the non-contiguous batched matrices were passed to `gemmStridedBatched`. The following code fails on master and works with the proposed patch: ```python import torch x = torch.tensor([[1., 2, 3], [4., 5, 6]], device='cuda:0') c = torch.as_strided(x, size=[2, 2, 2], stride=[3, 1, 1]) torch.einsum('...ab,...bc->...ac', c, c) ``` Pull Request resolved: #42425 Reviewed By: glaringlee Differential Revision: D22925266 Pulled By: ngimel fbshipit-source-id: a72d56d26c7381b7793a047d76bcc5bd45a9602c
1 parent 317b9d3 commit b9e68e0

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

aten/src/ATen/native/cuda/LinearAlgebra.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,12 @@ Tensor bmm_cuda(const Tensor& self, const Tensor& mat2) {
3434
Tensor prepare_matrix_for_cublas(Tensor& tensor, bool& transpose_tensor) {
3535
Tensor tensor_;
3636
IntArrayRef tensor_strides = tensor.strides();
37+
IntArrayRef tensor_sizes = tensor.sizes();
3738

38-
if ((tensor_strides[0] == 1) && (tensor_strides[1] != 0)) {
39+
if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max<int64_t>(1, tensor_sizes[0]))) {
3940
tensor_ = tensor;
4041
transpose_tensor = false;
41-
} else if ((tensor_strides[1] == 1) && (tensor_strides[0] != 0)) {
42+
} else if ((tensor_strides[1] == 1) && (tensor_strides[0] >= std::max<int64_t>(1, tensor_sizes[1]))) {
4243
tensor_ = tensor;
4344
transpose_tensor = true;
4445
} else {

aten/src/THC/generic/THCTensorMathBlas.cu

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ void THCTensor_(baddbmm)(THCState *state, THCTensor *result, THCTensor *t,
5151
char transpose_batch1, transpose_batch2;
5252
int64_t lda, ldb, ldc;
5353
THCTensor *result_, *batch1_, *batch2_;
54-
if (result->stride(1) == 1)
54+
if (result->stride(1) == 1 &&
55+
(result->size(2) == 1 || result->stride(2) >= std::max<int64_t>(1, result->size(1))))
5556
{
5657
transpose_result = false;
5758
result_ = result;
5859
ldc = result_->stride(2);
5960
}
60-
else if (result->stride(2) == 1)
61+
else if (result->stride(2) == 1 &&
62+
(result->size(1) == 1 || result->stride(1) >= std::max<int64_t>(1, result->size(2))))
6163
{
6264
transpose_result = true;
6365

@@ -80,15 +82,19 @@ void THCTensor_(baddbmm)(THCState *state, THCTensor *result, THCTensor *t,
8082
ldc = result_->stride(2);
8183
}
8284

85+
const int64_t m = result->size(transpose_result ? 2 : 1);
86+
const int64_t n = result->size(transpose_result ? 1 : 2);
87+
const int64_t k = batch1->size(transpose_result ? 1 : 2);
88+
8389
if (batch1->stride(transpose_result ? 2 : 1) == 1 &&
84-
batch1->stride(transpose_result ? 1 : 2) != 0)
90+
batch1->stride(transpose_result ? 1 : 2) >= std::max<int64_t>(1, m))
8591
{
8692
transpose_batch1 = 'n';
8793
batch1_ = batch1;
8894
lda = batch1_->stride(transpose_result ? 1 : 2);
8995
}
9096
else if (batch1->stride(transpose_result ? 1 : 2) == 1 &&
91-
batch1->stride(transpose_result ? 2 : 1) != 0)
97+
batch1->stride(transpose_result ? 2 : 1) >= std::max<int64_t>(1, k))
9298
{
9399
transpose_batch1 = 't';
94100
batch1_ = batch1;
@@ -107,14 +113,14 @@ void THCTensor_(baddbmm)(THCState *state, THCTensor *result, THCTensor *t,
107113
}
108114

109115
if (batch2->stride(transpose_result ? 2 : 1) == 1 &&
110-
batch2->stride(transpose_result ? 1 : 2) != 0)
116+
batch2->stride(transpose_result ? 1 : 2) >= std::max<int64_t>(1, k))
111117
{
112118
transpose_batch2 = 'n';
113119
batch2_ = batch2;
114120
ldb = batch2_->stride(transpose_result ? 1 : 2);
115121
}
116122
else if (batch2->stride(transpose_result ? 1 : 2) == 1 &&
117-
batch2->stride(transpose_result ? 2 : 1) != 0)
123+
batch2->stride(transpose_result ? 2 : 1) >= std::max<int64_t>(1, n))
118124
{
119125
transpose_batch2 = 't';
120126
batch2_ = batch2;

test/test_torch.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17083,6 +17083,23 @@ def genf_float(x, y):
1708317083

1708417084
_test_mm(n, m, p, dtype, genf)
1708517085

17086+
@onlyOnCPUAndCUDA
17087+
@dtypes(torch.float32, torch.float64)
17088+
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
17089+
def test_strided_mm_bmm(self, device, dtype):
17090+
# Tests strided view case with stride smaller than corresponding dimension size
17091+
x = torch.tensor([[1., 2., 3.], [4., 5., 6.]], dtype=dtype, device=device)
17092+
new_shape = [2, 2, 2]
17093+
new_stride = [3, 1, 1]
17094+
sx = torch.as_strided(x, size=new_shape, stride=new_stride)
17095+
17096+
torch_fn = lambda x: torch.bmm(x, x) # noqa: E731
17097+
np_fn = lambda x: np.matmul(x, x) # noqa: E731
17098+
self.compare_with_numpy(torch_fn, np_fn, sx)
17099+
17100+
torch_fn = lambda x: torch.mm(x, x) # noqa: E731
17101+
self.compare_with_numpy(torch_fn, np_fn, sx[0])
17102+
1708617103
@onlyCPU
1708717104
@dtypes(torch.float)
1708817105
def test_bmm(self, device, dtype):

0 commit comments

Comments
 (0)