@@ -14414,6 +14414,67 @@ def test_bernoulli_mem_overlap(self, device):
1441414414 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
1441514415 torch.bernoulli(torch.rand_like(x), out=x)
1441614416
14417+ def test_index_put_mem_overlap(self, device):
14418+ x = torch.rand((1,), device=device).expand((6,))
14419+ y = torch.rand((6,), device=device)
14420+ ind = torch.tensor([0, 2, 3], device=device)
14421+ value = torch.rand((3,), device=device)
14422+ with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14423+ x.index_put_((ind,), value)
14424+
14425+ def test_masked_fill_mem_overlap(self, device):
14426+ x = torch.rand((1,), device=device).expand((6,))
14427+ mask = torch.tensor([True, False, True, True, False, False], device=device)
14428+ with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14429+ x.masked_fill_(mask, 0.)
14430+
14431+ fill_val = torch.tensor(0., device=device)
14432+ with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14433+ x.masked_fill_(mask, fill_val)
14434+
14435+ def test_masked_select_mem_overlap(self, device):
14436+ x = torch.rand((1,), device=device).expand((3,))
14437+ y = torch.rand((6,), device=device)
14438+ mask = torch.tensor([True, False, True, True, False, False], device=device)
14439+ with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14440+ torch.masked_select(y, mask, out=x)
14441+
14442+ def test_masked_scatter_mem_overlap(self, device):
14443+ x = torch.rand((1,), device=device).expand((6,))
14444+ src = torch.rand((3,), device=device)
14445+ mask = torch.tensor([True, False, True, True, False, False], device=device)
14446+
14447+ with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14448+ x.masked_scatter_(mask, src)
14449+
14450+ def test_index_select_mem_overlap(self, device):
14451+ x = torch.rand((1, 6), device=device).expand((2, 6))
14452+ y = torch.rand((3, 6), device=device)
14453+ ind = torch.tensor([0, 1], dtype=torch.int64, device=device)
14454+ with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14455+ torch.index_select(y, 1, ind, out=x)
14456+
14457+ def test_cat_mem_overlap(self, device):
14458+ x = torch.rand((1, 3), device=device).expand((6, 3))
14459+ y = torch.rand((3, 3), device=device)
14460+ with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14461+ torch.cat([y, y], out=x)
14462+
14463+ def test_scatter_mem_overlap(self, device):
14464+ x = torch.rand((1,), device=device).expand((6,))
14465+ src = torch.rand((3,), device=device)
14466+ ind = torch.tensor([0, 2, 3], device=device, dtype=torch.int64)
14467+
14468+ with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14469+ x.scatter_(0, ind, src)
14470+
14471+ def test_gather_mem_overlap(self, device):
14472+ x = torch.rand((1,), device=device).expand((3,))
14473+ src = torch.rand((6,), device=device)
14474+ ind = torch.tensor([0, 2, 3], device=device, dtype=torch.int64)
14475+ with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
14476+ torch.gather(src, 0, ind, out=x)
14477+
1441714478 def test_linlogspace_mem_overlap(self, device):
1441814479 x = torch.rand(1, device=device).expand(10)
1441914480 with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):
0 commit comments