Skip to content

Commit 9fe8a1b

Browse files
committed
Fix addr for non-float data types
1 parent 163fd96 commit 9fe8a1b

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

test/test_mps.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6371,6 +6371,7 @@ def test_addr(self, device="mps", dtype=torch.float32):
63716371
m1 = torch.randn(10, device=device).to(dtype)
63726372
m2 = torch.randn(25, device=device).to(dtype)
63736373
self._test_addr(torch.addr, M, m1, m2, beta=0)
6374+
63746375
class TestGatherScatter(TestCase):
63756376
def test_slicing_with_step(self):
63766377
# Slicing with step
@@ -8656,7 +8657,7 @@ class TestConsistency(TestCase):
86568657
'addcmul': ['f32', 'i16', 'i32', 'i64', 'u8'],
86578658
'addmm': ['f32'],
86588659
'addmv': ['f32'],
8659-
'addr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
8660+
'addr': ['f32'],
86608661
'all': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
86618662
'allclose': ['f16', 'f32'],
86628663
'any': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],

0 commit comments

Comments
 (0)