Commit b9e68e0
Fix the bug in THCTensor_(baddbmm) and ATen's addmm_cuda for strided views input (#42425)
Summary:
Fixes #42418.
The problem was that the non-contiguous batched matrices were passed to `gemmStridedBatched`.
The following code fails on master and works with the proposed patch:
```python
import torch
x = torch.tensor([[1., 2, 3], [4., 5, 6]], device='cuda:0')
c = torch.as_strided(x, size=[2, 2, 2], stride=[3, 1, 1])
torch.einsum('...ab,...bc->...ac', c, c)
```
Pull Request resolved: #42425
Reviewed By: glaringlee
Differential Revision: D22925266
Pulled By: ngimel
fbshipit-source-id: a72d56d26c7381b7793a047d76bcc5bd45a9602c1 parent 317b9d3 commit b9e68e0
File tree
3 files changed
+32
-8
lines changed- aten/src
- ATen/native/cuda
- THC/generic
- test
3 files changed
+32
-8
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
34 | 34 | | |
35 | 35 | | |
36 | 36 | | |
| 37 | + | |
37 | 38 | | |
38 | | - | |
| 39 | + | |
39 | 40 | | |
40 | 41 | | |
41 | | - | |
| 42 | + | |
42 | 43 | | |
43 | 44 | | |
44 | 45 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
51 | 51 | | |
52 | 52 | | |
53 | 53 | | |
54 | | - | |
| 54 | + | |
| 55 | + | |
55 | 56 | | |
56 | 57 | | |
57 | 58 | | |
58 | 59 | | |
59 | 60 | | |
60 | | - | |
| 61 | + | |
| 62 | + | |
61 | 63 | | |
62 | 64 | | |
63 | 65 | | |
| |||
80 | 82 | | |
81 | 83 | | |
82 | 84 | | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
83 | 89 | | |
84 | | - | |
| 90 | + | |
85 | 91 | | |
86 | 92 | | |
87 | 93 | | |
88 | 94 | | |
89 | 95 | | |
90 | 96 | | |
91 | | - | |
| 97 | + | |
92 | 98 | | |
93 | 99 | | |
94 | 100 | | |
| |||
107 | 113 | | |
108 | 114 | | |
109 | 115 | | |
110 | | - | |
| 116 | + | |
111 | 117 | | |
112 | 118 | | |
113 | 119 | | |
114 | 120 | | |
115 | 121 | | |
116 | 122 | | |
117 | | - | |
| 123 | + | |
118 | 124 | | |
119 | 125 | | |
120 | 126 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17083 | 17083 | | |
17084 | 17084 | | |
17085 | 17085 | | |
| 17086 | + | |
| 17087 | + | |
| 17088 | + | |
| 17089 | + | |
| 17090 | + | |
| 17091 | + | |
| 17092 | + | |
| 17093 | + | |
| 17094 | + | |
| 17095 | + | |
| 17096 | + | |
| 17097 | + | |
| 17098 | + | |
| 17099 | + | |
| 17100 | + | |
| 17101 | + | |
| 17102 | + | |
17086 | 17103 | | |
17087 | 17104 | | |
17088 | 17105 | | |
| |||
0 commit comments