Skip to content

Commit 0e44db8

Browse files
t-vifacebook-github-bot
authored andcommitted
Add check for backend of arguments to bmm cpu (#12434)
Summary: Fixes: #12406 Thank you, jcjohnson, for reporting. Pull Request resolved: #12434 Differential Revision: D10235799 Pulled By: soumith fbshipit-source-id: 44ee35010bac3791901f604095f5b4bc66b0e7f8
1 parent db8d01b commit 0e44db8

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor&
283283
TensorArg self_arg(self_or_result, is_bmm_out ? "self" : "result", 0);
284284
TensorArg b1_arg(batch1, "batch1", 1);
285285
TensorArg b2_arg(batch2, "batch2", 2);
286+
checkBackend(c, {self_or_result, batch1, batch2}, Backend::CPU);
286287
checkDim(c, b1_arg, 3);
287288
checkDim(c, b2_arg, 3);
288289

test/test_torch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,6 +1600,10 @@ def test_bmm(self):
16001600
for i in range(num_batches):
16011601
r = torch.mm(b1[i], b2[i])
16021602
self.assertEqual(r, res[i])
1603+
if torch.cuda.is_available():
1604+
# check that mixed arguments are rejected
1605+
self.assertRaises(RuntimeError, lambda: torch.bmm(b1, b2.cuda()))
1606+
self.assertRaises(RuntimeError, lambda: torch.bmm(b1.cuda(), b2))
16031607

16041608
def test_addbmm(self):
16051609
# num_batches = 10

0 commit comments

Comments
 (0)